feat(api): add CodeFormer to automatic upscale
This commit is contained in:
parent
9d0609fefe
commit
0a9f108156
|
@ -9,21 +9,22 @@ from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
pretrain_model_url = (
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
|
||||||
)
|
|
||||||
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
|
||||||
def correct_codeformer(
|
def correct_codeformer(
|
||||||
_job: JobContext,
|
job: JobContext,
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source: Image.Image,
|
||||||
|
*,
|
||||||
|
source_image: Image.Image = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
pipe = CodeFormer().to(device)
|
device = job.get_device()
|
||||||
|
# TODO: terrible names, fix
|
||||||
|
image = source or source_image
|
||||||
|
|
||||||
return pipe(source_image)
|
pipe = CodeFormer().to(device.torch_device())
|
||||||
|
return pipe(image)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from PIL import Image
|
||||||
|
|
||||||
from .chain import (
|
from .chain import (
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
|
correct_codeformer,
|
||||||
correct_gfpgan,
|
correct_gfpgan,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
|
@ -40,9 +41,16 @@ def run_upscale_correction(
|
||||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||||
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||||
chain.append((upscale_stable_diffusion, stage, None))
|
chain.append((upscale_stable_diffusion, stage, None))
|
||||||
|
else:
|
||||||
|
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||||
|
|
||||||
if upscale.faces:
|
if upscale.faces:
|
||||||
stage = StageParams(tile_size=stage.tile_size, outscale=1)
|
stage = StageParams(tile_size=stage.tile_size, outscale=1)
|
||||||
chain.append((correct_gfpgan, stage, None))
|
if "codeformer" in upscale.correction_model:
|
||||||
|
chain.append((correct_codeformer, stage, None))
|
||||||
|
elif "gfpgan" in upscale.correction_model:
|
||||||
|
chain.append((correct_gfpgan, stage, None))
|
||||||
|
else:
|
||||||
|
logger.warn("unknown correction model: %s", upscale.correction_model)
|
||||||
|
|
||||||
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
||||||
|
|
Loading…
Reference in New Issue