attempt to save grid mode metadata
This commit is contained in:
parent
297285fb6f
commit
46098960d8
|
@ -53,7 +53,9 @@ def stage_highres(
|
||||||
stage,
|
stage,
|
||||||
method=highres.method,
|
method=highres.method,
|
||||||
overlap=params.vae_overlap,
|
overlap=params.vae_overlap,
|
||||||
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale, upscale=True),
|
upscale=upscale.with_args(
|
||||||
|
scale=highres.scale, outscale=highres.scale, upscale=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
|
|
|
@ -31,8 +31,17 @@ class PersistDiskStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("persisting %s images to disk: %s", len(sources), output)
|
logger.info("persisting %s images to disk: %s", len(sources), output)
|
||||||
|
|
||||||
for source, name in zip(sources.as_image(), output):
|
for name, source, metadata in zip(output, sources.as_image(), sources.metadata):
|
||||||
dest = save_image(server, name, source, params=params, size=size)
|
dest = save_image(
|
||||||
|
server,
|
||||||
|
name,
|
||||||
|
source,
|
||||||
|
params=metadata.params,
|
||||||
|
size=metadata.size,
|
||||||
|
upscale=metadata.upscale,
|
||||||
|
border=metadata.border,
|
||||||
|
highres=metadata.highres,
|
||||||
|
) # TODO: inversions and loras
|
||||||
logger.info("saved image to %s", dest)
|
logger.info("saved image to %s", dest)
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -63,14 +63,20 @@ class StageResult:
|
||||||
def from_images(images: List[Image.Image]):
|
def from_images(images: List[Image.Image]):
|
||||||
return StageResult(images=images)
|
return StageResult(images=images)
|
||||||
|
|
||||||
def __init__(self, arrays=None, images=None) -> None:
|
def __init__(
|
||||||
if arrays is not None and images is not None:
|
self,
|
||||||
|
arrays: Optional[List[np.ndarray]] = None,
|
||||||
|
images: Optional[List[Image.Image]] = None,
|
||||||
|
source: Optional[Any] = None,
|
||||||
|
) -> None:
|
||||||
|
if sum([arrays is not None, images is not None, source is not None]) > 1:
|
||||||
raise ValueError("stages must only return one type of result")
|
raise ValueError("stages must only return one type of result")
|
||||||
elif arrays is None and images is None:
|
elif arrays is None and images is None and source is None:
|
||||||
raise ValueError("stages must return results")
|
raise ValueError("stages must return results")
|
||||||
|
|
||||||
self.arrays = arrays
|
self.arrays = arrays
|
||||||
self.images = images
|
self.images = images
|
||||||
|
self.source = source
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
if self.arrays is not None:
|
if self.arrays is not None:
|
||||||
|
|
|
@ -18,7 +18,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import ImageMetadata, StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
if params.is_lpw():
|
if params.is_lpw():
|
||||||
logger.debug("using LPW pipeline for txt2img")
|
logger.debug("using LPW pipeline for txt2img")
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
result = pipe.text2img(
|
output = pipe.text2img(
|
||||||
prompt,
|
prompt,
|
||||||
height=latent_size.height,
|
height=latent_size.height,
|
||||||
width=latent_size.width,
|
width=latent_size.width,
|
||||||
|
@ -141,7 +141,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = pipe(
|
output = pipe(
|
||||||
prompt,
|
prompt,
|
||||||
height=latent_size.height,
|
height=latent_size.height,
|
||||||
width=latent_size.width,
|
width=latent_size.width,
|
||||||
|
@ -155,10 +155,12 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = sources.as_image()
|
result = StageResult(source=sources)
|
||||||
outputs.extend(result.images)
|
for image in output.images:
|
||||||
logger.debug("produced %s outputs", len(outputs))
|
result.push_image(image, ImageMetadata(params, size))
|
||||||
return StageResult(images=outputs)
|
|
||||||
|
logger.debug("produced %s outputs", len(result))
|
||||||
|
return result
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -122,7 +122,9 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
# add a thumbnail, if requested
|
# add a thumbnail, if requested
|
||||||
cover = images[0]
|
cover = images[0]
|
||||||
if params.thumbnail and (cover.width > server.thumbnail_size or cover.height > server.thumbnail_size):
|
if params.thumbnail and (
|
||||||
|
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
|
||||||
|
):
|
||||||
thumbnail = cover.copy()
|
thumbnail = cover.copy()
|
||||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||||
|
|
||||||
|
|
|
@ -133,7 +133,7 @@ class ServerContext:
|
||||||
debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
|
debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
|
||||||
thumbnail_size=int(
|
thumbnail_size=int(
|
||||||
env.get("ONNX_WEB_THUMBNAIL_SIZE", DEFAULT_THUMBNAIL_SIZE)
|
env.get("ONNX_WEB_THUMBNAIL_SIZE", DEFAULT_THUMBNAIL_SIZE)
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_setting(self, flag: str, default: str) -> Optional[str]:
|
def get_setting(self, flag: str, default: str) -> Optional[str]:
|
||||||
|
|
Loading…
Reference in New Issue