1
0
Fork 0

attempt to save grid mode metadata

This commit is contained in:
Sean Sube 2024-01-02 21:49:22 -06:00
parent 297285fb6f
commit 46098960d8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 37 additions and 16 deletions

View File

@ -53,7 +53,9 @@ def stage_highres(
stage,
method=highres.method,
overlap=params.vae_overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale, upscale=True),
upscale=upscale.with_args(
scale=highres.scale, outscale=highres.scale, upscale=True
),
)
chain.stage(

View File

@ -31,8 +31,17 @@ class PersistDiskStage(BaseStage):
) -> StageResult:
logger.info("persisting %s images to disk: %s", len(sources), output)
for source, name in zip(sources.as_image(), output):
dest = save_image(server, name, source, params=params, size=size)
for name, source, metadata in zip(output, sources.as_image(), sources.metadata):
dest = save_image(
server,
name,
source,
params=metadata.params,
size=metadata.size,
upscale=metadata.upscale,
border=metadata.border,
highres=metadata.highres,
) # TODO: inversions and loras
logger.info("saved image to %s", dest)
return sources

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional
import numpy as np
from PIL import Image
@ -63,14 +63,20 @@ class StageResult:
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
def __init__(
self,
arrays: Optional[List[np.ndarray]] = None,
images: Optional[List[Image.Image]] = None,
source: Optional[Any] = None,
) -> None:
if sum([arrays is not None, images is not None, source is not None]) > 1:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
elif arrays is None and images is None and source is None:
raise ValueError("stages must return results")
self.arrays = arrays
self.images = images
self.source = source
def __len__(self) -> int:
if self.arrays is not None:

View File

@ -18,7 +18,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
from .result import ImageMetadata, StageResult
logger = getLogger(__name__)
@ -115,7 +115,7 @@ class SourceTxt2ImgStage(BaseStage):
if params.is_lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
output = pipe.text2img(
prompt,
height=latent_size.height,
width=latent_size.width,
@ -141,7 +141,7 @@ class SourceTxt2ImgStage(BaseStage):
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
output = pipe(
prompt,
height=latent_size.height,
width=latent_size.width,
@ -155,10 +155,12 @@ class SourceTxt2ImgStage(BaseStage):
callback=callback,
)
outputs = sources.as_image()
outputs.extend(result.images)
logger.debug("produced %s outputs", len(outputs))
return StageResult(images=outputs)
result = StageResult(source=sources)
for image in output.images:
result.push_image(image, ImageMetadata(params, size))
logger.debug("produced %s outputs", len(result))
return result
def steps(
self,

View File

@ -122,7 +122,9 @@ def run_txt2img_pipeline(
# add a thumbnail, if requested
cover = images[0]
if params.thumbnail and (cover.width > server.thumbnail_size or cover.height > server.thumbnail_size):
if params.thumbnail and (
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
):
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))

View File

@ -133,7 +133,7 @@ class ServerContext:
debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
thumbnail_size=int(
env.get("ONNX_WEB_THUMBNAIL_SIZE", DEFAULT_THUMBNAIL_SIZE)
)
),
)
def get_setting(self, flag: str, default: str) -> Optional[str]: