From 99c91a301c9c59da03461c59b7800a6c12fff295 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 2 Jul 2023 19:07:59 -0500 Subject: [PATCH] feat: make enabling highres a parameter of its own --- api/onnx_web/chain/highres.py | 57 ++++++++++++++++-------------- api/onnx_web/diffusers/run.py | 66 +++++++++++++++++++---------------- api/onnx_web/params.py | 3 ++ api/onnx_web/server/params.py | 2 ++ api/params.json | 3 ++ gui/src/client/api.ts | 1 + gui/src/config.json | 3 ++ 7 files changed, 79 insertions(+), 56 deletions(-) diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index c2166ef3..da1f7412 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -22,36 +22,41 @@ def stage_highres( if chain is None: chain = ChainPipeline() + if not highres.enabled: + logger.debug("highres not enabled, skipping") + return chain + if highres.iterations < 1: logger.debug("no highres iterations, skipping") return chain - if highres.method == "upscale": - logger.debug("using upscaling pipeline for highres") - stage_upscale_correction( - stage, - params, - upscale=upscale.with_args( - faces=False, - scale=highres.scale, - outscale=highres.scale, - ), - chain=chain, - ) - else: - logger.debug("using simple upscaling for highres") - chain.stage( - UpscaleSimpleStage(), - stage, - method=highres.method, - upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale), - ) + for _i in range(highres.iterations): + if highres.method == "upscale": + logger.debug("using upscaling pipeline for highres") + stage_upscale_correction( + stage, + params, + upscale=upscale.with_args( + faces=False, + scale=highres.scale, + outscale=highres.scale, + ), + chain=chain, + ) + else: + logger.debug("using simple upscaling for highres") + chain.stage( + UpscaleSimpleStage(), + stage, + method=highres.method, + upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale), + ) - chain.stage( - BlendImg2ImgStage(), - stage, - overlap=params.overlap, - strength=highres.strength, - ) + chain.stage( + BlendImg2ImgStage(), + stage, + overlap=params.overlap, + strength=highres.strength, + ) return chain diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 7d630d1c..8d3083d7 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -62,14 +62,13 @@ def run_txt2img_pipeline( ) # apply highres - for _i in range(highres.iterations): - stage_highres( - stage, - params, - highres, - upscale, - chain=chain, - ) + stage_highres( + stage, + params, + highres, + upscale, + chain=chain, + ) # apply upscaling and correction, after highres stage_upscale_correction( @@ -152,14 +151,13 @@ def run_img2img_pipeline( ) # highres, if selected - for _i in range(highres.iterations): - stage_highres( - stage, - params, - highres, - upscale, - chain=chain, - ) + stage_highres( + stage, + params, + highres, + upscale, + chain=chain, + ) # apply upscaling and correction, after highres stage_upscale_correction( @@ -234,21 +232,30 @@ def run_inpaint_pipeline( noise_source=noise_source, ) - # apply highres - for _i in range(highres.iterations): - stage_highres( + # apply upscaling and correction, before highres + first_upscale, after_upscale = split_upscale(upscale) + if first_upscale: + stage_upscale_correction( stage, params, - highres, - upscale, + upscale=first_upscale, chain=chain, ) + # apply highres + stage_highres( + stage, + params, + highres, + upscale, + chain=chain, + ) + # apply upscaling and correction stage_upscale_correction( stage, params, - upscale=upscale, + upscale=after_upscale, chain=chain, ) @@ -303,14 +310,13 @@ def run_upscale_pipeline( ) # apply highres - for _i in range(highres.iterations): - stage_highres( - stage, - params, - highres, - upscale, - chain=chain, - ) + stage_highres( + stage, + params, + highres, + upscale, + chain=chain, + ) # apply upscaling and correction, after highres stage_upscale_correction( diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 9338db26..84dd1869 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -421,12 +421,14 @@ class UpscaleParams: class HighresParams: def __init__( self, + enabled: bool, scale: int, steps: int, strength: float, method: Literal["bilinear", "lanczos", "upscale"] = "lanczos", iterations: int = 1, ): + self.enabled = enabled self.scale = scale self.steps = steps self.strength = strength @@ -441,6 +443,7 @@ class HighresParams: def tojson(self): return { + "enabled": self.enabled, "iterations": self.iterations, "method": self.method, "scale": self.scale, diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index caf0ec82..2598e819 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -283,6 +283,7 @@ def upscale_from_request() -> UpscaleParams: def highres_from_request() -> HighresParams: + enabled = get_boolean(request.args, "highres", get_config_value("highres")) iterations = get_and_clamp_int( request.args, "highresIterations", @@ -313,6 +314,7 @@ def highres_from_request() -> HighresParams: get_config_value("highresStrength", "min"), ) return HighresParams( + enabled, scale, steps, strength, diff --git a/api/params.json b/api/params.json index 7f21d78c..7f1e9a54 100644 --- a/api/params.json +++ b/api/params.json @@ -64,6 +64,9 @@ "max": 8192, "step": 8 }, + "highres": { + "default": false + }, "highresIterations": { "default": 1, "min": 1, diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index efcfd7dc..6bed1e46 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -123,6 +123,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { export function appendHighresToURL(url: URL, highres: HighresParams) { if (highres.enabled) { + url.searchParams.append('highres', String(highres.enabled)); url.searchParams.append('highresIterations', highres.highresIterations.toFixed(FIXED_INTEGER)); url.searchParams.append('highresMethod', highres.highresMethod); url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER)); diff --git a/gui/src/config.json b/gui/src/config.json index b1cdd475..97a53d25 100644 --- a/gui/src/config.json +++ b/gui/src/config.json @@ -68,6 +68,9 @@ "max": 8192, "step": 8 }, + "highres": { + "default": false + }, "highresMethod": { "default": "lanczos", "keys": [