fix(api): use kwargs for chain stages
This commit is contained in:
parent
7a73c9ff61
commit
2d10252564
|
@ -17,7 +17,7 @@ from .diffusers.run import (
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from .diffusers.stub_scheduler import StubScheduler
|
from .diffusers.stub_scheduler import StubScheduler
|
||||||
from .diffusers.upscale import append_upscale_correction
|
from .diffusers.upscale import stage_upscale_correction
|
||||||
from .image.utils import (
|
from .image.utils import (
|
||||||
expand_image,
|
expand_image,
|
||||||
valid_image,
|
valid_image,
|
||||||
|
|
|
@ -78,11 +78,31 @@ class ChainPipeline:
|
||||||
|
|
||||||
def append(self, stage: PipelineStage):
|
def append(self, stage: PipelineStage):
|
||||||
"""
|
"""
|
||||||
|
DEPRECATED: use `stage` instead
|
||||||
|
|
||||||
Append an additional stage to this pipeline.
|
Append an additional stage to this pipeline.
|
||||||
"""
|
"""
|
||||||
if stage is not None:
|
if stage is not None:
|
||||||
self.stages.append(stage)
|
self.stages.append(stage)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
job: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
source: Optional[Image.Image],
|
||||||
|
callback: Optional[ProgressCallback],
|
||||||
|
**kwargs
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
TODO: handle List[Image] inputs and outputs
|
||||||
|
"""
|
||||||
|
return self(job, server, params, source=source, callback=callback, **kwargs)
|
||||||
|
|
||||||
|
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
|
||||||
|
self.stages.append((callback, params, kwargs))
|
||||||
|
return self
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
job: WorkerContext,
|
job: WorkerContext,
|
||||||
|
@ -93,7 +113,7 @@ class ChainPipeline:
|
||||||
**pipeline_kwargs
|
**pipeline_kwargs
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
TODO: handle List[Image] inputs and outputs
|
DEPRECATED: use `run` instead
|
||||||
"""
|
"""
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback = ChainProgress.from_progress(callback)
|
callback = ChainProgress.from_progress(callback)
|
||||||
|
|
|
@ -4,8 +4,8 @@ from typing import Any, Optional
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..chain.base import ChainPipeline
|
from ..chain.base import ChainPipeline
|
||||||
from ..chain.img2img import blend_img2img
|
from ..chain.blend_img2img import blend_img2img
|
||||||
from ..diffusers.upscale import append_upscale_correction
|
from ..diffusers.upscale import stage_upscale_correction
|
||||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
|
@ -45,7 +45,7 @@ def upscale_highres(
|
||||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||||
else:
|
else:
|
||||||
logger.debug("using upscaling pipeline for highres")
|
logger.debug("using upscaling pipeline for highres")
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
StageParams(),
|
StageParams(),
|
||||||
params,
|
params,
|
||||||
upscale=upscale.with_args(
|
upscale=upscale.with_args(
|
||||||
|
|
|
@ -24,7 +24,7 @@ from ..server import ServerContext
|
||||||
from ..server.load import get_source_filters
|
from ..server.load import get_source_filters
|
||||||
from ..utils import run_gc, show_system_toast
|
from ..utils import run_gc, show_system_toast
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .upscale import append_upscale_correction, split_upscale
|
from .upscale import split_upscale, stage_upscale_correction
|
||||||
from .utils import parse_prompt
|
from .utils import parse_prompt
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -42,20 +42,16 @@ def run_txt2img_pipeline(
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
source_txt2img,
|
||||||
source_txt2img,
|
stage,
|
||||||
stage,
|
size=size,
|
||||||
{
|
|
||||||
"size": size,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
|
@ -64,22 +60,21 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
for _i in range(highres.iterations):
|
for _i in range(highres.iterations):
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
upscale_highres,
|
||||||
upscale_highres,
|
StageParams(
|
||||||
stage,
|
outscale=highres.scale,
|
||||||
{
|
),
|
||||||
"highres": highres,
|
highres=highres,
|
||||||
"upscale": upscale,
|
upscale=upscale,
|
||||||
},
|
overlap=params.overlap,
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
StageParams(),
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=upscale,
|
upscale=after_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -128,20 +123,16 @@ def run_img2img_pipeline(
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
blend_img2img,
|
||||||
blend_img2img,
|
stage,
|
||||||
stage,
|
strength=strength,
|
||||||
{
|
|
||||||
"strength": strength,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
|
@ -151,32 +142,24 @@ def run_img2img_pipeline(
|
||||||
# loopback through multiple img2img iterations
|
# loopback through multiple img2img iterations
|
||||||
if params.loopback > 0:
|
if params.loopback > 0:
|
||||||
for _i in range(params.loopback):
|
for _i in range(params.loopback):
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
blend_img2img,
|
||||||
blend_img2img,
|
stage,
|
||||||
stage,
|
strength=strength,
|
||||||
{
|
|
||||||
"strength": strength,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# highres, if selected
|
# highres, if selected
|
||||||
if highres.iterations > 0:
|
if highres.iterations > 0:
|
||||||
for _i in range(highres.iterations):
|
for _i in range(highres.iterations):
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
upscale_highres,
|
||||||
upscale_highres,
|
stage,
|
||||||
stage,
|
highres=highres,
|
||||||
{
|
upscale=upscale,
|
||||||
"highres": highres,
|
|
||||||
"upscale": upscale,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
|
@ -237,34 +220,26 @@ def run_inpaint_pipeline(
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_order=tile_order)
|
stage = StageParams(tile_order=tile_order)
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
upscale_outpaint,
|
||||||
upscale_outpaint,
|
stage,
|
||||||
stage,
|
border=border,
|
||||||
{
|
stage_mask=mask,
|
||||||
"border": border,
|
fill_color=fill_color,
|
||||||
"stage_mask": mask,
|
mask_filter=mask_filter,
|
||||||
"fill_color": fill_color,
|
noise_source=noise_source,
|
||||||
"mask_filter": mask_filter,
|
|
||||||
"noise_source": noise_source,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
upscale_highres,
|
||||||
upscale_highres,
|
stage,
|
||||||
stage,
|
highres=highres,
|
||||||
{
|
upscale=upscale,
|
||||||
"highres": highres,
|
|
||||||
"upscale": upscale,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction
|
# apply upscaling and correction
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=upscale,
|
upscale=upscale,
|
||||||
|
@ -313,7 +288,7 @@ def run_upscale_pipeline(
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
|
@ -321,19 +296,15 @@ def run_upscale_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
chain.append(
|
chain.stage(
|
||||||
(
|
upscale_highres,
|
||||||
upscale_highres,
|
stage,
|
||||||
stage,
|
highres=highres,
|
||||||
{
|
upscale=upscale,
|
||||||
"highres": highres,
|
|
||||||
"upscale": upscale,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
|
@ -380,7 +351,7 @@ def run_blend_pipeline(
|
||||||
stage.append((blend_mask, stage, None))
|
stage.append((blend_mask, stage, None))
|
||||||
|
|
||||||
# apply upscaling and correction
|
# apply upscaling and correction
|
||||||
append_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
stage,
|
||||||
params,
|
params,
|
||||||
upscale=upscale,
|
upscale=upscale,
|
||||||
|
|
|
@ -36,7 +36,7 @@ def split_upscale(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def append_upscale_correction(
|
def stage_upscale_correction(
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
*,
|
*,
|
||||||
|
|
Loading…
Reference in New Issue