1
0
Fork 0

fix(api): use kwargs for chain stages

This commit is contained in:
Sean Sube 2023-06-30 21:42:24 -05:00
parent 7a73c9ff61
commit 2d10252564
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 80 additions and 89 deletions

View File

@ -17,7 +17,7 @@ from .diffusers.run import (
run_upscale_pipeline, run_upscale_pipeline,
) )
from .diffusers.stub_scheduler import StubScheduler from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import append_upscale_correction from .diffusers.upscale import stage_upscale_correction
from .image.utils import ( from .image.utils import (
expand_image, expand_image,
valid_image, valid_image,

View File

@ -78,11 +78,31 @@ class ChainPipeline:
def append(self, stage: PipelineStage): def append(self, stage: PipelineStage):
""" """
DEPRECATED: use `stage` instead
Append an additional stage to this pipeline. Append an additional stage to this pipeline.
""" """
if stage is not None: if stage is not None:
self.stages.append(stage) self.stages.append(stage)
def run(
self,
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
"""
return self(job, server, params, source=source, callback=callback, **kwargs)
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
def __call__( def __call__(
self, self,
job: WorkerContext, job: WorkerContext,
@ -93,7 +113,7 @@ class ChainPipeline:
**pipeline_kwargs **pipeline_kwargs
) -> Image.Image: ) -> Image.Image:
""" """
TODO: handle List[Image] inputs and outputs DEPRECATED: use `run` instead
""" """
if callback is not None: if callback is not None:
callback = ChainProgress.from_progress(callback) callback = ChainProgress.from_progress(callback)

View File

@ -4,8 +4,8 @@ from typing import Any, Optional
from PIL import Image from PIL import Image
from ..chain.base import ChainPipeline from ..chain.base import ChainPipeline
from ..chain.img2img import blend_img2img from ..chain.blend_img2img import blend_img2img
from ..diffusers.upscale import append_upscale_correction from ..diffusers.upscale import stage_upscale_correction
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
@ -45,7 +45,7 @@ def upscale_highres(
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else: else:
logger.debug("using upscaling pipeline for highres") logger.debug("using upscaling pipeline for highres")
append_upscale_correction( stage_upscale_correction(
StageParams(), StageParams(),
params, params,
upscale=upscale.with_args( upscale=upscale.with_args(

View File

@ -24,7 +24,7 @@ from ..server import ServerContext
from ..server.load import get_source_filters from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast from ..utils import run_gc, show_system_toast
from ..worker import WorkerContext from ..worker import WorkerContext
from .upscale import append_upscale_correction, split_upscale from .upscale import split_upscale, stage_upscale_correction
from .utils import parse_prompt from .utils import parse_prompt
logger = getLogger(__name__) logger = getLogger(__name__)
@ -42,20 +42,16 @@ def run_txt2img_pipeline(
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams()
chain.append( chain.stage(
( source_txt2img,
source_txt2img, stage,
stage, size=size,
{
"size": size,
},
)
) )
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
append_upscale_correction( stage_upscale_correction(
stage, stage,
params, params,
upscale=first_upscale, upscale=first_upscale,
@ -64,22 +60,21 @@ def run_txt2img_pipeline(
# apply highres # apply highres
for _i in range(highres.iterations): for _i in range(highres.iterations):
chain.append( chain.stage(
( upscale_highres,
upscale_highres, StageParams(
stage, outscale=highres.scale,
{ ),
"highres": highres, highres=highres,
"upscale": upscale, upscale=upscale,
}, overlap=params.overlap,
)
) )
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
append_upscale_correction( stage_upscale_correction(
StageParams(), stage,
params, params,
upscale=upscale, upscale=after_upscale,
chain=chain, chain=chain,
) )
@ -128,20 +123,16 @@ def run_img2img_pipeline(
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams()
chain.append( chain.stage(
( blend_img2img,
blend_img2img, stage,
stage, strength=strength,
{
"strength": strength,
},
)
) )
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
append_upscale_correction( stage_upscale_correction(
stage, stage,
params, params,
upscale=first_upscale, upscale=first_upscale,
@ -151,32 +142,24 @@ def run_img2img_pipeline(
# loopback through multiple img2img iterations # loopback through multiple img2img iterations
if params.loopback > 0: if params.loopback > 0:
for _i in range(params.loopback): for _i in range(params.loopback):
chain.append( chain.stage(
( blend_img2img,
blend_img2img, stage,
stage, strength=strength,
{
"strength": strength,
},
)
) )
# highres, if selected # highres, if selected
if highres.iterations > 0: if highres.iterations > 0:
for _i in range(highres.iterations): for _i in range(highres.iterations):
chain.append( chain.stage(
( upscale_highres,
upscale_highres, stage,
stage, highres=highres,
{ upscale=upscale,
"highres": highres,
"upscale": upscale,
},
)
) )
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
append_upscale_correction( stage_upscale_correction(
stage, stage,
params, params,
upscale=after_upscale, upscale=after_upscale,
@ -237,34 +220,26 @@ def run_inpaint_pipeline(
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_order=tile_order) stage = StageParams(tile_order=tile_order)
chain.append( chain.stage(
( upscale_outpaint,
upscale_outpaint, stage,
stage, border=border,
{ stage_mask=mask,
"border": border, fill_color=fill_color,
"stage_mask": mask, mask_filter=mask_filter,
"fill_color": fill_color, noise_source=noise_source,
"mask_filter": mask_filter,
"noise_source": noise_source,
},
)
) )
# apply highres # apply highres
chain.append( chain.stage(
( upscale_highres,
upscale_highres, stage,
stage, highres=highres,
{ upscale=upscale,
"highres": highres,
"upscale": upscale,
},
)
) )
# apply upscaling and correction # apply upscaling and correction
append_upscale_correction( stage_upscale_correction(
stage, stage,
params, params,
upscale=upscale, upscale=upscale,
@ -313,7 +288,7 @@ def run_upscale_pipeline(
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
append_upscale_correction( stage_upscale_correction(
stage, stage,
params, params,
upscale=first_upscale, upscale=first_upscale,
@ -321,19 +296,15 @@ def run_upscale_pipeline(
) )
# apply highres # apply highres
chain.append( chain.stage(
( upscale_highres,
upscale_highres, stage,
stage, highres=highres,
{ upscale=upscale,
"highres": highres,
"upscale": upscale,
},
)
) )
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
append_upscale_correction( stage_upscale_correction(
stage, stage,
params, params,
upscale=after_upscale, upscale=after_upscale,
@ -380,7 +351,7 @@ def run_blend_pipeline(
stage.append((blend_mask, stage, None)) stage.append((blend_mask, stage, None))
# apply upscaling and correction # apply upscaling and correction
append_upscale_correction( stage_upscale_correction(
stage, stage,
params, params,
upscale=upscale, upscale=upscale,

View File

@ -36,7 +36,7 @@ def split_upscale(
) )
def append_upscale_correction( def stage_upscale_correction(
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
*, *,