2023-01-28 23:09:19 +00:00
|
|
|
from logging import getLogger
|
2023-01-16 00:04:10 +00:00
|
|
|
from PIL import Image
|
2023-01-16 00:13:28 +00:00
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
from .chain import (
|
2023-01-28 05:28:14 +00:00
|
|
|
correct_gfpgan,
|
|
|
|
upscale_stable_diffusion,
|
|
|
|
upscale_resrgan,
|
2023-01-27 23:08:36 +00:00
|
|
|
ChainPipeline,
|
2023-01-26 03:04:00 +00:00
|
|
|
)
|
2023-01-28 04:48:06 +00:00
|
|
|
from .params import (
|
2023-01-27 23:08:36 +00:00
|
|
|
ImageParams,
|
2023-01-28 20:56:06 +00:00
|
|
|
SizeChart,
|
2023-01-28 05:28:14 +00:00
|
|
|
StageParams,
|
2023-01-28 04:48:06 +00:00
|
|
|
UpscaleParams,
|
|
|
|
)
|
|
|
|
from .utils import (
|
|
|
|
ServerContext,
|
2023-01-16 19:02:15 +00:00
|
|
|
)
|
2023-01-16 21:11:40 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-01-16 19:02:15 +00:00
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
def run_upscale_correction(
|
|
|
|
ctx: ServerContext,
|
|
|
|
stage: StageParams,
|
|
|
|
params: ImageParams,
|
|
|
|
image: Image.Image,
|
|
|
|
*,
|
|
|
|
upscale: UpscaleParams,
|
|
|
|
) -> Image.Image:
|
2023-01-28 18:42:02 +00:00
|
|
|
'''
|
|
|
|
This is a convenience method for a chain pipeline that will run upscaling and
|
|
|
|
correction, based on the `upscale` params.
|
|
|
|
'''
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('running upscaling and correction pipeline')
|
2023-01-26 02:31:39 +00:00
|
|
|
|
2023-01-27 23:38:21 +00:00
|
|
|
chain = ChainPipeline()
|
2023-01-28 04:48:06 +00:00
|
|
|
kwargs = {'upscale': upscale}
|
2023-01-27 23:38:21 +00:00
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
if upscale.scale > 1:
|
|
|
|
if 'esrgan' in upscale.upscale_model:
|
|
|
|
stage = StageParams(tile_size=stage.tile_size,
|
|
|
|
outscale=upscale.outscale)
|
2023-01-28 04:48:06 +00:00
|
|
|
chain.append((upscale_resrgan, stage, kwargs))
|
2023-01-27 23:08:36 +00:00
|
|
|
elif 'stable-diffusion' in upscale.upscale_model:
|
2023-01-28 20:56:06 +00:00
|
|
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
2023-01-27 23:08:36 +00:00
|
|
|
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
2023-01-28 04:48:06 +00:00
|
|
|
chain.append((upscale_stable_diffusion, stage, kwargs))
|
2023-01-27 23:08:36 +00:00
|
|
|
|
|
|
|
if upscale.faces:
|
|
|
|
stage = StageParams(tile_size=stage.tile_size,
|
|
|
|
outscale=upscale.outscale)
|
2023-01-28 04:48:06 +00:00
|
|
|
chain.append((correct_gfpgan, stage, kwargs))
|
2023-01-26 02:31:39 +00:00
|
|
|
|
2023-01-27 23:38:21 +00:00
|
|
|
return chain(ctx, params, image)
|