1
0
Fork 0

feat(api): enable 1x upscaling models

This commit is contained in:
Sean Sube 2023-12-30 11:59:52 -06:00
parent 11e643bcb5
commit 7abe6dc6a9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 24 additions and 25 deletions

View File

@ -68,31 +68,30 @@ def stage_upscale_correction(
"upscale": upscale, "upscale": upscale,
} }
upscale_stage: Optional[PipelineStage] = None upscale_stage: Optional[PipelineStage] = None
if upscale.scale > 1: if "bsrgan" in upscale.upscale_model:
if "bsrgan" in upscale.upscale_model: bsrgan_params = StageParams(
bsrgan_params = StageParams( tile_size=stage.tile_size,
tile_size=stage.tile_size, outscale=upscale.outscale,
outscale=upscale.outscale, )
) upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts) elif "esrgan" in upscale.upscale_model:
elif "esrgan" in upscale.upscale_model: esrgan_params = StageParams(
esrgan_params = StageParams( tile_size=stage.tile_size,
tile_size=stage.tile_size, outscale=upscale.outscale,
outscale=upscale.outscale, )
) upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts) elif "stable-diffusion" in upscale.upscale_model:
elif "stable-diffusion" in upscale.upscale_model: mini_tile = min(SizeChart.mini, stage.tile_size)
mini_tile = min(SizeChart.mini, stage.tile_size) sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts) elif "swinir" in upscale.upscale_model:
elif "swinir" in upscale.upscale_model: swinir_params = StageParams(
swinir_params = StageParams( tile_size=stage.tile_size,
tile_size=stage.tile_size, outscale=upscale.outscale,
outscale=upscale.outscale, )
) upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts) else:
else: logger.warning("unknown upscaling model: %s", upscale.upscale_model)
logger.warning("unknown upscaling model: %s", upscale.upscale_model)
correct_stage: Optional[PipelineStage] = None correct_stage: Optional[PipelineStage] = None
if upscale.faces: if upscale.faces: