fix(api): report accurate sizes
This commit is contained in:
parent
d406cd4e99
commit
4bf68759d7
|
@ -324,7 +324,7 @@ def img2img():
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'output': output,
|
'output': output,
|
||||||
'params': params.tojson(),
|
'params': params.tojson(),
|
||||||
'size': size.tojson(),
|
'size': upscale.resize(size).tojson(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -345,7 +345,7 @@ def txt2img():
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'output': output,
|
'output': output,
|
||||||
'params': params.tojson(),
|
'params': params.tojson(),
|
||||||
'size': size.tojson(),
|
'size': upscale.resize(size).tojson(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -399,7 +399,7 @@ def inpaint():
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'output': output,
|
'output': output,
|
||||||
'params': params.tojson(),
|
'params': params.tojson(),
|
||||||
'size': size.tojson(),
|
'size': upscale.resize(size.with_border(expand)).tojson(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ServerContext
|
ServerContext,
|
||||||
|
Size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: these should all be params or config
|
# TODO: these should all be params or config
|
||||||
|
@ -49,7 +50,7 @@ class ONNXImage():
|
||||||
|
|
||||||
class ONNXNet():
|
class ONNXNet():
|
||||||
'''
|
'''
|
||||||
Provides the RRDBNet interface but using ONNX.
|
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
|
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
|
||||||
|
@ -102,6 +103,9 @@ class UpscaleParams():
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.half = half
|
self.half = half
|
||||||
|
|
||||||
|
def resize(self, size: Size) -> Size:
|
||||||
|
return Size(size.width * self.scale * self.outscale, size.height * self.scale * self.outscale)
|
||||||
|
|
||||||
|
|
||||||
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
model_file = '%s.%s' % (params.upscale_model, params.platform)
|
model_file = '%s.%s' % (params.upscale_model, params.platform)
|
||||||
|
@ -125,9 +129,10 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
model_path = [model_path, wdn_model_path]
|
model_path = [model_path, wdn_model_path]
|
||||||
dni_weight = [params.denoise, 1 - params.denoise]
|
dni_weight = [params.denoise, 1 - params.denoise]
|
||||||
|
|
||||||
|
# TODO: shouldn't need the PTH file
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale=params.scale,
|
scale=params.scale,
|
||||||
model_path=model_path,
|
model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model),
|
||||||
dni_weight=dni_weight,
|
dni_weight=dni_weight,
|
||||||
model=model,
|
model=model,
|
||||||
tile=tile,
|
tile=tile,
|
||||||
|
|
|
@ -76,6 +76,9 @@ class Size:
|
||||||
'width': self.width,
|
'width': self.width,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def with_border(self, border: Border):
|
||||||
|
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
||||||
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
Loading…
Reference in New Issue