diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index efca1804..26c99a29 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -89,7 +89,11 @@ def run_txt2img_pipeline( callback=progress, ) - for image, output in zip(result.images, outputs): + image_outputs = list(zip(result.images, outputs)) + del result + del pipe + + for image, output in image_outputs: if highres.scale > 1: highres_progress = ChainProgress.from_progress(progress) @@ -99,7 +103,10 @@ def run_txt2img_pipeline( StageParams(), params, image, - upscale=upscale, + upscale=upscale.with_args( + scale=1, + outscale=1, + ), callback=highres_progress, ) @@ -116,7 +123,26 @@ def run_txt2img_pipeline( ) def highres_tile(tile: Image.Image, dims): - tile = tile.resize((size.height, size.width)) + if highres.method == "bilinear": + logger.debug("using bilinear interpolation for highres") + tile = tile.resize((size.height, size.width), resample=Image.Resampling.BILINEAR) + elif highres.method == "lanczos": + logger.debug("using Lanczos interpolation for highres") + tile = tile.resize((size.height, size.width), resample=Image.Resampling.LANCZOS) + else: + logger.debug("using upscaling pipeline for highres") + tile = run_upscale_correction( + job, + server, + StageParams(), + params, + image, + upscale=upscale.with_args( + faces=False, + ), + callback=highres_progress, + ) + if params.lpw: logger.debug("using LPW pipeline for highres") rng = torch.manual_seed(params.seed) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 0c0b7eb1..89644216 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -325,16 +325,19 @@ class HighresParams: scale: int, steps: int, strength: float, + method: Literal["bilinear", "lanczos", "upscale"] = "lanczos", ): self.scale = scale self.steps = steps self.strength = strength + self.method = method def resize(self, size: Size) -> Size: return Size(size.width * self.scale, size.height * self.scale) def tojson(self): return { + "method": self.method, "scale": self.scale, "steps": self.steps, "strength": self.strength, diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 05046054..b218285a 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -51,6 +51,11 @@ mask_filters = { "gaussian-multiply": mask_filter_gaussian_multiply, "gaussian-screen": mask_filter_gaussian_screen, } +highres_methods = { + "bilinear": highres_method_bilinear, + "lanczos": highres_method_lanczos, + "upscale": highres_method_upscale, +} # Available ORT providers @@ -94,6 +99,10 @@ def get_extra_strings(): return extra_strings +def get_highres_methods(): + return highres_methods + + def get_mask_filters(): return mask_filters diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 3bc5ddbe..dbc52f3d 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -19,6 +19,7 @@ from .load import ( get_available_platforms, get_config_value, get_correction_models, + get_highres_methods, get_upscaling_models, ) from .utils import get_model_path @@ -179,6 +180,7 @@ def upscale_from_request() -> UpscaleParams: def highres_from_request() -> HighresParams: + method = get_from_list(request.args, "highresMethod", get_highres_methods()) scale = get_and_clamp_int(request.args, "highresScale", 1, 4, 1) steps = get_and_clamp_int(request.args, "highresSteps", 1, 200, 1) strength = get_and_clamp_float(request.args, "highresStrength", 0.5, 1.0, 0.0) @@ -187,4 +189,5 @@ def highres_from_request() -> HighresParams: scale, steps, strength, + method=method, ) diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 61b76f1c..eaafbb68 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -22,6 +22,7 @@ MEMORY_ERRORS = [ "hipErrorOutOfMemory", "MIOPEN failure 7", "out of memory", + "rocblas_status_memory_error", ] diff --git a/api/params.json b/api/params.json index a35ee5b0..e9c18976 100644 --- a/api/params.json +++ b/api/params.json @@ -60,6 +60,14 @@ "max": 1024, "step": 8 }, + "highresMethods": { + "default": "lanczos", + "keys": [ + "bilinear", + "lanczos", + "upscale" + ] + }, "highresScale": { "default": 1, "min": 1, @@ -76,7 +84,7 @@ "default": 0.5, "min": 0, "max": 1, - "step": 0.1 + "step": 0.01 }, "inversion": { "default": "", diff --git a/gui/examples/config.json b/gui/examples/config.json index 69de988e..2d663cdd 100644 --- a/gui/examples/config.json +++ b/gui/examples/config.json @@ -64,6 +64,32 @@ "max": 1024, "step": 8 }, + "highresMethods": { + "default": "lanczos", + "keys": [ + "bilinear", + "lanczos", + "upscale" + ] + }, + "highresScale": { + "default": 1, + "min": 1, + "max": 4, + "step": 1 + }, + "highresSteps": { + "default": 0, + "min": 1, + "max": 200, + "step": 1 + }, + "highresStrength": { + "default": 0.5, + "min": 0, + "max": 1, + "step": 0.01 + }, "inversion": { "default": "", "keys": [] @@ -166,4 +192,4 @@ "step": 8 } } -} +} \ No newline at end of file diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index c30ad228..68bf8838 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -147,6 +147,7 @@ export interface BlendParams { export interface HighresParams { enabled: boolean; + highresMethod: string; highresScale: number; highresSteps: number; highresStrength: number; @@ -503,6 +504,7 @@ export function makeClient(root: string, f = fetch): ApiClient { } if (doesExist(highres) && highres.enabled) { + url.searchParams.append('highresMethod', highres.highresMethod); url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER)); url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER)); url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT)); diff --git a/gui/src/components/control/HighresControl.tsx b/gui/src/components/control/HighresControl.tsx index ff62432b..bdcf7d91 100644 --- a/gui/src/components/control/HighresControl.tsx +++ b/gui/src/components/control/HighresControl.tsx @@ -69,5 +69,22 @@ export function HighresControl() { }); }} /> + + {t('parameter.highres.method')} + + ; } diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx index 47b5658c..4eff8c83 100644 --- a/gui/src/components/control/UpscaleControl.tsx +++ b/gui/src/components/control/UpscaleControl.tsx @@ -109,7 +109,7 @@ export function UpscaleControl() { }} /> - Upscale Order + {t('parameter.upscale.order')}