fix(api): use kwargs for chain stages
This commit is contained in:
parent
7a73c9ff61
commit
2d10252564
|
@ -17,7 +17,7 @@ from .diffusers.run import (
|
|||
run_upscale_pipeline,
|
||||
)
|
||||
from .diffusers.stub_scheduler import StubScheduler
|
||||
from .diffusers.upscale import append_upscale_correction
|
||||
from .diffusers.upscale import stage_upscale_correction
|
||||
from .image.utils import (
|
||||
expand_image,
|
||||
valid_image,
|
||||
|
|
|
@ -78,11 +78,31 @@ class ChainPipeline:
|
|||
|
||||
def append(self, stage: PipelineStage):
|
||||
"""
|
||||
DEPRECATED: use `stage` instead
|
||||
|
||||
Append an additional stage to this pipeline.
|
||||
"""
|
||||
if stage is not None:
|
||||
self.stages.append(stage)
|
||||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
source: Optional[Image.Image],
|
||||
callback: Optional[ProgressCallback],
|
||||
**kwargs
|
||||
) -> Image.Image:
|
||||
"""
|
||||
TODO: handle List[Image] inputs and outputs
|
||||
"""
|
||||
return self(job, server, params, source=source, callback=callback, **kwargs)
|
||||
|
||||
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
|
||||
self.stages.append((callback, params, kwargs))
|
||||
return self
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
|
@ -93,7 +113,7 @@ class ChainPipeline:
|
|||
**pipeline_kwargs
|
||||
) -> Image.Image:
|
||||
"""
|
||||
TODO: handle List[Image] inputs and outputs
|
||||
DEPRECATED: use `run` instead
|
||||
"""
|
||||
if callback is not None:
|
||||
callback = ChainProgress.from_progress(callback)
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import Any, Optional
|
|||
from PIL import Image
|
||||
|
||||
from ..chain.base import ChainPipeline
|
||||
from ..chain.img2img import blend_img2img
|
||||
from ..diffusers.upscale import append_upscale_correction
|
||||
from ..chain.blend_img2img import blend_img2img
|
||||
from ..diffusers.upscale import stage_upscale_correction
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
|
@ -45,7 +45,7 @@ def upscale_highres(
|
|||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
logger.debug("using upscaling pipeline for highres")
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
StageParams(),
|
||||
params,
|
||||
upscale=upscale.with_args(
|
||||
|
|
|
@ -24,7 +24,7 @@ from ..server import ServerContext
|
|||
from ..server.load import get_source_filters
|
||||
from ..utils import run_gc, show_system_toast
|
||||
from ..worker import WorkerContext
|
||||
from .upscale import append_upscale_correction, split_upscale
|
||||
from .upscale import split_upscale, stage_upscale_correction
|
||||
from .utils import parse_prompt
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -42,20 +42,16 @@ def run_txt2img_pipeline(
|
|||
# prepare the chain pipeline and first stage
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams()
|
||||
chain.append(
|
||||
(
|
||||
source_txt2img,
|
||||
stage,
|
||||
{
|
||||
"size": size,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
source_txt2img,
|
||||
stage,
|
||||
size=size,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, before highres
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=first_upscale,
|
||||
|
@ -64,22 +60,21 @@ def run_txt2img_pipeline(
|
|||
|
||||
# apply highres
|
||||
for _i in range(highres.iterations):
|
||||
chain.append(
|
||||
(
|
||||
upscale_highres,
|
||||
stage,
|
||||
{
|
||||
"highres": highres,
|
||||
"upscale": upscale,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
StageParams(
|
||||
outscale=highres.scale,
|
||||
),
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
overlap=params.overlap,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, after highres
|
||||
append_upscale_correction(
|
||||
StageParams(),
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=upscale,
|
||||
upscale=after_upscale,
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
|
@ -128,20 +123,16 @@ def run_img2img_pipeline(
|
|||
# prepare the chain pipeline and first stage
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams()
|
||||
chain.append(
|
||||
(
|
||||
blend_img2img,
|
||||
stage,
|
||||
{
|
||||
"strength": strength,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
blend_img2img,
|
||||
stage,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, before highres
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=first_upscale,
|
||||
|
@ -151,32 +142,24 @@ def run_img2img_pipeline(
|
|||
# loopback through multiple img2img iterations
|
||||
if params.loopback > 0:
|
||||
for _i in range(params.loopback):
|
||||
chain.append(
|
||||
(
|
||||
blend_img2img,
|
||||
stage,
|
||||
{
|
||||
"strength": strength,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
blend_img2img,
|
||||
stage,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
# highres, if selected
|
||||
if highres.iterations > 0:
|
||||
for _i in range(highres.iterations):
|
||||
chain.append(
|
||||
(
|
||||
upscale_highres,
|
||||
stage,
|
||||
{
|
||||
"highres": highres,
|
||||
"upscale": upscale,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, after highres
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=after_upscale,
|
||||
|
@ -237,34 +220,26 @@ def run_inpaint_pipeline(
|
|||
# set up the chain pipeline and base stage
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams(tile_order=tile_order)
|
||||
chain.append(
|
||||
(
|
||||
upscale_outpaint,
|
||||
stage,
|
||||
{
|
||||
"border": border,
|
||||
"stage_mask": mask,
|
||||
"fill_color": fill_color,
|
||||
"mask_filter": mask_filter,
|
||||
"noise_source": noise_source,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
upscale_outpaint,
|
||||
stage,
|
||||
border=border,
|
||||
stage_mask=mask,
|
||||
fill_color=fill_color,
|
||||
mask_filter=mask_filter,
|
||||
noise_source=noise_source,
|
||||
)
|
||||
|
||||
# apply highres
|
||||
chain.append(
|
||||
(
|
||||
upscale_highres,
|
||||
stage,
|
||||
{
|
||||
"highres": highres,
|
||||
"upscale": upscale,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
)
|
||||
|
||||
# apply upscaling and correction
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=upscale,
|
||||
|
@ -313,7 +288,7 @@ def run_upscale_pipeline(
|
|||
# apply upscaling and correction, before highres
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=first_upscale,
|
||||
|
@ -321,19 +296,15 @@ def run_upscale_pipeline(
|
|||
)
|
||||
|
||||
# apply highres
|
||||
chain.append(
|
||||
(
|
||||
upscale_highres,
|
||||
stage,
|
||||
{
|
||||
"highres": highres,
|
||||
"upscale": upscale,
|
||||
},
|
||||
)
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, after highres
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=after_upscale,
|
||||
|
@ -380,7 +351,7 @@ def run_blend_pipeline(
|
|||
stage.append((blend_mask, stage, None))
|
||||
|
||||
# apply upscaling and correction
|
||||
append_upscale_correction(
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=upscale,
|
||||
|
|
|
@ -36,7 +36,7 @@ def split_upscale(
|
|||
)
|
||||
|
||||
|
||||
def append_upscale_correction(
|
||||
def stage_upscale_correction(
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
*,
|
||||
|
|
Loading…
Reference in New Issue