From ecf3c03f0f8433a554dc395bda01217ee8349516 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 10:59:59 -0600 Subject: [PATCH] feat: add parameter to run correction before upscaling (#132) --- api/onnx_web/chain/base.py | 3 ++- api/onnx_web/params.py | 10 ++++++++ api/onnx_web/serve.py | 10 +++++--- api/onnx_web/server/upscale.py | 25 +++++++++++++++---- api/params.json | 8 ++++++ gui/src/client.ts | 2 ++ gui/src/components/control/UpscaleControl.tsx | 20 ++++++++++++++- gui/src/state.ts | 1 + 8 files changed, 68 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 32d919ec..b144b546 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -72,7 +72,8 @@ class ChainPipeline: """ Append an additional stage to this pipeline. """ - self.stages.append(stage) + if stage is not None: + self.stages.append(stage) def __call__( self, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 4a744138..77b9a2b1 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -22,6 +22,12 @@ class TileOrder: spiral = "spiral" +class UpscaleOrder: + correct_first = "correct-first" + correct_last = "correct-last" + correct_both = "correct-both" + + Param = Union[str, int, float] Point = Tuple[int, int] @@ -168,6 +174,7 @@ class UpscaleParams: scale: int = 4, pre_pad: int = 0, tile_pad: int = 10, + upscale_order: Literal["correction-first", "correction-last", "correction-both"] = "upscaling-first", ) -> None: self.upscale_model = upscale_model self.correction_model = correction_model @@ -181,6 +188,7 @@ class UpscaleParams: self.pre_pad = pre_pad self.scale = scale self.tile_pad = tile_pad + self.upscale_order = upscale_order def rescale(self, scale: int): return UpscaleParams( @@ -196,6 +204,7 @@ class UpscaleParams: scale=scale, pre_pad=self.pre_pad, tile_pad=self.tile_pad, + upscale_order=self.upscale_order, ) def resize(self, size: Size) -> Size: @@ -218,4 +227,5 @@ class UpscaleParams: "pre_pad": self.pre_pad, "scale": self.scale, "tile_pad": self.tile_pad, + "upscale_order": self.upscale_order, } diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 9e492e19..4818c64c 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -255,6 +255,7 @@ def upscale_from_request() -> UpscaleParams: faces = get_not_empty(request.args, "faces", "false") == "true" face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) + upscale_order = request.args.get("upscaleOrder", "correction-first") return UpscaleParams( upscaling, @@ -266,10 +267,11 @@ def upscale_from_request() -> UpscaleParams: format="onnx", outscale=outscale, scale=scale, + upscale_order=upscale_order, ) -def check_paths(context: ServerContext): +def check_paths(context: ServerContext) -> None: if not path.exists(context.model_path): raise RuntimeError("model path must exist") @@ -283,7 +285,7 @@ def get_model_name(model: str) -> str: return file -def load_models(context: ServerContext): +def load_models(context: ServerContext) -> None: global diffusion_models global correction_models global upscaling_models @@ -313,7 +315,7 @@ def load_models(context: ServerContext): upscaling_models.sort() -def load_params(context: ServerContext): +def load_params(context: ServerContext) -> None: global config_params params_file = path.join(context.params_path, "params.json") with open(params_file, "r") as f: @@ -328,7 +330,7 @@ def load_params(context: ServerContext): config_platform["default"] = context.default_platform -def load_platforms(context: ServerContext): +def load_platforms(context: ServerContext) -> None: global available_platforms providers = list(get_available_providers()) diff --git a/api/onnx_web/server/upscale.py b/api/onnx_web/server/upscale.py index 702ef72c..128e7d3c 100644 --- a/api/onnx_web/server/upscale.py +++ b/api/onnx_web/server/upscale.py @@ -36,27 +36,42 @@ def run_upscale_correction( if upscale.scale > 1: if "esrgan" in upscale.upscale_model: - esrgan_stage = StageParams( + esrgan_params = StageParams( tile_size=stage.tile_size, outscale=upscale.outscale ) - chain.append((upscale_resrgan, esrgan_stage, None)) + upscale_stage = (upscale_resrgan, esrgan_params, None) elif "stable-diffusion" in upscale.upscale_model: mini_tile = min(SizeChart.mini, stage.tile_size) sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) - chain.append((upscale_stable_diffusion, sd_stage, None)) + upscale_stage = (upscale_stable_diffusion, sd_stage, None) else: logger.warn("unknown upscaling model: %s", upscale.upscale_model) + upscale_stage = None if upscale.faces: face_stage = StageParams( tile_size=stage.tile_size, outscale=upscale.face_outscale ) if "codeformer" in upscale.correction_model: - chain.append((correct_codeformer, face_stage, None)) + correct_stage = (correct_codeformer, face_stage, None) elif "gfpgan" in upscale.correction_model: - chain.append((correct_gfpgan, face_stage, None)) + correct_stage = (correct_gfpgan, face_stage, None) else: logger.warn("unknown correction model: %s", upscale.correction_model) + correct_stage = None + + if upscale.upscale_order == "correction-both": + chain.append(correct_stage) + chain.append(upscale_stage) + chain.append(correct_stage) + elif upscale.upscale_order == "correction-first": + chain.append(correct_stage) + chain.append(upscale_stage) + elif upscale.upscale_order == "correction-last": + chain.append(upscale_stage) + chain.append(correct_stage) + else: + logger.warn("unknown upscaling order: %s", upscale.upscale_order) return chain( job, diff --git a/api/params.json b/api/params.json index 80f7f34f..a15164db 100644 --- a/api/params.json +++ b/api/params.json @@ -127,6 +127,14 @@ "max": 512, "step": 8 }, + "upscaleOrder": { + "default": "correction-last", + "keys": [ + "correction-both", + "correction-first", + "correction-last" + ] + }, "upscaling": { "default": "", "keys": [] diff --git a/gui/src/client.ts b/gui/src/client.ts index a0a69645..ee21b163 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -114,6 +114,8 @@ export interface BrushParams { */ export interface UpscaleParams { enabled: boolean; + upscaleOrder: string; + denoise: number; scale: number; outscale: number; diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx index 98195606..7ba1ac37 100644 --- a/gui/src/components/control/UpscaleControl.tsx +++ b/gui/src/components/control/UpscaleControl.tsx @@ -1,5 +1,6 @@ import { mustExist } from '@apextoaster/js-utils'; -import { Checkbox, FormControlLabel, Stack } from '@mui/material'; +import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material'; +import { startCase } from 'lodash'; import * as React from 'react'; import { useContext } from 'react'; import { useStore } from 'zustand'; @@ -106,5 +107,22 @@ export function UpscaleControl() { }); }} /> + + Upscale Order + + ; } diff --git a/gui/src/state.ts b/gui/src/state.ts index 535fbc55..563a5cfd 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -408,6 +408,7 @@ export function createStateSlices(server: ServerParams) { faceStrength: server.faceStrength.default, outscale: server.outscale.default, scale: server.scale.default, + upscaleOrder: server.upscaleOrder.default, }, upscaleTab: { source: null,