diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 0a220773..89d48400 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -17,7 +17,7 @@ class BaseStage: _server: ServerContext, _stage: StageParams, _params: ImageParams, - _sources: List[Image.Image], + _sources: StageResult, *args, stage_source: Optional[Image.Image] = None, **kwargs, diff --git a/api/onnx_web/chain/blend_denoise.py b/api/onnx_web/chain/blend_denoise.py index efc5b2b3..beabb871 100644 --- a/api/onnx_web/chain/blend_denoise.py +++ b/api/onnx_web/chain/blend_denoise.py @@ -9,6 +9,7 @@ from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -22,19 +23,19 @@ class BlendDenoiseStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, strength: int = 3, stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("denoising source images") results = [] - for source in sources: - data = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR) + for source in sources.as_numpy(): + data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength) - results.append(Image.fromarray(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))) + results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)) - return results + return StageResult(arrays=results) diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index a6cab0fe..b31ef7bd 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -7,6 +7,7 @@ from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,7 +21,7 @@ class BlendGridStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, height: int, width: int, @@ -31,7 +32,7 @@ class BlendGridStage(BaseStage): stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("combining source images using grid layout") size = sources[0].size @@ -49,7 +50,7 @@ class BlendGridStage(BaseStage): n = order[i] output.paste(sources[n], (x * size[0], y * size[1])) - return [*sources, output] + return StageResult(images=[*sources, output]) def outputs( self, diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index d44e52cc..aa5aaa5e 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -11,6 +11,7 @@ from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage): server: ServerContext, _stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, strength: float, callback: Optional[ProgressCallback] = None, stage_source: Optional[Image.Image] = None, prompt_index: Optional[int] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: params = params.with_args(**kwargs) # multi-stage prompting @@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage): pipe_params["strength"] = strength outputs = [] - for source in sources: + for source in sources.as_image(): if params.is_lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) @@ -101,7 +102,7 @@ class BlendImg2ImgStage(BaseStage): outputs.extend(result.images) - return outputs + return StageResult(images=outputs) def steps( self, diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 1eae984a..6b2c8d6c 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,13 +19,13 @@ class BlendLinearStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, alpha: float, stage_source: Optional[Image.Image] = None, _callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("blending source images using linear interpolation") - return [Image.blend(source, stage_source, alpha) for source in sources] + return StageResult(images=[Image.blend(source, stage_source, alpha) for source in sources]) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index d4cd4001..75cdcc9f 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -9,6 +9,7 @@ from ..server import ServerContext from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,13 +21,13 @@ class BlendMaskStage(BaseStage): server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, _callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("blending image using mask") mult_mask = Image.new("RGBA", stage_mask.size, color="black") @@ -37,4 +38,4 @@ class BlendMaskStage(BaseStage): save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) - return [Image.composite(stage_source, source, mult_mask) for source in sources] + return StageResult(images=[Image.composite(stage_source, source, mult_mask) for source in sources]) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index f649acc6..66c6d454 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, stage_source: Optional[Image.Image] = None, upscale: UpscaleParams, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: # must be within the load function for patch to take effect # TODO: rewrite and remove from codeformer import CodeFormer @@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage): device = worker.get_device() pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) - return [pipe(source) for source in sources] + return StageResult(images=[pipe(source) for source in sources]) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 6b0e17be..56aaa849 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -10,6 +10,7 @@ from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -57,12 +58,12 @@ class CorrectGFPGANStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.correction_model is None: @@ -73,16 +74,12 @@ class CorrectGFPGANStage(BaseStage): device = worker.get_device() gfpgan = self.load(server, stage, upscale, device) - outputs = [] - for source in sources: - output = np.array(source) - _, _, output = gfpgan.enhance( - output, + outputs = [gfpgan.enhance( + source, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength, - ) - outputs.append(Image.fromarray(output, "RGB")) + ) for source in sources.as_numpy()] - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 38ec2b3f..f55d54e1 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -8,6 +8,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -21,13 +22,13 @@ class PersistDiskStage(BaseStage): server: ServerContext, _stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, output: List[str], size: Optional[Size] = None, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info( "persisting images to disk: %s, %s", [s.size for s in sources], output ) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index f2becfc2..91d946e5 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,7 +21,7 @@ class PersistS3Stage(BaseStage): server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, output: str, bucket: str, @@ -28,11 +29,11 @@ class PersistS3Stage(BaseStage): profile_name: Optional[str] = None, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) - for source in sources: + for source in sources.as_image(): data = BytesIO() source.save(data, format=server.image_format) data.seek(0) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index edba28c9..d3d0ebfb 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -13,6 +13,7 @@ from ..utils import is_debug, run_gc from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .tile import needs_tile, process_tile_order +from .result import StageResult logger = getLogger(__name__) @@ -73,26 +74,27 @@ class ChainPipeline: worker: WorkerContext, server: ServerContext, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, callback: Optional[ProgressCallback], **kwargs - ) -> List[Image.Image]: - return self( + ) -> StageResult: + result = self( worker, server, params, sources=sources, callback=callback, **kwargs ) + return result.as_image() def stage(self, callback: BaseStage, params: StageParams, **kwargs): self.stages.append((callback, params, kwargs)) return self - def steps(self, params: ImageParams, size: Size): + def steps(self, params: ImageParams, size: Size) -> int: steps = 0 for callback, _params, kwargs in self.stages: steps += callback.steps(kwargs.get("params", params), size) return steps - def outputs(self, params: ImageParams, sources: int): + def outputs(self, params: ImageParams, sources: int) -> int: outputs = sources for callback, _params, kwargs in self.stages: outputs = callback.outputs(kwargs.get("params", params), outputs) @@ -104,10 +106,10 @@ class ChainPipeline: worker: WorkerContext, server: ServerContext, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, callback: Optional[ProgressCallback] = None, **pipeline_kwargs - ) -> List[Image.Image]: + ) -> StageResult: """ DEPRECATED: use `run` instead """ @@ -161,23 +163,21 @@ class ChainPipeline: tile = min(stage_pipe.max_tile, stage_params.tile_size) if stage_sources or must_tile: - stage_outputs = [] + stage_results = [] for source in stage_sources: logger.info( "image contains sources or is larger than tile size of %s, tiling stage", tile, ) - extra_tiles = [] - def stage_tile( source_tile: Image.Image, tile_mask: Image.Image, dims: Tuple[int, int, int], - ) -> Image.Image: + ) -> StageResult: for _i in range(worker.retries): try: - output_tile = stage_pipe.run( + tile_result = stage_pipe.run( worker, server, stage_params, @@ -189,17 +189,11 @@ class ChainPipeline: **kwargs, ) - if len(output_tile) > 1: - while len(extra_tiles) < len(output_tile): - extra_tiles.append([]) - - for tile, layer in zip(output_tile, extra_tiles): - layer.append((tile, dims)) - if is_debug(): - save_image(server, "last-tile.png", output_tile[0]) + for j, image in enumerate(tile_result.as_image()): + save_image(server, f"last-tile-{j}.png", image) - return output_tile[0] + return tile_result except Exception: worker.retries = worker.retries - 1 logger.exception( @@ -220,17 +214,9 @@ class ChainPipeline: **kwargs, ) - stage_outputs.append(output) + stage_results.append(output) - if len(extra_tiles) > 1: - for layer in extra_tiles: - layer_output = Image.new("RGB", output.size) - for layer_tile, dims in layer: - layer_output.paste(layer_tile, (dims[0], dims[1])) - - stage_outputs.append(layer_output) - - stage_sources = stage_outputs + stage_sources = StageResult(images=stage_results) else: logger.debug( "image does not contain sources and is within tile size of %s, running stage", @@ -238,7 +224,7 @@ class ChainPipeline: ) for i in range(worker.retries): try: - stage_outputs = stage_pipe.run( + stage_result = stage_pipe.run( worker, server, stage_params, @@ -250,7 +236,7 @@ class ChainPipeline: ) # doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline # does not like, so it throws - stage_sources = stage_outputs + stage_sources = stage_result break except Exception: worker.retries = worker.retries - 1 diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 24974e36..31e16b54 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, origin: Size, size: Size, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: outputs = [] - for source in sources: + for source in sources.as_image(): image = source.crop((origin.width, origin.height, size.width, size.height)) logger.info( "created thumbnail with dimensions: %sx%s", image.width, image.height ) outputs.append(image) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index c22ba3fe..6c909232 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,15 +19,15 @@ class ReduceThumbnailStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, size: Size, stage_source: Image.Image, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: outputs = [] - for source in sources: + for source in sources.as_image(): image = source.copy() image = image.thumbnail((size.width, size.height)) @@ -37,4 +38,4 @@ class ReduceThumbnailStage(BaseStage): outputs.append(image) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 627c5197..028c63d2 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -1,4 +1,4 @@ -from PIL.Image import Image, fromarray +from PIL import Image from typing import List, Optional import numpy as np @@ -7,25 +7,35 @@ class StageResult: """ Chain pipeline stage result. Can contain PIL images or numpy arrays, with helpers to convert between them. + This class intentionally does not provide `__iter__`, to ensure clients get results in the format + they are expected. """ arrays: Optional[List[np.ndarray]] - images: Optional[List[Image]] + images: Optional[List[Image.Image]] def __init__(self, arrays = None, images = None) -> None: if arrays is not None and images is not None: raise ValueError("stages must only return one type of result") + elif arrays is None and images is None: + raise ValueError("stages must return results") self.arrays = arrays self.images = images + def __len__(self) -> int: + if self.arrays is not None: + return len(self.arrays) + else: + return len(self.images) + def as_numpy(self) -> List[np.ndarray]: if self.arrays is not None: return self.arrays return [np.array(i) for i in self.images] - def as_image(self) -> List[Image]: + def as_image(self) -> List[Image.Image]: if self.images is not None: return self.images - return [fromarray(i) for i in self.arrays] + return [Image.fromarray(i, "RGB") for i in self.arrays] diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 2cf5b6b0..d7a606d6 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,13 +19,13 @@ class SourceNoiseStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, size: Size, noise_source: Callable, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("generating image from noise source") if len(sources) > 0: @@ -32,16 +33,16 @@ class SourceNoiseStage(BaseStage): "source images were passed to a source stage, new images will be appended" ) - outputs = list(sources) + outputs = [] # TODO: looping over sources and ignoring params does not make much sense for a source stage - for source in sources: + for source in sources.as_image(): output = noise_source(source, (size.width, size.height), (0, 0)) logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return outputs + return StageResult(images=outputs) def outputs( self, diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index 32eb4357..d9a53aca 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,14 +21,14 @@ class SourceS3Stage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, source_keys: List[str], bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) @@ -36,7 +37,7 @@ class SourceS3Stage(BaseStage): "source images were passed to a source stage, new images will be appended" ) - outputs = list(sources) + outputs = sources.as_image() for key in source_keys: try: logger.info("loading image from s3://%s/%s", bucket, key) @@ -48,7 +49,7 @@ class SourceS3Stage(BaseStage): except Exception: logger.exception("error loading image from S3") - return outputs + return StageResult(outputs) def outputs( self, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 6eb20285..f41ebbbf 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -19,6 +19,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -32,7 +33,7 @@ class SourceTxt2ImgStage(BaseStage): server: ServerContext, stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, dims: Tuple[int, int, int] = None, size: Size, @@ -153,10 +154,10 @@ class SourceTxt2ImgStage(BaseStage): callback=callback, ) - output = list(sources) - output.extend(result.images) - logger.debug("produced %s outputs", len(output)) - return output + outputs = list(sources) + outputs.extend(result.images) + logger.debug("produced %s outputs", len(outputs)) + return StageResult(images=outputs) def steps( self, diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 2dfcb855..c8d100e1 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,12 +21,12 @@ class SourceURLStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, source_urls: List[str], stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("loading image from URL source") if len(sources) > 0: @@ -41,7 +42,7 @@ class SourceURLStage(BaseStage): logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return outputs + return StageResult(images=outputs) def outputs( self, diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 65ac3a46..bc6bfdf5 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -9,6 +9,7 @@ from PIL import Image from ..image.noise_source import noise_source_histogram from ..params import Size, TileOrder +from .result import StageResult # from skimage.exposure import match_histograms @@ -21,7 +22,7 @@ class TileCallback(Protocol): Definition for a tile job function. """ - def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image: + def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult: """ Run this stage against a single tile. """ diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 9afe54ae..5dabaf32 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -54,12 +55,12 @@ class UpscaleBSRGANStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.upscale_model is None: @@ -71,8 +72,8 @@ class UpscaleBSRGANStage(BaseStage): bsrgan = self.load(server, stage, upscale, device) outputs = [] - for source in sources: - image = np.array(source) / 255.0 + for source in sources.as_numpy(): + image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) logger.trace("BSRGAN input shape: %s", image.shape) @@ -99,7 +100,7 @@ class UpscaleBSRGANStage(BaseStage): outputs.append(output) - return outputs + return StageResult(images=outputs) def steps( self, diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index 5ed28f9b..2564b033 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -9,6 +9,7 @@ from ..server import ServerContext from ..worker import WorkerContext from ..worker.context import ProgressCallback from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage): server: ServerContext, stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *args, highres: HighresParams, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: if highres.scale <= 1: return sources chain = stage_highres(stage, params, highres, upscale) - return [ + outputs = [ chain( worker, server, @@ -43,3 +44,5 @@ class UpscaleHighresStage(BaseStage): ) for source in sources ] + + return StageResult(images=outputs) \ No newline at end of file diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 29883cc0..3a321f91 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -19,6 +19,7 @@ from ..server import ServerContext from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage): server: ServerContext, stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, border: Border, dims: Tuple[int, int, int], @@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage): stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( params ) @@ -61,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage): ) outputs = [] - for source in sources: + for source in sources.as_image(): if is_debug(): save_image(server, "tile-source.png", source) save_image(server, "tile-mask.png", tile_mask) @@ -122,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage): outputs.extend(result.images) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 7fbb6901..0cd6322d 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -77,25 +78,22 @@ class UpscaleRealESRGANStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) + upsampler = self.load( + server, upscale, worker.get_device(), tile=stage.tile_size + ) + outputs = [] - for source in sources: - output = np.array(source) - upsampler = self.load( - server, upscale, worker.get_device(), tile=stage.tile_size - ) - - output, _ = upsampler.enhance(output, outscale=upscale.outscale) - - output = Image.fromarray(output, "RGB") - logger.info("final output image size: %sx%s", output.width, output.height) + for source in sources.as_numpy(): + output, _ = upsampler.enhance(source, outscale=upscale.outscale) + logger.info("final output image size: %s", output.shape) outputs.append(output) - return outputs + return StageResult(arrays=outputs) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index 36095339..33046842 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, method: str, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: if upscale.scale <= 1: logger.debug( "simple upscale stage run with scale of %s, skipping", upscale.scale @@ -32,7 +33,7 @@ class UpscaleSimpleStage(BaseStage): return sources outputs = [] - for source in sources: + for source in sources.as_image(): scaled_size = (source.width * upscale.scale, source.height * upscale.scale) if method == "bilinear": diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 763871e0..1bc62417 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -11,6 +11,7 @@ from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage): server: ServerContext, _stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: params = params.with_args(**kwargs) upscale = upscale.with_args(**kwargs) logger.info( @@ -58,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage): pipeline.unet.set_prompts(prompt_embeds) outputs = [] - for source in sources: + for source in sources.as_image(): result = pipeline( prompt, source, diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index 52114bba..94f63d86 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.upscale_model is None: @@ -71,31 +72,27 @@ class UpscaleSwinIRStage(BaseStage): swinir = self.load(server, stage, upscale, device) outputs = [] - for source in sources: + for source in sources.as_numpy(): # TODO: add support for grayscale (1-channel) images - image = np.array(source) / 255.0 + image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) logger.trace("SwinIR input shape: %s", image.shape) scale = upscale.outscale - dest = np.zeros( - ( + logger.trace("SwinIR output shape: %s", ( image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale, - ) - ) - logger.trace("SwinIR output shape: %s", dest.shape) + )) - dest = swinir(image) - dest = np.clip(np.squeeze(dest, axis=0), 0, 1) - dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) - dest = (dest * 255.0).round().astype(np.uint8) + output = swinir(image) + output = np.clip(np.squeeze(output, axis=0), 0, 1) + output = output[[2, 1, 0], :, :].transpose((1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) - output = Image.fromarray(dest, "RGB") - logger.info("output image size: %s x %s", output.width, output.height) + logger.info("output image size: %s", output.shape) outputs.append(output) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 45ef3881..6bd226d5 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -486,7 +486,7 @@ def run_blend_pipeline( outputs: List[str], upscale: UpscaleParams, # highres: HighresParams, - sources: List[Image.Image], + sources: StageResult, mask: Image.Image, ) -> None: # set up the chain pipeline and base stage