1
0
Fork 0
onnx-web/api/onnx_web/upscale.py

61 lines
1.9 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from logging import getLogger
2023-02-05 13:53:26 +00:00
2023-01-16 00:04:10 +00:00
from PIL import Image
from .chain import (
2023-02-05 13:53:26 +00:00
ChainPipeline,
correct_codeformer,
correct_gfpgan,
upscale_resrgan,
2023-02-05 13:53:26 +00:00
upscale_stable_diffusion,
)
2023-02-05 13:53:26 +00:00
from .device_pool import JobContext
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .utils import ServerContext
2023-01-16 21:11:40 +00:00
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
def run_upscale_correction(
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image.Image:
2023-02-05 13:53:26 +00:00
"""
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
2023-02-05 13:53:26 +00:00
"""
logger.info("running upscaling and correction pipeline")
chain = ChainPipeline()
if upscale.scale > 1:
2023-02-05 13:53:26 +00:00
if "esrgan" in upscale.upscale_model:
2023-02-06 23:26:51 +00:00
resr_stage = StageParams(
tile_size=stage.tile_size, outscale=upscale.outscale
)
chain.append((upscale_resrgan, resr_stage, None))
2023-02-05 13:53:26 +00:00
elif "stable-diffusion" in upscale.upscale_model:
2023-01-28 20:56:06 +00:00
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, sd_stage, None))
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
if upscale.faces:
2023-02-06 23:26:51 +00:00
face_stage = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
chain.append((correct_codeformer, face_stage, None))
elif "gfpgan" in upscale.correction_model:
chain.append((correct_gfpgan, face_stage, None))
else:
logger.warn("unknown correction model: %s", upscale.correction_model)
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)