1
0
Fork 0

fix(api): pass hardware platform to upscaling pipeline (#77)

This commit is contained in:
Sean Sube 2023-01-22 16:35:53 -06:00
parent fe9206c894
commit f319e6a49b
2 changed files with 25 additions and 22 deletions

View File

@ -187,7 +187,7 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
def upscale_from_request(provider: str) -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
@ -199,13 +199,14 @@ def upscale_from_request() -> UpscaleParams:
return UpscaleParams(
upscaling,
provider,
correction_model=correction,
scale=scale,
outscale=outscale,
faces=faces,
platform='onnx',
denoise=denoise,
faces=faces,
face_strength=face_strength,
format='onnx',
outscale=outscale,
scale=scale,
)
@ -355,7 +356,7 @@ def img2img():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
params, size = pipeline_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)
strength = get_and_clamp_float(
request.args,
@ -385,7 +386,7 @@ def img2img():
@app.route('/api/txt2img', methods=['POST'])
def txt2img():
params, size = pipeline_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)
output = make_output_name(
'txt2img',
@ -413,7 +414,7 @@ def inpaint():
params, size = pipeline_from_request()
expand = border_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)
fill_color = get_not_empty(request.args, 'fillColor', 'white')
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
@ -474,7 +475,7 @@ def upscale():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
params, size = pipeline_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)
output = make_output_name(
'upscale',

View File

@ -4,7 +4,7 @@ from onnxruntime import InferenceSession
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Any, Union
from typing import Any, Literal, Union
import numpy as np
import torch
@ -86,43 +86,45 @@ class UpscaleParams():
def __init__(
self,
upscale_model: str,
provider: str,
correction_model: Union[str, None] = None,
scale: int = 4,
outscale: int = 1,
denoise: float = 0.5,
faces=True,
face_strength: float = 0.5,
platform: str = 'onnx',
half=False
format: Literal['onnx', 'pth'] = 'onnx',
half=False,
outscale: int = 1,
scale: int = 4,
) -> None:
self.upscale_model = upscale_model
self.provider = provider
self.correction_model = correction_model
self.scale = scale
self.outscale = outscale
self.denoise = denoise
self.faces = faces
self.face_strength = face_strength
self.platform = platform
self.format = format
self.half = half
self.outscale = outscale
self.scale = scale
def resize(self, size: Size) -> Size:
return Size(size.width * self.outscale, size.height * self.outscale)
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_file = '%s.%s' % (params.upscale_model, params.platform)
model_file = '%s.%s' % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path):
raise Exception('Real ESRGAN model not found at %s' % model_path)
# use ONNX acceleration, if available
if params.platform == 'onnx':
model = ONNXNet(ctx, model_file)
elif params.platform == 'pth':
if params.format == 'onnx':
model = ONNXNet(ctx, model_file, provider=params.provider)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
else:
raise Exception('unknown platform %s' % params.platform)
raise Exception('unknown platform %s' % params.format)
dni_weight = None
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1: