diff --git a/api/onnx_web/chain/upscale.py b/api/onnx_web/chain/upscale.py index 54a1754b..faafdb91 100644 --- a/api/onnx_web/chain/upscale.py +++ b/api/onnx_web/chain/upscale.py @@ -67,31 +67,33 @@ def stage_upscale_correction( **kwargs, "upscale": upscale, } + upscale_stage: Optional[PipelineStage] = None - if "bsrgan" in upscale.upscale_model: - bsrgan_params = StageParams( - tile_size=stage.tile_size, - outscale=upscale.outscale, - ) - upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts) - elif "esrgan" in upscale.upscale_model: - esrgan_params = StageParams( - tile_size=stage.tile_size, - outscale=upscale.outscale, - ) - upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts) - elif "stable-diffusion" in upscale.upscale_model: - mini_tile = min(SizeChart.mini, stage.tile_size) - sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) - upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts) - elif "swinir" in upscale.upscale_model: - swinir_params = StageParams( - tile_size=stage.tile_size, - outscale=upscale.outscale, - ) - upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts) - else: - logger.warning("unknown upscaling model: %s", upscale.upscale_model) + if upscale.upscale: + if "bsrgan" in upscale.upscale_model: + bsrgan_params = StageParams( + tile_size=stage.tile_size, + outscale=upscale.outscale, + ) + upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts) + elif "esrgan" in upscale.upscale_model: + esrgan_params = StageParams( + tile_size=stage.tile_size, + outscale=upscale.outscale, + ) + upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts) + elif "stable-diffusion" in upscale.upscale_model: + mini_tile = min(SizeChart.mini, stage.tile_size) + sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) + upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts) + elif "swinir" in upscale.upscale_model: + swinir_params = StageParams( + tile_size=stage.tile_size, + outscale=upscale.outscale, + ) + upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts) + else: + logger.warning("unknown upscaling model: %s", upscale.upscale_model) correct_stage: Optional[PipelineStage] = None if upscale.faces: diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index f4b74659..b3e03670 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -391,6 +391,7 @@ class UpscaleParams: upscale_model: str, correction_model: Optional[str] = None, denoise: float = 0.5, + upscale=True, faces=True, face_outscale: int = 1, face_strength: float = 0.5, @@ -406,6 +407,7 @@ class UpscaleParams: self.upscale_model = upscale_model self.correction_model = correction_model self.denoise = denoise + self.upscale = upscale self.faces = faces self.face_outscale = face_outscale self.face_strength = face_strength @@ -421,6 +423,7 @@ class UpscaleParams: self.upscale_model, correction_model=self.correction_model, denoise=self.denoise, + upscale=self.upscale, faces=self.faces, face_outscale=self.face_outscale, face_strength=self.face_strength, @@ -447,6 +450,7 @@ class UpscaleParams: "upscale_model": self.upscale_model, "correction_model": self.correction_model, "denoise": self.denoise, + "upscale": self.upscale, "faces": self.faces, "face_outscale": self.face_outscale, "face_strength": self.face_strength, @@ -463,6 +467,7 @@ class UpscaleParams: kwargs.get("upscale_model", self.upscale_model), kwargs.get("correction_model", self.correction_model), kwargs.get("denoise", self.denoise), + kwargs.get("upscale", self.upscale), kwargs.get("faces", self.faces), kwargs.get("face_outscale", self.face_outscale), kwargs.get("face_strength", self.face_strength), diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 39468542..b19f541a 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -239,6 +239,7 @@ def build_upscale( if data is None: data = request.args + upscale = get_boolean(data, "upscale", False) denoise = get_and_clamp_float( data, "denoise", @@ -262,7 +263,8 @@ def build_upscale( ) upscaling = get_from_list(data, "upscaling", get_upscaling_models()) correction = get_from_list(data, "correction", get_correction_models()) - faces = get_not_empty(data, "faces", "false") == "true" + + faces = get_boolean(data, "faces", False) face_outscale = get_and_clamp_int( data, "faceOutscale", @@ -283,6 +285,7 @@ def build_upscale( upscaling, correction_model=correction, denoise=denoise, + upscale=upscale, faces=faces, face_outscale=face_outscale, face_strength=face_strength, diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index e2c974f6..8ea871ea 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -110,6 +110,7 @@ export function appendModelToURL(url: URL, params: ModelParams) { * Append the upscale parameters to an existing URL. */ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { + url.searchParams.append('upscale', String(upscale.enabled)); url.searchParams.append('upscaleOrder', upscale.upscaleOrder); if (upscale.enabled) {