diff --git a/api/onnx_web/chain/upscale.py b/api/onnx_web/chain/upscale.py index f8880802..54a1754b 100644 --- a/api/onnx_web/chain/upscale.py +++ b/api/onnx_web/chain/upscale.py @@ -68,31 +68,30 @@ def stage_upscale_correction( "upscale": upscale, } upscale_stage: Optional[PipelineStage] = None - if upscale.scale > 1: - if "bsrgan" in upscale.upscale_model: - bsrgan_params = StageParams( - tile_size=stage.tile_size, - outscale=upscale.outscale, - ) - upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts) - elif "esrgan" in upscale.upscale_model: - esrgan_params = StageParams( - tile_size=stage.tile_size, - outscale=upscale.outscale, - ) - upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts) - elif "stable-diffusion" in upscale.upscale_model: - mini_tile = min(SizeChart.mini, stage.tile_size) - sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) - upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts) - elif "swinir" in upscale.upscale_model: - swinir_params = StageParams( - tile_size=stage.tile_size, - outscale=upscale.outscale, - ) - upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts) - else: - logger.warning("unknown upscaling model: %s", upscale.upscale_model) + if "bsrgan" in upscale.upscale_model: + bsrgan_params = StageParams( + tile_size=stage.tile_size, + outscale=upscale.outscale, + ) + upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts) + elif "esrgan" in upscale.upscale_model: + esrgan_params = StageParams( + tile_size=stage.tile_size, + outscale=upscale.outscale, + ) + upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts) + elif "stable-diffusion" in upscale.upscale_model: + mini_tile = min(SizeChart.mini, stage.tile_size) + sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) + upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts) + elif "swinir" in upscale.upscale_model: + swinir_params = StageParams( + tile_size=stage.tile_size, + outscale=upscale.outscale, + ) + upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts) + else: + logger.warning("unknown upscaling model: %s", upscale.upscale_model) correct_stage: Optional[PipelineStage] = None if upscale.faces: