1
0
Fork 0

feat(api): add CodeFormer to automatic upscale

This commit is contained in:
Sean Sube 2023-02-05 10:49:20 -06:00
parent 9d0609fefe
commit 0a9f108156
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 18 additions and 9 deletions

View File

@ -9,21 +9,22 @@ from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
pretrain_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
device = "cpu" device = "cpu"
def correct_codeformer( def correct_codeformer(
_job: JobContext, job: JobContext,
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source: Image.Image,
*,
source_image: Image.Image = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
pipe = CodeFormer().to(device) device = job.get_device()
# TODO: terrible names, fix
image = source or source_image
return pipe(source_image) pipe = CodeFormer().to(device.torch_device())
return pipe(image)

View File

@ -4,6 +4,7 @@ from PIL import Image
from .chain import ( from .chain import (
ChainPipeline, ChainPipeline,
correct_codeformer,
correct_gfpgan, correct_gfpgan,
upscale_resrgan, upscale_resrgan,
upscale_stable_diffusion, upscale_stable_diffusion,
@ -40,9 +41,16 @@ def run_upscale_correction(
mini_tile = min(SizeChart.mini, stage.tile_size) mini_tile = min(SizeChart.mini, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, stage, None)) chain.append((upscale_stable_diffusion, stage, None))
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
if upscale.faces: if upscale.faces:
stage = StageParams(tile_size=stage.tile_size, outscale=1) stage = StageParams(tile_size=stage.tile_size, outscale=1)
chain.append((correct_gfpgan, stage, None)) if "codeformer" in upscale.correction_model:
chain.append((correct_codeformer, stage, None))
elif "gfpgan" in upscale.correction_model:
chain.append((correct_gfpgan, stage, None))
else:
logger.warn("unknown correction model: %s", upscale.correction_model)
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale) return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)