1
0
Fork 0

attempt to save grid mode metadata

This commit is contained in:
Sean Sube 2024-01-02 21:49:22 -06:00
parent 297285fb6f
commit 46098960d8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 37 additions and 16 deletions

View File

@ -53,7 +53,9 @@ def stage_highres(
stage, stage,
method=highres.method, method=highres.method,
overlap=params.vae_overlap, overlap=params.vae_overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale, upscale=True), upscale=upscale.with_args(
scale=highres.scale, outscale=highres.scale, upscale=True
),
) )
chain.stage( chain.stage(

View File

@ -31,8 +31,17 @@ class PersistDiskStage(BaseStage):
) -> StageResult: ) -> StageResult:
logger.info("persisting %s images to disk: %s", len(sources), output) logger.info("persisting %s images to disk: %s", len(sources), output)
for source, name in zip(sources.as_image(), output): for name, source, metadata in zip(output, sources.as_image(), sources.metadata):
dest = save_image(server, name, source, params=params, size=size) dest = save_image(
server,
name,
source,
params=metadata.params,
size=metadata.size,
upscale=metadata.upscale,
border=metadata.border,
highres=metadata.highres,
) # TODO: inversions and loras
logger.info("saved image to %s", dest) logger.info("saved image to %s", dest)
return sources return sources

View File

@ -1,4 +1,4 @@
from typing import List, Optional from typing import Any, List, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -63,14 +63,20 @@ class StageResult:
def from_images(images: List[Image.Image]): def from_images(images: List[Image.Image]):
return StageResult(images=images) return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None: def __init__(
if arrays is not None and images is not None: self,
arrays: Optional[List[np.ndarray]] = None,
images: Optional[List[Image.Image]] = None,
source: Optional[Any] = None,
) -> None:
if sum([arrays is not None, images is not None, source is not None]) > 1:
raise ValueError("stages must only return one type of result") raise ValueError("stages must only return one type of result")
elif arrays is None and images is None: elif arrays is None and images is None and source is None:
raise ValueError("stages must return results") raise ValueError("stages must return results")
self.arrays = arrays self.arrays = arrays
self.images = images self.images = images
self.source = source
def __len__(self) -> int: def __len__(self) -> int:
if self.arrays is not None: if self.arrays is not None:

View File

@ -18,7 +18,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import ImageMetadata, StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -115,7 +115,7 @@ class SourceTxt2ImgStage(BaseStage):
if params.is_lpw(): if params.is_lpw():
logger.debug("using LPW pipeline for txt2img") logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.text2img( output = pipe.text2img(
prompt, prompt,
height=latent_size.height, height=latent_size.height,
width=latent_size.width, width=latent_size.width,
@ -141,7 +141,7 @@ class SourceTxt2ImgStage(BaseStage):
pipe.unet.set_prompts(prompt_embeds) pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( output = pipe(
prompt, prompt,
height=latent_size.height, height=latent_size.height,
width=latent_size.width, width=latent_size.width,
@ -155,10 +155,12 @@ class SourceTxt2ImgStage(BaseStage):
callback=callback, callback=callback,
) )
outputs = sources.as_image() result = StageResult(source=sources)
outputs.extend(result.images) for image in output.images:
logger.debug("produced %s outputs", len(outputs)) result.push_image(image, ImageMetadata(params, size))
return StageResult(images=outputs)
logger.debug("produced %s outputs", len(result))
return result
def steps( def steps(
self, self,

View File

@ -122,7 +122,9 @@ def run_txt2img_pipeline(
# add a thumbnail, if requested # add a thumbnail, if requested
cover = images[0] cover = images[0]
if params.thumbnail and (cover.width > server.thumbnail_size or cover.height > server.thumbnail_size): if params.thumbnail and (
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
):
thumbnail = cover.copy() thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))

View File

@ -133,7 +133,7 @@ class ServerContext:
debug=get_boolean(env, "ONNX_WEB_DEBUG", False), debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
thumbnail_size=int( thumbnail_size=int(
env.get("ONNX_WEB_THUMBNAIL_SIZE", DEFAULT_THUMBNAIL_SIZE) env.get("ONNX_WEB_THUMBNAIL_SIZE", DEFAULT_THUMBNAIL_SIZE)
) ),
) )
def get_setting(self, flag: str, default: str) -> Optional[str]: def get_setting(self, flag: str, default: str) -> Optional[str]: