1
0
Fork 0

fix(api): use kwargs for chain stages

This commit is contained in:
Sean Sube 2023-06-30 21:42:24 -05:00
parent 7a73c9ff61
commit 2d10252564
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 80 additions and 89 deletions

View File

@ -17,7 +17,7 @@ from .diffusers.run import (
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import append_upscale_correction
from .diffusers.upscale import stage_upscale_correction
from .image.utils import (
expand_image,
valid_image,

View File

@ -78,11 +78,31 @@ class ChainPipeline:
def append(self, stage: PipelineStage):
"""
DEPRECATED: use `stage` instead
Append an additional stage to this pipeline.
"""
if stage is not None:
self.stages.append(stage)
def run(
self,
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
"""
return self(job, server, params, source=source, callback=callback, **kwargs)
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
def __call__(
self,
job: WorkerContext,
@ -93,7 +113,7 @@ class ChainPipeline:
**pipeline_kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
DEPRECATED: use `run` instead
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)

View File

@ -4,8 +4,8 @@ from typing import Any, Optional
from PIL import Image
from ..chain.base import ChainPipeline
from ..chain.img2img import blend_img2img
from ..diffusers.upscale import append_upscale_correction
from ..chain.blend_img2img import blend_img2img
from ..diffusers.upscale import stage_upscale_correction
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
@ -45,7 +45,7 @@ def upscale_highres(
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
append_upscale_correction(
stage_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(

View File

@ -24,7 +24,7 @@ from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
from ..worker import WorkerContext
from .upscale import append_upscale_correction, split_upscale
from .upscale import split_upscale, stage_upscale_correction
from .utils import parse_prompt
logger = getLogger(__name__)
@ -42,20 +42,16 @@ def run_txt2img_pipeline(
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
chain.append(
(
source_txt2img,
stage,
{
"size": size,
},
)
chain.stage(
source_txt2img,
stage,
size=size,
)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
@ -64,22 +60,21 @@ def run_txt2img_pipeline(
# apply highres
for _i in range(highres.iterations):
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
StageParams(
outscale=highres.scale,
),
highres=highres,
upscale=upscale,
overlap=params.overlap,
)
# apply upscaling and correction, after highres
append_upscale_correction(
StageParams(),
stage_upscale_correction(
stage,
params,
upscale=upscale,
upscale=after_upscale,
chain=chain,
)
@ -128,20 +123,16 @@ def run_img2img_pipeline(
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
chain.append(
(
blend_img2img,
stage,
{
"strength": strength,
},
)
chain.stage(
blend_img2img,
stage,
strength=strength,
)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
@ -151,32 +142,24 @@ def run_img2img_pipeline(
# loopback through multiple img2img iterations
if params.loopback > 0:
for _i in range(params.loopback):
chain.append(
(
blend_img2img,
stage,
{
"strength": strength,
},
)
chain.stage(
blend_img2img,
stage,
strength=strength,
)
# highres, if selected
if highres.iterations > 0:
for _i in range(highres.iterations):
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)
# apply upscaling and correction, after highres
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
@ -237,34 +220,26 @@ def run_inpaint_pipeline(
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order)
chain.append(
(
upscale_outpaint,
stage,
{
"border": border,
"stage_mask": mask,
"fill_color": fill_color,
"mask_filter": mask_filter,
"noise_source": noise_source,
},
)
chain.stage(
upscale_outpaint,
stage,
border=border,
stage_mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
)
# apply highres
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)
# apply upscaling and correction
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=upscale,
@ -313,7 +288,7 @@ def run_upscale_pipeline(
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
@ -321,19 +296,15 @@ def run_upscale_pipeline(
)
# apply highres
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)
# apply upscaling and correction, after highres
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
@ -380,7 +351,7 @@ def run_blend_pipeline(
stage.append((blend_mask, stage, None))
# apply upscaling and correction
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=upscale,

View File

@ -36,7 +36,7 @@ def split_upscale(
)
def append_upscale_correction(
def stage_upscale_correction(
stage: StageParams,
params: ImageParams,
*,