diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index f06cfaaa..9d80b83a 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -9,21 +9,22 @@ from ..utils import ServerContext logger = getLogger(__name__) -pretrain_model_url = ( - "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" -) - device = "cpu" def correct_codeformer( - _job: JobContext, + job: JobContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, - source_image: Image.Image, + source: Image.Image, + *, + source_image: Image.Image = None, **kwargs, ) -> Image.Image: - pipe = CodeFormer().to(device) + device = job.get_device() + # TODO: terrible names, fix + image = source or source_image - return pipe(source_image) + pipe = CodeFormer().to(device.torch_device()) + return pipe(image) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 1862338b..d661629f 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -4,6 +4,7 @@ from PIL import Image from .chain import ( ChainPipeline, + correct_codeformer, correct_gfpgan, upscale_resrgan, upscale_stable_diffusion, @@ -40,9 +41,16 @@ def run_upscale_correction( mini_tile = min(SizeChart.mini, stage.tile_size) stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) chain.append((upscale_stable_diffusion, stage, None)) + else: + logger.warn("unknown upscaling model: %s", upscale.upscale_model) if upscale.faces: stage = StageParams(tile_size=stage.tile_size, outscale=1) - chain.append((correct_gfpgan, stage, None)) + if "codeformer" in upscale.correction_model: + chain.append((correct_codeformer, stage, None)) + elif "gfpgan" in upscale.correction_model: + chain.append((correct_gfpgan, stage, None)) + else: + logger.warn("unknown correction model: %s", upscale.correction_model) return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)