From 4f230f4111c2202eb750d9bb289004da7781687f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 5 Jan 2024 20:11:58 -0600 Subject: [PATCH] clean up result and metadata handling --- .../chain/blend_denoise_fastnlmeans.py | 4 +- api/onnx_web/chain/blend_denoise_localstd.py | 10 ++- api/onnx_web/chain/blend_grid.py | 6 +- api/onnx_web/chain/blend_img2img.py | 2 +- api/onnx_web/chain/blend_linear.py | 9 ++- api/onnx_web/chain/blend_mask.py | 5 +- api/onnx_web/chain/correct_codeformer.py | 4 +- api/onnx_web/chain/correct_gfpgan.py | 4 +- api/onnx_web/chain/persist_disk.py | 12 +-- api/onnx_web/chain/persist_s3.py | 3 +- api/onnx_web/chain/pipeline.py | 7 +- api/onnx_web/chain/reduce_crop.py | 4 +- api/onnx_web/chain/reduce_thumbnail.py | 4 +- api/onnx_web/chain/result.py | 80 +++++++++++++------ api/onnx_web/chain/source_noise.py | 4 +- api/onnx_web/chain/source_s3.py | 2 +- api/onnx_web/chain/source_url.py | 2 +- api/onnx_web/chain/tile.py | 8 +- api/onnx_web/chain/upscale_bsrgan.py | 4 +- api/onnx_web/chain/upscale_highres.py | 4 +- api/onnx_web/chain/upscale_outpaint.py | 4 +- api/onnx_web/chain/upscale_resrgan.py | 2 +- api/onnx_web/chain/upscale_simple.py | 4 +- .../chain/upscale_stable_diffusion.py | 4 +- api/onnx_web/chain/upscale_swinir.py | 4 +- api/onnx_web/diffusers/run.py | 6 +- api/onnx_web/output.py | 2 +- 27 files changed, 119 insertions(+), 85 deletions(-) diff --git a/api/onnx_web/chain/blend_denoise_fastnlmeans.py b/api/onnx_web/chain/blend_denoise_fastnlmeans.py index c3de00e2..ea39a451 100644 --- a/api/onnx_web/chain/blend_denoise_fastnlmeans.py +++ b/api/onnx_web/chain/blend_denoise_fastnlmeans.py @@ -32,9 +32,9 @@ class BlendDenoiseFastNLMeansStage(BaseStage): logger.info("denoising source images") results = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength) results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)) - return StageResult(arrays=results) + return StageResult.from_arrays(results, metadata=sources.metadata) diff --git a/api/onnx_web/chain/blend_denoise_localstd.py b/api/onnx_web/chain/blend_denoise_localstd.py index 389e30a8..08f20d47 100644 --- a/api/onnx_web/chain/blend_denoise_localstd.py +++ b/api/onnx_web/chain/blend_denoise_localstd.py @@ -14,6 +14,11 @@ logger = getLogger(__name__) class BlendDenoiseLocalStdStage(BaseStage): + """ + Experimental stage to blend and denoise images using local means compared to local standard deviation. + Very slow. + """ + max_tile = SizeChart.max def run( @@ -35,8 +40,9 @@ class BlendDenoiseLocalStdStage(BaseStage): return StageResult.from_arrays( [ remove_noise(source, threshold=strength, deviation=range)[0] - for source in sources.as_numpy() - ] + for source in sources.as_arrays() + ], + metadata=sources.metadata, ) diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 34e4f535..376c2af9 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -35,7 +35,7 @@ class BlendGridStage(BaseStage): ) -> StageResult: logger.info("combining source images using grid layout") - images = sources.as_image() + images = sources.as_images() ref_image = images[0] size = Size(*ref_image.size) @@ -52,7 +52,9 @@ class BlendGridStage(BaseStage): n = order[i] output.paste(images[n], (x * size.width, y * size.height)) - return StageResult(images=[*images, output]) + result = StageResult(source=sources) + result.push_image(output, sources.metadata[0]) + return result def outputs( self, diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 2c838c73..b2350463 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -66,7 +66,7 @@ class BlendImg2ImgStage(BaseStage): pipe_params["strength"] = strength outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): if params.is_lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 1b40a5fd..0b5e85db 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -28,9 +28,10 @@ class BlendLinearStage(BaseStage): ) -> StageResult: logger.info("blending source images using linear interpolation") - return StageResult( - images=[ + return StageResult.from_images( + [ Image.blend(source, stage_source, alpha) - for source in sources.as_image() - ] + for source in sources.as_images() + ], + metadata=sources.metadata, ) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 93abd1ee..44358f25 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -48,6 +48,7 @@ class BlendMaskStage(BaseStage): return StageResult.from_images( [ Image.composite(stage_source_tile, source, mult_mask) - for source in sources.as_image() - ] + for source in sources.as_images() + ], + metadata=sources.metadata, ) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index a03da0f6..173e05e4 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -67,7 +67,7 @@ class CorrectCodeformerStage(BaseStage): ) results = [] - for img in sources.as_numpy(): + for img in sources.as_arrays(): img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # clean all the intermediate results to process the next image face_helper.clean_all() @@ -121,4 +121,4 @@ class CorrectCodeformerStage(BaseStage): ) results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))) - return StageResult.from_images(results) + return StageResult.from_images(results, metadata=sources.metadata) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index f3ce33f3..52ef659a 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -74,7 +74,7 @@ class CorrectGFPGANStage(BaseStage): gfpgan = self.load(server, stage, upscale, device) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): cropped, restored, result = gfpgan.enhance( source, has_aligned=False, @@ -84,4 +84,4 @@ class CorrectGFPGANStage(BaseStage): ) outputs.append(result) - return StageResult.from_arrays(outputs) + return StageResult.from_arrays(outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index af023a68..61d39eae 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -31,18 +31,14 @@ class PersistDiskStage(BaseStage): ) -> StageResult: logger.info("persisting %s images to disk: %s", len(sources), output) - for name, source, metadata in zip(output, sources.as_image(), sources.metadata): + for name, source, metadata in zip( + output, sources.as_images(), sources.metadata + ): dest = save_image( server, name, source, - params=metadata.params, - size=metadata.size, - upscale=metadata.upscale, - border=metadata.border, - highres=metadata.highres, - inversions=metadata.inversions, - loras=metadata.loras, + metadata=metadata, ) logger.info("saved image to %s", dest) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 060afc4f..6c4cc5f3 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -33,7 +33,8 @@ class PersistS3Stage(BaseStage): session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) - for source, name in zip(sources.as_image(), output): + # TODO: save metadata as well + for source, name in zip(sources.as_images(), output): data = BytesIO() source.save(data, format=server.image_format) data.seek(0) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index c68f4632..6578bf82 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -11,7 +11,6 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..utils import is_debug, run_gc from ..worker import ProgressCallback, WorkerContext -from ..worker.command import Progress from .base import BaseStage from .result import StageResult from .tile import needs_tile, process_tile_order @@ -107,7 +106,7 @@ class ChainPipeline: result = self( worker, server, params, sources=sources, callback=callback, **kwargs ) - return result.as_image() + return result.as_images() def stage(self, callback: BaseStage, params: StageParams, **kwargs): self.stages.append((callback, params, kwargs)) @@ -184,7 +183,7 @@ class ChainPipeline: size=kwargs.get("size", None), source=source, ) - for source in stage_sources.as_image() + for source in stage_sources.as_images() ] ) @@ -302,7 +301,7 @@ class ChainPipeline: ) if is_debug(): - for j, image in enumerate(stage_sources.as_image()): + for j, image in enumerate(stage_sources.as_images()): save_image(server, f"last-stage-{j}.png", image) end = monotonic() diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index fe98fbd3..3a81ce39 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -28,11 +28,11 @@ class ReduceCropStage(BaseStage): ) -> StageResult: outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): image = source.crop((origin.width, origin.height, size.width, size.height)) logger.info( "created thumbnail with dimensions: %sx%s", image.width, image.height ) outputs.append(image) - return StageResult(images=outputs) + return StageResult.from_images(outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 9c65a819..79970301 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -26,7 +26,7 @@ class ReduceThumbnailStage(BaseStage): ) -> StageResult: outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): image = source.copy() image = image.thumbnail((size.width, size.height)) @@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage): outputs.append(image) - return StageResult(images=outputs) + return StageResult.from_images(outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 541fb3e7..2f9557b6 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -194,12 +194,16 @@ class StageResult: return StageResult(images=[]) @staticmethod - def from_arrays(arrays: List[np.ndarray]): - return StageResult(arrays=arrays) + def from_arrays( + arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None + ): + return StageResult(arrays=arrays, metadata=metadata) @staticmethod - def from_images(images: List[Image.Image]): - return StageResult(images=images) + def from_images( + images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None + ): + return StageResult(images=images, metadata=metadata) def __init__( self, @@ -208,15 +212,22 @@ class StageResult: metadata: Optional[List[ImageMetadata]] = None, source: Optional[Any] = None, ) -> None: - if sum([arrays is not None, images is not None, source is not None]) > 1: - raise ValueError("stages must only return one type of result") - elif arrays is None and images is None and source is None: - raise ValueError("stages must return results") + data_provided = sum( + [arrays is not None, images is not None, source is not None] + ) + if data_provided > 1: + raise ValueError("results must only contain one type of data") + elif data_provided == 0: + raise ValueError("results must contain some data") - self.arrays = arrays - self.images = images - self.source = source - self.metadata = metadata or [] + if source is not None: + self.arrays = source.arrays + self.images = source.images + self.metadata = source.metadata + else: + self.arrays = arrays + self.images = images + self.metadata = metadata or [] def __len__(self) -> int: if self.arrays is not None: @@ -226,7 +237,7 @@ class StageResult: else: return 0 - def as_numpy(self) -> List[np.ndarray]: + def as_arrays(self) -> List[np.ndarray]: if self.arrays is not None: return self.arrays elif self.images is not None: @@ -234,7 +245,7 @@ class StageResult: else: return [] - def as_image(self) -> List[Image.Image]: + def as_images(self) -> List[Image.Image]: if self.images is not None: return self.images elif self.arrays is not None: @@ -242,7 +253,7 @@ class StageResult: else: return [] - def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]): + def push_array(self, array: np.ndarray, metadata: ImageMetadata): if self.arrays is not None: self.arrays.append(array) elif self.images is not None: @@ -253,9 +264,9 @@ class StageResult: if metadata is not None: self.metadata.append(metadata) else: - self.metadata.append(ImageMetadata()) + raise ValueError("metadata must be provided") - def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]): + def push_image(self, image: Image.Image, metadata: ImageMetadata): if self.images is not None: self.images.append(image) elif self.arrays is not None: @@ -266,11 +277,9 @@ class StageResult: if metadata is not None: self.metadata.append(metadata) else: - self.metadata.append(ImageMetadata()) + raise ValueError("metadata must be provided") - def insert_array( - self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata] - ): + def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata): if self.arrays is not None: self.arrays.insert(index, array) elif self.images is not None: @@ -283,11 +292,9 @@ class StageResult: if metadata is not None: self.metadata.insert(index, metadata) else: - self.metadata.insert(index, ImageMetadata()) + raise ValueError("metadata must be provided") - def insert_image( - self, index: int, image: Image.Image, metadata: Optional[ImageMetadata] - ): + def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata): if self.images is not None: self.images.insert(index, image) elif self.arrays is not None: @@ -298,7 +305,28 @@ class StageResult: if metadata is not None: self.metadata.insert(index, metadata) else: - self.metadata.insert(index, ImageMetadata()) + raise ValueError("metadata must be provided") + + def size(self) -> Size: + if self.images is not None: + return Size(self.images[0].width, self.images[0].height) + elif self.arrays is not None: + return Size( + self.arrays[0].shape[0], self.arrays[0].shape[1] + ) # TODO: which fields within the shape are width/height? + else: + return Size(0, 0) + + def validate(self) -> None: + """ + Make sure the data exists and that data and metadata match in length. + """ + + if self.arrays is None and self.images is None: + raise ValueError("no data in result") + + if len(self) != len(self.metadata): + raise ValueError("metadata and data do not match in length") def shape_mode(arr: np.ndarray) -> str: diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index d1b2eac2..bfa2d94b 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -36,13 +36,13 @@ class SourceNoiseStage(BaseStage): outputs = [] # TODO: looping over sources and ignoring params does not make much sense for a source stage - for source in sources.as_image(): + for source in sources.as_images(): output = noise_source(source, (size.width, size.height), (0, 0)) logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return StageResult(images=outputs) + return StageResult.from_images(outputs, metadata=sources.metadata) def outputs( self, diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index d9a53aca..539f07df 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -37,7 +37,7 @@ class SourceS3Stage(BaseStage): "source images were passed to a source stage, new images will be appended" ) - outputs = sources.as_image() + outputs = sources.as_images() for key in source_keys: try: logger.info("loading image from s3://%s/%s", bucket, key) diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index b6aa62cd..7fe158bf 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -34,7 +34,7 @@ class SourceURLStage(BaseStage): "source images were passed to a source stage, new images will be appended" ) - outputs = sources.as_image() + outputs = sources.as_images() for url in source_urls: response = requests.get(url) output = Image.open(BytesIO(response.content)) diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index e8e1baff..7a1d6310 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -257,7 +257,7 @@ def process_tile_stack( overlap: float = 0.5, **kwargs, ) -> List[Image.Image]: - sources = stack.as_image() + sources = stack.as_images() width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None) mask = kwargs.get("mask", kwargs.get("stage_mask", None)) @@ -308,7 +308,7 @@ def process_tile_stack( bottom_margin, ) tile_stack = add_margin( - stack.as_image(), + stack.as_images(), left, top, right, @@ -346,7 +346,7 @@ def process_tile_stack( if isinstance(tile_stack, list): tile_stack = StageResult.from_images(tile_stack) - tiles.append((left, top, tile_stack.as_image())) + tiles.append((left, top, tile_stack.as_images())) lefts, tops, stacks = list(zip(*tiles)) coords = list(zip(lefts, tops)) @@ -516,7 +516,7 @@ def get_result_tile( top, left = origin return [ layer.crop((top, left, top + tile.height, left + tile.width)) - for layer in result.as_image() + for layer in result.as_images() ] diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 08c07759..80f0af5f 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -79,7 +79,7 @@ class UpscaleBSRGANStage(BaseStage): bsrgan = self.load(server, stage, upscale, device) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) @@ -105,7 +105,7 @@ class UpscaleBSRGANStage(BaseStage): logger.debug("output image shape: %s", output.shape) outputs.append(output) - return StageResult(arrays=outputs) + return StageResult(arrays=outputs, metadata=sources.metadata) def steps( self, diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index 32f891a6..ed86b97e 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage): source, callback=callback, ) - for source in sources.as_image() + for source in sources.as_images() ] - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 464f5920..de5f1f71 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -62,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage): ) outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): if is_debug(): save_image(server, "tile-source.png", source) save_image(server, "tile-mask.png", tile_mask) @@ -123,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage): outputs.extend(result.images) - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index f3d48195..53afce2a 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -112,7 +112,7 @@ class UpscaleRealESRGANStage(BaseStage): ) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): output, _ = upsampler.enhance(source, outscale=upscale.outscale) logger.info("final output image size: %s", output.shape) outputs.append(output) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index 7e939bd4..5cf5f24c 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -33,7 +33,7 @@ class UpscaleSimpleStage(BaseStage): return sources outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): scaled_size = (source.width * upscale.scale, source.height * upscale.scale) if method == "bilinear": @@ -49,4 +49,4 @@ class UpscaleSimpleStage(BaseStage): else: logger.warning("unknown upscaling method: %s", method) - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 6c8a300e..d169fc3f 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -59,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage): pipeline.unet.set_prompts(prompt_embeds) outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): result = pipeline( prompt, source, @@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage): ) outputs.extend(result.images) - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index ef7d421f..7d55d9b1 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -72,7 +72,7 @@ class UpscaleSwinIRStage(BaseStage): swinir = self.load(server, stage, upscale, device) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): # TODO: add support for grayscale (1-channel) images image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) @@ -98,4 +98,4 @@ class UpscaleSwinIRStage(BaseStage): logger.info("output image size: %s", output.shape) outputs.append(output) - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 3ed614a3..eed650d2 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -118,7 +118,7 @@ def run_txt2img_pipeline( ) # add a thumbnail, if requested - cover = images.as_image()[0] + cover = images.as_images()[0] if params.thumbnail and ( cover.width > server.thumbnail_size or cover.height > server.thumbnail_size ): @@ -385,12 +385,12 @@ def run_inpaint_pipeline( worker, server, params, - StageResult(images=[source]), + StageResult(images=[source]), # TODO: load metadata from source image callback=progress, latents=latents, ) - for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)): + for i, image, metadata in enumerate(zip(images.as_images(), images.metadata)): if full_res_inpaint: if is_debug(): save_image(server, "adjusted-output.png", image) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 64c2f8ad..0e9a32c4 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -62,7 +62,7 @@ def save_result( result: StageResult, base_name: str, ) -> List[str]: - images = result.as_image() + images = result.as_images() outputs = make_output_names(server, base_name, len(images)) results = [] for image, metadata, filename in zip(images, result.metadata, outputs):