1
0
Fork 0
onnx-web/api/onnx_web/diffusers/upscale.py

115 lines
3.6 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from logging import getLogger
from typing import List, Optional
2023-02-05 13:53:26 +00:00
2023-01-16 00:04:10 +00:00
from PIL import Image
2023-04-01 16:26:10 +00:00
from ..chain import (
2023-02-05 13:53:26 +00:00
ChainPipeline,
PipelineStage,
correct_codeformer,
correct_gfpgan,
upscale_bsrgan,
upscale_resrgan,
2023-02-05 13:53:26 +00:00
upscale_stable_diffusion,
upscale_swinir,
)
2023-04-01 16:26:10 +00:00
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
2023-01-16 21:11:40 +00:00
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
def run_upscale_correction(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
callback: Optional[ProgressCallback] = None,
pre_stages: List[PipelineStage] = None,
post_stages: List[PipelineStage] = None,
) -> Image.Image:
2023-02-05 13:53:26 +00:00
"""
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
2023-02-05 13:53:26 +00:00
"""
2023-04-01 21:38:08 +00:00
logger.info(
"running upscaling and correction pipeline at %s:%s",
upscale.scale,
upscale.outscale,
)
chain = ChainPipeline()
if pre_stages is not None:
for stage, params in pre_stages:
chain.append((stage, params))
upscale_stage = None
if upscale.scale > 1:
if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_bsrgan, bsrgan_params, None)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
2023-02-06 23:26:51 +00:00
)
upscale_stage = (upscale_resrgan, esrgan_params, None)
2023-02-05 13:53:26 +00:00
elif "stable-diffusion" in upscale.upscale_model:
2023-01-28 20:56:06 +00:00
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, None)
2023-04-10 22:57:42 +00:00
elif "swinir" in upscale.upscale_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_swinir, swinir_params, None)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
correct_stage = None
if upscale.faces:
face_params = StageParams(
2023-02-06 23:26:51 +00:00
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
correct_stage = (correct_codeformer, face_params, None)
elif "gfpgan" in upscale.correction_model:
correct_stage = (correct_gfpgan, face_params, None)
else:
logger.warn("unknown correction model: %s", upscale.correction_model)
if upscale.upscale_order == "correction-both":
chain.append(correct_stage)
chain.append(upscale_stage)
chain.append(correct_stage)
elif upscale.upscale_order == "correction-first":
chain.append(correct_stage)
chain.append(upscale_stage)
elif upscale.upscale_order == "correction-last":
chain.append(upscale_stage)
chain.append(correct_stage)
else:
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
if post_stages is not None:
for stage, params in post_stages:
chain.append((stage, params))
2023-02-12 18:33:36 +00:00
return chain(
job,
server,
params,
image,
prompt=params.prompt,
upscale=upscale,
callback=callback,
)