2023-01-28 23:09:19 +00:00
|
|
|
from logging import getLogger
|
2023-02-05 13:53:26 +00:00
|
|
|
|
2023-01-16 00:04:10 +00:00
|
|
|
from PIL import Image
|
2023-01-16 00:13:28 +00:00
|
|
|
|
2023-02-19 02:28:21 +00:00
|
|
|
from .chain import (
|
2023-02-05 13:53:26 +00:00
|
|
|
ChainPipeline,
|
2023-02-05 16:49:20 +00:00
|
|
|
correct_codeformer,
|
2023-01-28 05:28:14 +00:00
|
|
|
correct_gfpgan,
|
|
|
|
upscale_resrgan,
|
2023-02-05 13:53:26 +00:00
|
|
|
upscale_stable_diffusion,
|
2023-01-16 19:02:15 +00:00
|
|
|
)
|
2023-02-19 02:28:21 +00:00
|
|
|
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
2023-02-26 05:49:39 +00:00
|
|
|
from .server import ServerContext
|
2023-02-26 20:15:30 +00:00
|
|
|
from .worker import ProgressCallback, WorkerContext
|
2023-01-16 21:11:40 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-01-16 19:02:15 +00:00
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
def run_upscale_correction(
|
2023-02-26 05:49:39 +00:00
|
|
|
job: WorkerContext,
|
2023-02-04 20:52:23 +00:00
|
|
|
server: ServerContext,
|
2023-01-27 23:08:36 +00:00
|
|
|
stage: StageParams,
|
|
|
|
params: ImageParams,
|
|
|
|
image: Image.Image,
|
|
|
|
*,
|
|
|
|
upscale: UpscaleParams,
|
2023-02-12 18:33:36 +00:00
|
|
|
callback: ProgressCallback = None,
|
2023-01-27 23:08:36 +00:00
|
|
|
) -> Image.Image:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-28 18:42:02 +00:00
|
|
|
This is a convenience method for a chain pipeline that will run upscaling and
|
|
|
|
correction, based on the `upscale` params.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
|
|
|
logger.info("running upscaling and correction pipeline")
|
2023-01-26 02:31:39 +00:00
|
|
|
|
2023-01-27 23:38:21 +00:00
|
|
|
chain = ChainPipeline()
|
|
|
|
|
2023-02-18 17:59:39 +00:00
|
|
|
upscale_stage = None
|
2023-01-27 23:08:36 +00:00
|
|
|
if upscale.scale > 1:
|
2023-02-05 13:53:26 +00:00
|
|
|
if "esrgan" in upscale.upscale_model:
|
2023-02-18 16:59:59 +00:00
|
|
|
esrgan_params = StageParams(
|
2023-02-06 23:26:51 +00:00
|
|
|
tile_size=stage.tile_size, outscale=upscale.outscale
|
|
|
|
)
|
2023-02-18 16:59:59 +00:00
|
|
|
upscale_stage = (upscale_resrgan, esrgan_params, None)
|
2023-02-05 13:53:26 +00:00
|
|
|
elif "stable-diffusion" in upscale.upscale_model:
|
2023-01-28 20:56:06 +00:00
|
|
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
2023-02-18 17:59:39 +00:00
|
|
|
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
|
|
|
upscale_stage = (upscale_stable_diffusion, sd_params, None)
|
2023-02-05 16:49:20 +00:00
|
|
|
else:
|
|
|
|
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-02-18 17:59:39 +00:00
|
|
|
correct_stage = None
|
2023-01-27 23:08:36 +00:00
|
|
|
if upscale.faces:
|
2023-02-18 17:59:39 +00:00
|
|
|
face_params = StageParams(
|
2023-02-06 23:26:51 +00:00
|
|
|
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
|
|
|
)
|
2023-02-05 16:49:20 +00:00
|
|
|
if "codeformer" in upscale.correction_model:
|
2023-02-18 17:59:39 +00:00
|
|
|
correct_stage = (correct_codeformer, face_params, None)
|
2023-02-05 16:49:20 +00:00
|
|
|
elif "gfpgan" in upscale.correction_model:
|
2023-02-18 17:59:39 +00:00
|
|
|
correct_stage = (correct_gfpgan, face_params, None)
|
2023-02-05 16:49:20 +00:00
|
|
|
else:
|
|
|
|
logger.warn("unknown correction model: %s", upscale.correction_model)
|
2023-02-18 16:59:59 +00:00
|
|
|
|
|
|
|
if upscale.upscale_order == "correction-both":
|
|
|
|
chain.append(correct_stage)
|
|
|
|
chain.append(upscale_stage)
|
|
|
|
chain.append(correct_stage)
|
|
|
|
elif upscale.upscale_order == "correction-first":
|
|
|
|
chain.append(correct_stage)
|
|
|
|
chain.append(upscale_stage)
|
|
|
|
elif upscale.upscale_order == "correction-last":
|
|
|
|
chain.append(upscale_stage)
|
|
|
|
chain.append(correct_stage)
|
|
|
|
else:
|
|
|
|
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
|
2023-01-26 02:31:39 +00:00
|
|
|
|
2023-02-12 18:33:36 +00:00
|
|
|
return chain(
|
|
|
|
job,
|
|
|
|
server,
|
|
|
|
params,
|
|
|
|
image,
|
|
|
|
prompt=params.prompt,
|
|
|
|
upscale=upscale,
|
|
|
|
callback=callback,
|
|
|
|
)
|