diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index beca62dd..77d49e14 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -17,7 +17,7 @@ from .diffusers.run import ( run_upscale_pipeline, ) from .diffusers.stub_scheduler import StubScheduler -from .diffusers.upscale import append_upscale_correction +from .diffusers.upscale import stage_upscale_correction from .image.utils import ( expand_image, valid_image, diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 26d22a5b..2b2ecaaa 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -78,11 +78,31 @@ class ChainPipeline: def append(self, stage: PipelineStage): """ + DEPRECATED: use `stage` instead + Append an additional stage to this pipeline. """ if stage is not None: self.stages.append(stage) + def run( + self, + job: WorkerContext, + server: ServerContext, + params: ImageParams, + source: Optional[Image.Image], + callback: Optional[ProgressCallback], + **kwargs + ) -> Image.Image: + """ + TODO: handle List[Image] inputs and outputs + """ + return self(job, server, params, source=source, callback=callback, **kwargs) + + def stage(self, callback: StageCallback, params: StageParams, **kwargs): + self.stages.append((callback, params, kwargs)) + return self + def __call__( self, job: WorkerContext, @@ -93,7 +113,7 @@ class ChainPipeline: **pipeline_kwargs ) -> Image.Image: """ - TODO: handle List[Image] inputs and outputs + DEPRECATED: use `run` instead """ if callback is not None: callback = ChainProgress.from_progress(callback) diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index 23444a48..8d1cd2a2 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -4,8 +4,8 @@ from typing import Any, Optional from PIL import Image from ..chain.base import ChainPipeline -from ..chain.img2img import blend_img2img -from ..diffusers.upscale import append_upscale_correction +from ..chain.blend_img2img import blend_img2img +from ..diffusers.upscale import stage_upscale_correction from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext @@ -45,7 +45,7 @@ def upscale_highres( source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) else: logger.debug("using upscaling pipeline for highres") - append_upscale_correction( + stage_upscale_correction( StageParams(), params, upscale=upscale.with_args( diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 18d51cc3..050dbdcf 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -24,7 +24,7 @@ from ..server import ServerContext from ..server.load import get_source_filters from ..utils import run_gc, show_system_toast from ..worker import WorkerContext -from .upscale import append_upscale_correction, split_upscale +from .upscale import split_upscale, stage_upscale_correction from .utils import parse_prompt logger = getLogger(__name__) @@ -42,20 +42,16 @@ def run_txt2img_pipeline( # prepare the chain pipeline and first stage chain = ChainPipeline() stage = StageParams() - chain.append( - ( - source_txt2img, - stage, - { - "size": size, - }, - ) + chain.stage( + source_txt2img, + stage, + size=size, ) # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale) if first_upscale: - append_upscale_correction( + stage_upscale_correction( stage, params, upscale=first_upscale, @@ -64,22 +60,21 @@ def run_txt2img_pipeline( # apply highres for _i in range(highres.iterations): - chain.append( - ( - upscale_highres, - stage, - { - "highres": highres, - "upscale": upscale, - }, - ) + chain.stage( + upscale_highres, + StageParams( + outscale=highres.scale, + ), + highres=highres, + upscale=upscale, + overlap=params.overlap, ) # apply upscaling and correction, after highres - append_upscale_correction( - StageParams(), + stage_upscale_correction( + stage, params, - upscale=upscale, + upscale=after_upscale, chain=chain, ) @@ -128,20 +123,16 @@ def run_img2img_pipeline( # prepare the chain pipeline and first stage chain = ChainPipeline() stage = StageParams() - chain.append( - ( - blend_img2img, - stage, - { - "strength": strength, - }, - ) + chain.stage( + blend_img2img, + stage, + strength=strength, ) # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale) if first_upscale: - append_upscale_correction( + stage_upscale_correction( stage, params, upscale=first_upscale, @@ -151,32 +142,24 @@ def run_img2img_pipeline( # loopback through multiple img2img iterations if params.loopback > 0: for _i in range(params.loopback): - chain.append( - ( - blend_img2img, - stage, - { - "strength": strength, - }, - ) + chain.stage( + blend_img2img, + stage, + strength=strength, ) # highres, if selected if highres.iterations > 0: for _i in range(highres.iterations): - chain.append( - ( - upscale_highres, - stage, - { - "highres": highres, - "upscale": upscale, - }, - ) + chain.stage( + upscale_highres, + stage, + highres=highres, + upscale=upscale, ) # apply upscaling and correction, after highres - append_upscale_correction( + stage_upscale_correction( stage, params, upscale=after_upscale, @@ -237,34 +220,26 @@ def run_inpaint_pipeline( # set up the chain pipeline and base stage chain = ChainPipeline() stage = StageParams(tile_order=tile_order) - chain.append( - ( - upscale_outpaint, - stage, - { - "border": border, - "stage_mask": mask, - "fill_color": fill_color, - "mask_filter": mask_filter, - "noise_source": noise_source, - }, - ) + chain.stage( + upscale_outpaint, + stage, + border=border, + stage_mask=mask, + fill_color=fill_color, + mask_filter=mask_filter, + noise_source=noise_source, ) # apply highres - chain.append( - ( - upscale_highres, - stage, - { - "highres": highres, - "upscale": upscale, - }, - ) + chain.stage( + upscale_highres, + stage, + highres=highres, + upscale=upscale, ) # apply upscaling and correction - append_upscale_correction( + stage_upscale_correction( stage, params, upscale=upscale, @@ -313,7 +288,7 @@ def run_upscale_pipeline( # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale) if first_upscale: - append_upscale_correction( + stage_upscale_correction( stage, params, upscale=first_upscale, @@ -321,19 +296,15 @@ def run_upscale_pipeline( ) # apply highres - chain.append( - ( - upscale_highres, - stage, - { - "highres": highres, - "upscale": upscale, - }, - ) + chain.stage( + upscale_highres, + stage, + highres=highres, + upscale=upscale, ) # apply upscaling and correction, after highres - append_upscale_correction( + stage_upscale_correction( stage, params, upscale=after_upscale, @@ -380,7 +351,7 @@ def run_blend_pipeline( stage.append((blend_mask, stage, None)) # apply upscaling and correction - append_upscale_correction( + stage_upscale_correction( stage, params, upscale=upscale, diff --git a/api/onnx_web/diffusers/upscale.py b/api/onnx_web/diffusers/upscale.py index caee1756..2398a3e9 100644 --- a/api/onnx_web/diffusers/upscale.py +++ b/api/onnx_web/diffusers/upscale.py @@ -36,7 +36,7 @@ def split_upscale( ) -def append_upscale_correction( +def stage_upscale_correction( stage: StageParams, params: ImageParams, *,