From d52c68d6074b1dc11daba1f3641adc5c0fd47016 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Nov 2023 17:18:23 -0600 Subject: [PATCH] feat(api): add chain pipeline stage result type --- api/onnx_web/chain/__init__.py | 51 +-- api/onnx_web/chain/base.py | 300 ++---------------- api/onnx_web/chain/blend_denoise.py | 2 +- api/onnx_web/chain/blend_grid.py | 2 +- api/onnx_web/chain/blend_img2img.py | 2 +- api/onnx_web/chain/blend_linear.py | 2 +- api/onnx_web/chain/blend_mask.py | 2 +- api/onnx_web/chain/correct_codeformer.py | 2 +- api/onnx_web/chain/correct_gfpgan.py | 2 +- api/onnx_web/chain/highres.py | 2 +- api/onnx_web/chain/persist_disk.py | 2 +- api/onnx_web/chain/persist_s3.py | 2 +- api/onnx_web/chain/pipeline.py | 283 +++++++++++++++++ api/onnx_web/chain/reduce_crop.py | 2 +- api/onnx_web/chain/reduce_thumbnail.py | 2 +- api/onnx_web/chain/result.py | 31 ++ api/onnx_web/chain/source_noise.py | 2 +- api/onnx_web/chain/source_s3.py | 2 +- api/onnx_web/chain/source_txt2img.py | 2 +- api/onnx_web/chain/source_url.py | 2 +- api/onnx_web/chain/stage.py | 38 --- api/onnx_web/chain/stages.py | 64 ++++ api/onnx_web/chain/upscale_bsrgan.py | 2 +- api/onnx_web/chain/upscale_highres.py | 2 +- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/chain/upscale_resrgan.py | 2 +- api/onnx_web/chain/upscale_simple.py | 2 +- .../chain/upscale_stable_diffusion.py | 2 +- api/onnx_web/chain/upscale_swinir.py | 2 +- api/onnx_web/worker/context.py | 2 +- api/tests/chain/test_base.py | 2 +- 31 files changed, 433 insertions(+), 384 deletions(-) create mode 100644 api/onnx_web/chain/pipeline.py create mode 100644 api/onnx_web/chain/result.py delete mode 100644 api/onnx_web/chain/stage.py create mode 100644 api/onnx_web/chain/stages.py diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index b34372ee..476e3e18 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,49 +1,2 @@ -from .base import ChainPipeline, PipelineStage, StageParams -from .blend_denoise import BlendDenoiseStage -from .blend_img2img import BlendImg2ImgStage -from .blend_grid import BlendGridStage -from .blend_linear import BlendLinearStage -from .blend_mask import BlendMaskStage -from .correct_codeformer import CorrectCodeformerStage -from .correct_gfpgan import CorrectGFPGANStage -from .persist_disk import PersistDiskStage -from .persist_s3 import PersistS3Stage -from .reduce_crop import ReduceCropStage -from .reduce_thumbnail import ReduceThumbnailStage -from .source_noise import SourceNoiseStage -from .source_s3 import SourceS3Stage -from .source_txt2img import SourceTxt2ImgStage -from .source_url import SourceURLStage -from .upscale_bsrgan import UpscaleBSRGANStage -from .upscale_highres import UpscaleHighresStage -from .upscale_outpaint import UpscaleOutpaintStage -from .upscale_resrgan import UpscaleRealESRGANStage -from .upscale_simple import UpscaleSimpleStage -from .upscale_stable_diffusion import UpscaleStableDiffusionStage -from .upscale_swinir import UpscaleSwinIRStage - -CHAIN_STAGES = { - "blend-denoise": BlendDenoiseStage, - "blend-img2img": BlendImg2ImgStage, - "blend-inpaint": UpscaleOutpaintStage, - "blend-grid": BlendGridStage, - "blend-linear": BlendLinearStage, - "blend-mask": BlendMaskStage, - "correct-codeformer": CorrectCodeformerStage, - "correct-gfpgan": CorrectGFPGANStage, - "persist-disk": PersistDiskStage, - "persist-s3": PersistS3Stage, - "reduce-crop": ReduceCropStage, - "reduce-thumbnail": ReduceThumbnailStage, - "source-noise": SourceNoiseStage, - "source-s3": SourceS3Stage, - "source-txt2img": SourceTxt2ImgStage, - "source-url": SourceURLStage, - "upscale-bsrgan": UpscaleBSRGANStage, - "upscale-highres": UpscaleHighresStage, - "upscale-outpaint": UpscaleOutpaintStage, - "upscale-resrgan": UpscaleRealESRGANStage, - "upscale-simple": UpscaleSimpleStage, - "upscale-stable-diffusion": UpscaleStableDiffusionStage, - "upscale-swinir": UpscaleSwinIRStage, -} +from .pipeline import ChainPipeline, PipelineStage, StageParams +from .stages import * \ No newline at end of file diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index cabb8da1..0a220773 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -1,283 +1,39 @@ -from datetime import timedelta -from logging import getLogger -from time import monotonic -from typing import Any, List, Optional, Tuple +from typing import List, Optional from PIL import Image -from ..errors import RetryException -from ..output import save_image -from ..params import ImageParams, Size, StageParams -from ..server import ServerContext -from ..utils import is_debug, run_gc -from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage -from .tile import needs_tile, process_tile_order - -logger = getLogger(__name__) +from .result import StageResult +from ..params import ImageParams, Size, SizeChart, StageParams +from ..server.context import ServerContext +from ..worker.context import WorkerContext -PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] - - -class ChainProgress: - def __init__(self, parent: ProgressCallback, start=0) -> None: - self.parent = parent - self.step = start - self.total = 0 - - def __call__(self, step: int, timestep: int, latents: Any) -> None: - if step < self.step: - # accumulate on resets - self.total += self.step - - self.step = step - self.parent(self.get_total(), timestep, latents) - - def get_total(self) -> int: - return self.step + self.total - - @classmethod - def from_progress(cls, parent: ProgressCallback): - start = parent.step if hasattr(parent, "step") else 0 - return ChainProgress(parent, start=start) - - -class ChainPipeline: - """ - Run many stages in series, passing the image results from each to the next, and processing - tiles as needed. - """ - - def __init__( - self, - stages: Optional[List[PipelineStage]] = None, - ): - """ - Create a new pipeline that will run the given stages. - """ - self.stages = list(stages or []) - - def append(self, stage: Optional[PipelineStage]): - """ - Append an additional stage to this pipeline. - - This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to - assemble the stage from loose arguments. - """ - if stage is not None: - self.stages.append(stage) +class BaseStage: + max_tile = SizeChart.auto def run( self, - worker: WorkerContext, - server: ServerContext, - params: ImageParams, - sources: List[Image.Image], - callback: Optional[ProgressCallback], - **kwargs - ) -> List[Image.Image]: - return self( - worker, server, params, sources=sources, callback=callback, **kwargs - ) + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + _sources: List[Image.Image], + *args, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> StageResult: + raise NotImplementedError() # noqa - def stage(self, callback: BaseStage, params: StageParams, **kwargs): - self.stages.append((callback, params, kwargs)) - return self - - def steps(self, params: ImageParams, size: Size): - steps = 0 - for callback, _params, kwargs in self.stages: - steps += callback.steps(kwargs.get("params", params), size) - - return steps - - def outputs(self, params: ImageParams, sources: int): - outputs = sources - for callback, _params, kwargs in self.stages: - outputs = callback.outputs(kwargs.get("params", params), outputs) - - return outputs - - def __call__( + def steps( self, - worker: WorkerContext, - server: ServerContext, - params: ImageParams, - sources: List[Image.Image], - callback: Optional[ProgressCallback] = None, - **pipeline_kwargs - ) -> List[Image.Image]: - """ - DEPRECATED: use `run` instead - """ - if callback is None: - callback = worker.get_progress_callback() - else: - callback = ChainProgress.from_progress(callback) + _params: ImageParams, + _size: Size, + ) -> int: + return 1 # noqa - start = monotonic() - - if len(sources) > 0: - logger.info( - "running pipeline on %s source images", - len(sources), - ) - else: - logger.info("running pipeline without source images") - - stage_sources = sources - for stage_pipe, stage_params, stage_kwargs in self.stages: - name = stage_params.name or stage_pipe.__class__.__name__ - kwargs = stage_kwargs or {} - kwargs = {**pipeline_kwargs, **kwargs} - logger.debug( - "running stage %s with %s source images, parameters: %s", - name, - len(stage_sources) - stage_sources.count(None), - kwargs.keys(), - ) - - per_stage_params = params - if "params" in kwargs: - per_stage_params = kwargs["params"] - kwargs.pop("params") - - # the stage must be split and tiled if any image is larger than the selected/max tile size - must_tile = any( - [ - needs_tile( - stage_pipe.max_tile, - stage_params.tile_size, - size=kwargs.get("size", None), - source=source, - ) - for source in stage_sources - ] - ) - - tile = stage_params.tile_size - if stage_pipe.max_tile > 0: - tile = min(stage_pipe.max_tile, stage_params.tile_size) - - if stage_sources or must_tile: - stage_outputs = [] - for source in stage_sources: - logger.info( - "image contains sources or is larger than tile size of %s, tiling stage", - tile, - ) - - extra_tiles = [] - - def stage_tile( - source_tile: Image.Image, - tile_mask: Image.Image, - dims: Tuple[int, int, int], - ) -> Image.Image: - for _i in range(worker.retries): - try: - output_tile = stage_pipe.run( - worker, - server, - stage_params, - per_stage_params, - [source_tile], - tile_mask=tile_mask, - callback=callback, - dims=dims, - **kwargs, - ) - - if len(output_tile) > 1: - while len(extra_tiles) < len(output_tile): - extra_tiles.append([]) - - for tile, layer in zip(output_tile, extra_tiles): - layer.append((tile, dims)) - - if is_debug(): - save_image(server, "last-tile.png", output_tile[0]) - - return output_tile[0] - except Exception: - worker.retries = worker.retries - 1 - logger.exception( - "error while running stage pipeline for tile, %s retries left", - worker.retries, - ) - server.cache.clear() - run_gc([worker.get_device()]) - - raise RetryException("exhausted retries on tile") - - output = process_tile_order( - stage_params.tile_order, - source, - tile, - stage_params.outscale, - [stage_tile], - **kwargs, - ) - - stage_outputs.append(output) - - if len(extra_tiles) > 1: - for layer in extra_tiles: - layer_output = Image.new("RGB", output.size) - for layer_tile, dims in layer: - layer_output.paste(layer_tile, (dims[0], dims[1])) - - stage_outputs.append(layer_output) - - stage_sources = stage_outputs - else: - logger.debug( - "image does not contain sources and is within tile size of %s, running stage", - tile, - ) - for i in range(worker.retries): - try: - stage_outputs = stage_pipe.run( - worker, - server, - stage_params, - per_stage_params, - stage_sources, - callback=callback, - dims=(0, 0, tile), - **kwargs, - ) - # doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline - # does not like, so it throws - stage_sources = stage_outputs - break - except Exception: - worker.retries = worker.retries - 1 - logger.exception( - "error while running stage pipeline, %s retries left", - worker.retries, - ) - server.cache.clear() - run_gc([worker.get_device()]) - - if worker.retries <= 0: - raise RetryException("exhausted retries on stage") - - logger.debug( - "finished stage %s with %s results", - name, - len(stage_sources), - ) - - if is_debug(): - save_image(server, "last-stage.png", stage_sources[0]) - - end = monotonic() - duration = timedelta(seconds=(end - start)) - logger.info( - "finished pipeline in %s with %s results", - duration, - len(stage_sources), - ) - return stage_sources + def outputs( + self, + _params: ImageParams, + sources: int, + ) -> int: + return sources diff --git a/api/onnx_web/chain/blend_denoise.py b/api/onnx_web/chain/blend_denoise.py index 94b8a1ff..efc5b2b3 100644 --- a/api/onnx_web/chain/blend_denoise.py +++ b/api/onnx_web/chain/blend_denoise.py @@ -8,7 +8,7 @@ from PIL import Image from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 19af2fbd..a6cab0fe 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -6,7 +6,7 @@ from PIL import Image from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 274ab407..d44e52cc 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -10,7 +10,7 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 6317ef13..1eae984a 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -6,7 +6,7 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 1038d3ea..d4cd4001 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -8,7 +8,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 121a5cb3..f649acc6 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -6,7 +6,7 @@ from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 145ff36b..6b0e17be 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -9,7 +9,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index 482b86c7..088c16ad 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Optional -from ..chain.base import ChainPipeline +from .pipeline import ChainPipeline from ..chain.blend_img2img import BlendImg2ImgStage from ..chain.upscale import stage_upscale_correction from ..chain.upscale_simple import UpscaleSimpleStage diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 124f0989..38ec2b3f 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -7,7 +7,7 @@ from ..output import save_image from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 27f4026f..f2becfc2 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -8,7 +8,7 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py new file mode 100644 index 00000000..edba28c9 --- /dev/null +++ b/api/onnx_web/chain/pipeline.py @@ -0,0 +1,283 @@ +from datetime import timedelta +from logging import getLogger +from time import monotonic +from typing import Any, List, Optional, Tuple + +from PIL import Image + +from ..errors import RetryException +from ..output import save_image +from ..params import ImageParams, Size, StageParams +from ..server import ServerContext +from ..utils import is_debug, run_gc +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .tile import needs_tile, process_tile_order + +logger = getLogger(__name__) + + +PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] + + +class ChainProgress: + def __init__(self, parent: ProgressCallback, start=0) -> None: + self.parent = parent + self.step = start + self.total = 0 + + def __call__(self, step: int, timestep: int, latents: Any) -> None: + if step < self.step: + # accumulate on resets + self.total += self.step + + self.step = step + self.parent(self.get_total(), timestep, latents) + + def get_total(self) -> int: + return self.step + self.total + + @classmethod + def from_progress(cls, parent: ProgressCallback): + start = parent.step if hasattr(parent, "step") else 0 + return ChainProgress(parent, start=start) + + +class ChainPipeline: + """ + Run many stages in series, passing the image results from each to the next, and processing + tiles as needed. + """ + + def __init__( + self, + stages: Optional[List[PipelineStage]] = None, + ): + """ + Create a new pipeline that will run the given stages. + """ + self.stages = list(stages or []) + + def append(self, stage: Optional[PipelineStage]): + """ + Append an additional stage to this pipeline. + + This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to + assemble the stage from loose arguments. + """ + if stage is not None: + self.stages.append(stage) + + def run( + self, + worker: WorkerContext, + server: ServerContext, + params: ImageParams, + sources: List[Image.Image], + callback: Optional[ProgressCallback], + **kwargs + ) -> List[Image.Image]: + return self( + worker, server, params, sources=sources, callback=callback, **kwargs + ) + + def stage(self, callback: BaseStage, params: StageParams, **kwargs): + self.stages.append((callback, params, kwargs)) + return self + + def steps(self, params: ImageParams, size: Size): + steps = 0 + for callback, _params, kwargs in self.stages: + steps += callback.steps(kwargs.get("params", params), size) + + return steps + + def outputs(self, params: ImageParams, sources: int): + outputs = sources + for callback, _params, kwargs in self.stages: + outputs = callback.outputs(kwargs.get("params", params), outputs) + + return outputs + + def __call__( + self, + worker: WorkerContext, + server: ServerContext, + params: ImageParams, + sources: List[Image.Image], + callback: Optional[ProgressCallback] = None, + **pipeline_kwargs + ) -> List[Image.Image]: + """ + DEPRECATED: use `run` instead + """ + if callback is None: + callback = worker.get_progress_callback() + else: + callback = ChainProgress.from_progress(callback) + + start = monotonic() + + if len(sources) > 0: + logger.info( + "running pipeline on %s source images", + len(sources), + ) + else: + logger.info("running pipeline without source images") + + stage_sources = sources + for stage_pipe, stage_params, stage_kwargs in self.stages: + name = stage_params.name or stage_pipe.__class__.__name__ + kwargs = stage_kwargs or {} + kwargs = {**pipeline_kwargs, **kwargs} + logger.debug( + "running stage %s with %s source images, parameters: %s", + name, + len(stage_sources) - stage_sources.count(None), + kwargs.keys(), + ) + + per_stage_params = params + if "params" in kwargs: + per_stage_params = kwargs["params"] + kwargs.pop("params") + + # the stage must be split and tiled if any image is larger than the selected/max tile size + must_tile = any( + [ + needs_tile( + stage_pipe.max_tile, + stage_params.tile_size, + size=kwargs.get("size", None), + source=source, + ) + for source in stage_sources + ] + ) + + tile = stage_params.tile_size + if stage_pipe.max_tile > 0: + tile = min(stage_pipe.max_tile, stage_params.tile_size) + + if stage_sources or must_tile: + stage_outputs = [] + for source in stage_sources: + logger.info( + "image contains sources or is larger than tile size of %s, tiling stage", + tile, + ) + + extra_tiles = [] + + def stage_tile( + source_tile: Image.Image, + tile_mask: Image.Image, + dims: Tuple[int, int, int], + ) -> Image.Image: + for _i in range(worker.retries): + try: + output_tile = stage_pipe.run( + worker, + server, + stage_params, + per_stage_params, + [source_tile], + tile_mask=tile_mask, + callback=callback, + dims=dims, + **kwargs, + ) + + if len(output_tile) > 1: + while len(extra_tiles) < len(output_tile): + extra_tiles.append([]) + + for tile, layer in zip(output_tile, extra_tiles): + layer.append((tile, dims)) + + if is_debug(): + save_image(server, "last-tile.png", output_tile[0]) + + return output_tile[0] + except Exception: + worker.retries = worker.retries - 1 + logger.exception( + "error while running stage pipeline for tile, %s retries left", + worker.retries, + ) + server.cache.clear() + run_gc([worker.get_device()]) + + raise RetryException("exhausted retries on tile") + + output = process_tile_order( + stage_params.tile_order, + source, + tile, + stage_params.outscale, + [stage_tile], + **kwargs, + ) + + stage_outputs.append(output) + + if len(extra_tiles) > 1: + for layer in extra_tiles: + layer_output = Image.new("RGB", output.size) + for layer_tile, dims in layer: + layer_output.paste(layer_tile, (dims[0], dims[1])) + + stage_outputs.append(layer_output) + + stage_sources = stage_outputs + else: + logger.debug( + "image does not contain sources and is within tile size of %s, running stage", + tile, + ) + for i in range(worker.retries): + try: + stage_outputs = stage_pipe.run( + worker, + server, + stage_params, + per_stage_params, + stage_sources, + callback=callback, + dims=(0, 0, tile), + **kwargs, + ) + # doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline + # does not like, so it throws + stage_sources = stage_outputs + break + except Exception: + worker.retries = worker.retries - 1 + logger.exception( + "error while running stage pipeline, %s retries left", + worker.retries, + ) + server.cache.clear() + run_gc([worker.get_device()]) + + if worker.retries <= 0: + raise RetryException("exhausted retries on stage") + + logger.debug( + "finished stage %s with %s results", + name, + len(stage_sources), + ) + + if is_debug(): + save_image(server, "last-stage.png", stage_sources[0]) + + end = monotonic() + duration = timedelta(seconds=(end - start)) + logger.info( + "finished pipeline in %s with %s results", + duration, + len(stage_sources), + ) + return stage_sources diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 2e258075..24974e36 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -6,7 +6,7 @@ from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index d7a0efee..c22ba3fe 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -6,7 +6,7 @@ from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py new file mode 100644 index 00000000..627c5197 --- /dev/null +++ b/api/onnx_web/chain/result.py @@ -0,0 +1,31 @@ +from PIL.Image import Image, fromarray +from typing import List, Optional + +import numpy as np + +class StageResult: + """ + Chain pipeline stage result. + Can contain PIL images or numpy arrays, with helpers to convert between them. + """ + arrays: Optional[List[np.ndarray]] + images: Optional[List[Image]] + + def __init__(self, arrays = None, images = None) -> None: + if arrays is not None and images is not None: + raise ValueError("stages must only return one type of result") + + self.arrays = arrays + self.images = images + + def as_numpy(self) -> List[np.ndarray]: + if self.arrays is not None: + return self.arrays + + return [np.array(i) for i in self.images] + + def as_image(self) -> List[Image]: + if self.images is not None: + return self.images + + return [fromarray(i) for i in self.arrays] diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 89e65abc..2cf5b6b0 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -6,7 +6,7 @@ from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index 1493088c..32eb4357 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -8,7 +8,7 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index d46b3711..6eb20285 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -18,7 +18,7 @@ from ..diffusers.utils import ( from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 33e5ac78..2dfcb855 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -8,7 +8,7 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/stage.py b/api/onnx_web/chain/stage.py deleted file mode 100644 index c9c6eafd..00000000 --- a/api/onnx_web/chain/stage.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List, Optional - -from PIL import Image - -from ..params import ImageParams, Size, SizeChart, StageParams -from ..server.context import ServerContext -from ..worker.context import WorkerContext - - -class BaseStage: - max_tile = SizeChart.auto - - def run( - self, - _worker: WorkerContext, - _server: ServerContext, - _stage: StageParams, - _params: ImageParams, - _sources: List[Image.Image], - *args, - stage_source: Optional[Image.Image] = None, - **kwargs, - ) -> List[Image.Image]: - raise NotImplementedError() # noqa - - def steps( - self, - _params: ImageParams, - _size: Size, - ) -> int: - return 1 # noqa - - def outputs( - self, - _params: ImageParams, - sources: int, - ) -> int: - return sources diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py new file mode 100644 index 00000000..f7b6801a --- /dev/null +++ b/api/onnx_web/chain/stages.py @@ -0,0 +1,64 @@ +from logging import getLogger + +from .base import BaseStage +from .blend_denoise import BlendDenoiseStage +from .blend_img2img import BlendImg2ImgStage +from .blend_grid import BlendGridStage +from .blend_linear import BlendLinearStage +from .blend_mask import BlendMaskStage +from .correct_codeformer import CorrectCodeformerStage +from .correct_gfpgan import CorrectGFPGANStage +from .persist_disk import PersistDiskStage +from .persist_s3 import PersistS3Stage +from .reduce_crop import ReduceCropStage +from .reduce_thumbnail import ReduceThumbnailStage +from .source_noise import SourceNoiseStage +from .source_s3 import SourceS3Stage +from .source_txt2img import SourceTxt2ImgStage +from .source_url import SourceURLStage +from .upscale_bsrgan import UpscaleBSRGANStage +from .upscale_highres import UpscaleHighresStage +from .upscale_outpaint import UpscaleOutpaintStage +from .upscale_resrgan import UpscaleRealESRGANStage +from .upscale_simple import UpscaleSimpleStage +from .upscale_stable_diffusion import UpscaleStableDiffusionStage +from .upscale_swinir import UpscaleSwinIRStage + +logger = getLogger(__name__) + +CHAIN_STAGES = { + "blend-denoise": BlendDenoiseStage, + "blend-img2img": BlendImg2ImgStage, + "blend-inpaint": UpscaleOutpaintStage, + "blend-grid": BlendGridStage, + "blend-linear": BlendLinearStage, + "blend-mask": BlendMaskStage, + "correct-codeformer": CorrectCodeformerStage, + "correct-gfpgan": CorrectGFPGANStage, + "persist-disk": PersistDiskStage, + "persist-s3": PersistS3Stage, + "reduce-crop": ReduceCropStage, + "reduce-thumbnail": ReduceThumbnailStage, + "source-noise": SourceNoiseStage, + "source-s3": SourceS3Stage, + "source-txt2img": SourceTxt2ImgStage, + "source-url": SourceURLStage, + "upscale-bsrgan": UpscaleBSRGANStage, + "upscale-highres": UpscaleHighresStage, + "upscale-outpaint": UpscaleOutpaintStage, + "upscale-resrgan": UpscaleRealESRGANStage, + "upscale-simple": UpscaleSimpleStage, + "upscale-stable-diffusion": UpscaleStableDiffusionStage, + "upscale-swinir": UpscaleSwinIRStage, +} + + +def add_stage(name: str, stage: BaseStage) -> bool: + global CHAIN_STAGES + + if name in CHAIN_STAGES: + logger.warning("cannot replace stage: %s", name) + return False + else: + CHAIN_STAGES[name] = stage + return True diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 0137750e..9afe54ae 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index e19f75fb..5ed28f9b 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -8,7 +8,7 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext from ..worker.context import ProgressCallback -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index cdc3a067..29883cc0 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -18,7 +18,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index e680af53..7fbb6901 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index 7dd44200..36095339 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -6,7 +6,7 @@ from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index cf784b05..763871e0 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -10,7 +10,7 @@ from ..diffusers.utils import encode_prompt, parse_prompt from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index a49b99e5..52114bba 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage logger = getLogger(__name__) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index a24613ed..2d6d0278 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -86,7 +86,7 @@ class WorkerContext: return 0 def get_progress_callback(self) -> ProgressCallback: - from ..chain.base import ChainProgress + from ..chain.pipeline import ChainProgress def on_progress(step: int, timestep: int, latents: Any): on_progress.step = step diff --git a/api/tests/chain/test_base.py b/api/tests/chain/test_base.py index a2530600..a0f5463b 100644 --- a/api/tests/chain/test_base.py +++ b/api/tests/chain/test_base.py @@ -1,6 +1,6 @@ import unittest -from onnx_web.chain.base import ChainProgress +from onnx_web.chain.pipeline import ChainProgress class ChainProgressTests(unittest.TestCase):