1
0
Fork 0

fix(api): report accurate sizes

This commit is contained in:
Sean Sube 2023-01-16 15:11:40 -06:00
parent d406cd4e99
commit 4bf68759d7
3 changed files with 14 additions and 6 deletions

View File

@ -324,7 +324,7 @@ def img2img():
return jsonify({ return jsonify({
'output': output, 'output': output,
'params': params.tojson(), 'params': params.tojson(),
'size': size.tojson(), 'size': upscale.resize(size).tojson(),
}) })
@ -345,7 +345,7 @@ def txt2img():
return jsonify({ return jsonify({
'output': output, 'output': output,
'params': params.tojson(), 'params': params.tojson(),
'size': size.tojson(), 'size': upscale.resize(size).tojson(),
}) })
@ -399,7 +399,7 @@ def inpaint():
return jsonify({ return jsonify({
'output': output, 'output': output,
'params': params.tojson(), 'params': params.tojson(),
'size': size.tojson(), 'size': upscale.resize(size.with_border(expand)).tojson(),
}) })

View File

@ -10,7 +10,8 @@ import numpy as np
import torch import torch
from .utils import ( from .utils import (
ServerContext ServerContext,
Size,
) )
# TODO: these should all be params or config # TODO: these should all be params or config
@ -49,7 +50,7 @@ class ONNXImage():
class ONNXNet(): class ONNXNet():
''' '''
Provides the RRDBNet interface but using ONNX. Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
''' '''
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None: def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
@ -102,6 +103,9 @@ class UpscaleParams():
self.platform = platform self.platform = platform
self.half = half self.half = half
def resize(self, size: Size) -> Size:
return Size(size.width * self.scale * self.outscale, size.height * self.scale * self.outscale)
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_file = '%s.%s' % (params.upscale_model, params.platform) model_file = '%s.%s' % (params.upscale_model, params.platform)
@ -125,9 +129,10 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_path = [model_path, wdn_model_path] model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise] dni_weight = [params.denoise, 1 - params.denoise]
# TODO: shouldn't need the PTH file
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale=params.scale, scale=params.scale,
model_path=model_path, model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model),
dni_weight=dni_weight, dni_weight=dni_weight,
model=model, model=model,
tile=tile, tile=tile,

View File

@ -76,6 +76,9 @@ class Size:
'width': self.width, 'width': self.width,
} }
def with_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value) return min(max(float(args.get(key, default_value)), min_value), max_value)