fix(api): pass hardware platform to upscaling pipeline (#77)
This commit is contained in:
parent
fe9206c894
commit
f319e6a49b
|
@ -187,7 +187,7 @@ def border_from_request() -> Border:
|
|||
return Border(left, right, top, bottom)
|
||||
|
||||
|
||||
def upscale_from_request() -> UpscaleParams:
|
||||
def upscale_from_request(provider: str) -> UpscaleParams:
|
||||
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
|
||||
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
|
||||
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
|
||||
|
@ -199,13 +199,14 @@ def upscale_from_request() -> UpscaleParams:
|
|||
|
||||
return UpscaleParams(
|
||||
upscaling,
|
||||
provider,
|
||||
correction_model=correction,
|
||||
scale=scale,
|
||||
outscale=outscale,
|
||||
faces=faces,
|
||||
platform='onnx',
|
||||
denoise=denoise,
|
||||
faces=faces,
|
||||
face_strength=face_strength,
|
||||
format='onnx',
|
||||
outscale=outscale,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
|
||||
|
@ -355,7 +356,7 @@ def img2img():
|
|||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
|
||||
params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
upscale = upscale_from_request(params.provider)
|
||||
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
|
@ -385,7 +386,7 @@ def img2img():
|
|||
@app.route('/api/txt2img', methods=['POST'])
|
||||
def txt2img():
|
||||
params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
upscale = upscale_from_request(params.provider)
|
||||
|
||||
output = make_output_name(
|
||||
'txt2img',
|
||||
|
@ -413,7 +414,7 @@ def inpaint():
|
|||
|
||||
params, size = pipeline_from_request()
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request()
|
||||
upscale = upscale_from_request(params.provider)
|
||||
|
||||
fill_color = get_not_empty(request.args, 'fillColor', 'white')
|
||||
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
|
||||
|
@ -474,7 +475,7 @@ def upscale():
|
|||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
|
||||
params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
upscale = upscale_from_request(params.provider)
|
||||
|
||||
output = make_output_name(
|
||||
'upscale',
|
||||
|
|
|
@ -4,7 +4,7 @@ from onnxruntime import InferenceSession
|
|||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -86,43 +86,45 @@ class UpscaleParams():
|
|||
def __init__(
|
||||
self,
|
||||
upscale_model: str,
|
||||
provider: str,
|
||||
correction_model: Union[str, None] = None,
|
||||
scale: int = 4,
|
||||
outscale: int = 1,
|
||||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_strength: float = 0.5,
|
||||
platform: str = 'onnx',
|
||||
half=False
|
||||
format: Literal['onnx', 'pth'] = 'onnx',
|
||||
half=False,
|
||||
outscale: int = 1,
|
||||
scale: int = 4,
|
||||
) -> None:
|
||||
self.upscale_model = upscale_model
|
||||
self.provider = provider
|
||||
self.correction_model = correction_model
|
||||
self.scale = scale
|
||||
self.outscale = outscale
|
||||
self.denoise = denoise
|
||||
self.faces = faces
|
||||
self.face_strength = face_strength
|
||||
self.platform = platform
|
||||
self.format = format
|
||||
self.half = half
|
||||
self.outscale = outscale
|
||||
self.scale = scale
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
return Size(size.width * self.outscale, size.height * self.outscale)
|
||||
|
||||
|
||||
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||
model_file = '%s.%s' % (params.upscale_model, params.platform)
|
||||
model_file = '%s.%s' % (params.upscale_model, params.format)
|
||||
model_path = path.join(ctx.model_path, model_file)
|
||||
if not path.isfile(model_path):
|
||||
raise Exception('Real ESRGAN model not found at %s' % model_path)
|
||||
|
||||
# use ONNX acceleration, if available
|
||||
if params.platform == 'onnx':
|
||||
model = ONNXNet(ctx, model_file)
|
||||
elif params.platform == 'pth':
|
||||
if params.format == 'onnx':
|
||||
model = ONNXNet(ctx, model_file, provider=params.provider)
|
||||
elif params.format == 'pth':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||
else:
|
||||
raise Exception('unknown platform %s' % params.platform)
|
||||
raise Exception('unknown platform %s' % params.format)
|
||||
|
||||
dni_weight = None
|
||||
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
|
||||
|
|
Loading…
Reference in New Issue