feat(api): implement upscaling and correction as a chain pipeline
This commit is contained in:
parent
76e25ac057
commit
bcaf0f73e6
|
@ -18,9 +18,11 @@ class StageParams:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
tile_size: int = 512,
|
||||
outscale: int = 1,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.tile_size = tile_size
|
||||
self.outscale = outscale
|
||||
|
||||
|
@ -48,7 +50,7 @@ class ChainPipeline:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[PipelineStage],
|
||||
stages: List[PipelineStage] = [],
|
||||
):
|
||||
'''
|
||||
Create a new pipeline that will run the given stages.
|
||||
|
@ -70,11 +72,12 @@ class ChainPipeline:
|
|||
image = source
|
||||
|
||||
for stage_fn, stage_params, stage_kwargs in self.stages:
|
||||
print('running pipeline stage on result image with dimensions %sx%s' %
|
||||
image.size)
|
||||
name = stage_params.label or stage_fn.__name__
|
||||
print('running pipeline stage %s on result image with dimensions %sx%s' %
|
||||
(name, image.width, image.height))
|
||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||
print('source image larger than tile size, tiling stage',
|
||||
stage_params.tile_size)
|
||||
print('source image larger than tile size of %s, tiling stage' % (
|
||||
stage_params.tile_size))
|
||||
|
||||
def stage_tile(tile: Image.Image) -> Image.Image:
|
||||
tile = stage_fn(ctx, stage_params, params, tile,
|
||||
|
@ -85,12 +88,12 @@ class ChainPipeline:
|
|||
image = process_tiles(
|
||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||
else:
|
||||
print('source image within tile size, run stage')
|
||||
print('source image within tile size, running stage')
|
||||
image = stage_fn(ctx, stage_params, params, image,
|
||||
**stage_kwargs)
|
||||
|
||||
print('finished running pipeline stage, result size: %sx%s' %
|
||||
image.size)
|
||||
print('finished running pipeline stage %s, result size: %sx%s' %
|
||||
(name, image.width, image.height))
|
||||
|
||||
print('finished running pipeline, result size: %sx%s' % image.size)
|
||||
return image
|
||||
|
|
|
@ -215,20 +215,21 @@ def run_upscale_correction(
|
|||
) -> Image.Image:
|
||||
print('running upscale pipeline')
|
||||
|
||||
chain = ChainPipeline()
|
||||
|
||||
if upscale.scale > 1:
|
||||
if 'esrgan' in upscale.upscale_model:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale)
|
||||
image = upscale_resrgan(ctx, stage, params, image, upscale=upscale)
|
||||
chain.append((upscale_resrgan, stage, {'upscale': upscale}))
|
||||
elif 'stable-diffusion' in upscale.upscale_model:
|
||||
mini_tile = min(128, stage.tile_size)
|
||||
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
image = upscale_stable_diffusion(
|
||||
ctx, stage, params, image, upscale=upscale)
|
||||
chain.append((upscale_stable_diffusion, stage, {'upscale': upscale}))
|
||||
|
||||
if upscale.faces:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale)
|
||||
image = upscale_gfpgan(ctx, stage, params, image, upscale=upscale)
|
||||
chain.append((upscale_gfpgan, stage, {'upscale': upscale}))
|
||||
|
||||
return image
|
||||
return chain(ctx, params, image)
|
||||
|
|
Loading…
Reference in New Issue