From 46098960d854121854093da9c08e9d6f31d6ee12 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 2 Jan 2024 21:49:22 -0600 Subject: [PATCH] attempt to save grid mode metadata --- api/onnx_web/chain/highres.py | 4 +++- api/onnx_web/chain/persist_disk.py | 13 +++++++++++-- api/onnx_web/chain/result.py | 14 ++++++++++---- api/onnx_web/chain/source_txt2img.py | 16 +++++++++------- api/onnx_web/diffusers/run.py | 4 +++- api/onnx_web/server/context.py | 2 +- 6 files changed, 37 insertions(+), 16 deletions(-) diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index ce31463f..8c5f0fb1 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -53,7 +53,9 @@ def stage_highres( stage, method=highres.method, overlap=params.vae_overlap, - upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale, upscale=True), + upscale=upscale.with_args( + scale=highres.scale, outscale=highres.scale, upscale=True + ), ) chain.stage( diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 28a08848..e4496a2f 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -31,8 +31,17 @@ class PersistDiskStage(BaseStage): ) -> StageResult: logger.info("persisting %s images to disk: %s", len(sources), output) - for source, name in zip(sources.as_image(), output): - dest = save_image(server, name, source, params=params, size=size) + for name, source, metadata in zip(output, sources.as_image(), sources.metadata): + dest = save_image( + server, + name, + source, + params=metadata.params, + size=metadata.size, + upscale=metadata.upscale, + border=metadata.border, + highres=metadata.highres, + ) # TODO: inversions and loras logger.info("saved image to %s", dest) return sources diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index ce1ee515..0b8dc171 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Optional import numpy as np from PIL import Image @@ -63,14 +63,20 @@ class StageResult: def from_images(images: List[Image.Image]): return StageResult(images=images) - def __init__(self, arrays=None, images=None) -> None: - if arrays is not None and images is not None: + def __init__( + self, + arrays: Optional[List[np.ndarray]] = None, + images: Optional[List[Image.Image]] = None, + source: Optional[Any] = None, + ) -> None: + if sum([arrays is not None, images is not None, source is not None]) > 1: raise ValueError("stages must only return one type of result") - elif arrays is None and images is None: + elif arrays is None and images is None and source is None: raise ValueError("stages must return results") self.arrays = arrays self.images = images + self.source = source def __len__(self) -> int: if self.arrays is not None: diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 571e58ad..cdf54a57 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -18,7 +18,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -115,7 +115,7 @@ class SourceTxt2ImgStage(BaseStage): if params.is_lpw(): logger.debug("using LPW pipeline for txt2img") rng = torch.manual_seed(params.seed) - result = pipe.text2img( + output = pipe.text2img( prompt, height=latent_size.height, width=latent_size.width, @@ -141,7 +141,7 @@ class SourceTxt2ImgStage(BaseStage): pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) - result = pipe( + output = pipe( prompt, height=latent_size.height, width=latent_size.width, @@ -155,10 +155,12 @@ class SourceTxt2ImgStage(BaseStage): callback=callback, ) - outputs = sources.as_image() - outputs.extend(result.images) - logger.debug("produced %s outputs", len(outputs)) - return StageResult(images=outputs) + result = StageResult(source=sources) + for image in output.images: + result.push_image(image, ImageMetadata(params, size)) + + logger.debug("produced %s outputs", len(result)) + return result def steps( self, diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index ece37670..1fef8087 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -122,7 +122,9 @@ def run_txt2img_pipeline( # add a thumbnail, if requested cover = images[0] - if params.thumbnail and (cover.width > server.thumbnail_size or cover.height > server.thumbnail_size): + if params.thumbnail and ( + cover.width > server.thumbnail_size or cover.height > server.thumbnail_size + ): thumbnail = cover.copy() thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index fcb82fff..dfdc2a07 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -133,7 +133,7 @@ class ServerContext: debug=get_boolean(env, "ONNX_WEB_DEBUG", False), thumbnail_size=int( env.get("ONNX_WEB_THUMBNAIL_SIZE", DEFAULT_THUMBNAIL_SIZE) - ) + ), ) def get_setting(self, flag: str, default: str) -> Optional[str]: