1
0
Fork 0

feat(api): implement upscaling and correction as a chain pipeline

This commit is contained in:
Sean Sube 2023-01-27 17:38:21 -06:00
parent 76e25ac057
commit bcaf0f73e6
2 changed files with 17 additions and 13 deletions

View File

@ -18,9 +18,11 @@ class StageParams:
def __init__(
self,
name: Optional[str] = None,
tile_size: int = 512,
outscale: int = 1,
) -> None:
self.name = name
self.tile_size = tile_size
self.outscale = outscale
@ -48,7 +50,7 @@ class ChainPipeline:
def __init__(
self,
stages: List[PipelineStage],
stages: List[PipelineStage] = [],
):
'''
Create a new pipeline that will run the given stages.
@ -70,11 +72,12 @@ class ChainPipeline:
image = source
for stage_fn, stage_params, stage_kwargs in self.stages:
print('running pipeline stage on result image with dimensions %sx%s' %
image.size)
name = stage_params.label or stage_fn.__name__
print('running pipeline stage %s on result image with dimensions %sx%s' %
(name, image.width, image.height))
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
print('source image larger than tile size, tiling stage',
stage_params.tile_size)
print('source image larger than tile size of %s, tiling stage' % (
stage_params.tile_size))
def stage_tile(tile: Image.Image) -> Image.Image:
tile = stage_fn(ctx, stage_params, params, tile,
@ -85,12 +88,12 @@ class ChainPipeline:
image = process_tiles(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
print('source image within tile size, run stage')
print('source image within tile size, running stage')
image = stage_fn(ctx, stage_params, params, image,
**stage_kwargs)
print('finished running pipeline stage, result size: %sx%s' %
image.size)
print('finished running pipeline stage %s, result size: %sx%s' %
(name, image.width, image.height))
print('finished running pipeline, result size: %sx%s' % image.size)
return image

View File

@ -215,20 +215,21 @@ def run_upscale_correction(
) -> Image.Image:
print('running upscale pipeline')
chain = ChainPipeline()
if upscale.scale > 1:
if 'esrgan' in upscale.upscale_model:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
image = upscale_resrgan(ctx, stage, params, image, upscale=upscale)
chain.append((upscale_resrgan, stage, {'upscale': upscale}))
elif 'stable-diffusion' in upscale.upscale_model:
mini_tile = min(128, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
image = upscale_stable_diffusion(
ctx, stage, params, image, upscale=upscale)
chain.append((upscale_stable_diffusion, stage, {'upscale': upscale}))
if upscale.faces:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
image = upscale_gfpgan(ctx, stage, params, image, upscale=upscale)
chain.append((upscale_gfpgan, stage, {'upscale': upscale}))
return image
return chain(ctx, params, image)