1
0
Fork 0

clean up result and metadata handling

This commit is contained in:
Sean Sube 2024-01-05 20:11:58 -06:00
parent 10acad232c
commit 4f230f4111
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
27 changed files with 119 additions and 85 deletions

View File

@ -32,9 +32,9 @@ class BlendDenoiseFastNLMeansStage(BaseStage):
logger.info("denoising source images") logger.info("denoising source images")
results = [] results = []
for source in sources.as_numpy(): for source in sources.as_arrays():
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength) data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)) results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
return StageResult(arrays=results) return StageResult.from_arrays(results, metadata=sources.metadata)

View File

@ -14,6 +14,11 @@ logger = getLogger(__name__)
class BlendDenoiseLocalStdStage(BaseStage): class BlendDenoiseLocalStdStage(BaseStage):
"""
Experimental stage to blend and denoise images using local means compared to local standard deviation.
Very slow.
"""
max_tile = SizeChart.max max_tile = SizeChart.max
def run( def run(
@ -35,8 +40,9 @@ class BlendDenoiseLocalStdStage(BaseStage):
return StageResult.from_arrays( return StageResult.from_arrays(
[ [
remove_noise(source, threshold=strength, deviation=range)[0] remove_noise(source, threshold=strength, deviation=range)[0]
for source in sources.as_numpy() for source in sources.as_arrays()
] ],
metadata=sources.metadata,
) )

View File

@ -35,7 +35,7 @@ class BlendGridStage(BaseStage):
) -> StageResult: ) -> StageResult:
logger.info("combining source images using grid layout") logger.info("combining source images using grid layout")
images = sources.as_image() images = sources.as_images()
ref_image = images[0] ref_image = images[0]
size = Size(*ref_image.size) size = Size(*ref_image.size)
@ -52,7 +52,9 @@ class BlendGridStage(BaseStage):
n = order[i] n = order[i]
output.paste(images[n], (x * size.width, y * size.height)) output.paste(images[n], (x * size.width, y * size.height))
return StageResult(images=[*images, output]) result = StageResult(source=sources)
result.push_image(output, sources.metadata[0])
return result
def outputs( def outputs(
self, self,

View File

@ -66,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
pipe_params["strength"] = strength pipe_params["strength"] = strength
outputs = [] outputs = []
for source in sources.as_image(): for source in sources.as_images():
if params.is_lpw(): if params.is_lpw():
logger.debug("using LPW pipeline for img2img") logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)

View File

@ -28,9 +28,10 @@ class BlendLinearStage(BaseStage):
) -> StageResult: ) -> StageResult:
logger.info("blending source images using linear interpolation") logger.info("blending source images using linear interpolation")
return StageResult( return StageResult.from_images(
images=[ [
Image.blend(source, stage_source, alpha) Image.blend(source, stage_source, alpha)
for source in sources.as_image() for source in sources.as_images()
] ],
metadata=sources.metadata,
) )

View File

@ -48,6 +48,7 @@ class BlendMaskStage(BaseStage):
return StageResult.from_images( return StageResult.from_images(
[ [
Image.composite(stage_source_tile, source, mult_mask) Image.composite(stage_source_tile, source, mult_mask)
for source in sources.as_image() for source in sources.as_images()
] ],
metadata=sources.metadata,
) )

View File

@ -67,7 +67,7 @@ class CorrectCodeformerStage(BaseStage):
) )
results = [] results = []
for img in sources.as_numpy(): for img in sources.as_arrays():
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# clean all the intermediate results to process the next image # clean all the intermediate results to process the next image
face_helper.clean_all() face_helper.clean_all()
@ -121,4 +121,4 @@ class CorrectCodeformerStage(BaseStage):
) )
results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))) results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)))
return StageResult.from_images(results) return StageResult.from_images(results, metadata=sources.metadata)

View File

@ -74,7 +74,7 @@ class CorrectGFPGANStage(BaseStage):
gfpgan = self.load(server, stage, upscale, device) gfpgan = self.load(server, stage, upscale, device)
outputs = [] outputs = []
for source in sources.as_numpy(): for source in sources.as_arrays():
cropped, restored, result = gfpgan.enhance( cropped, restored, result = gfpgan.enhance(
source, source,
has_aligned=False, has_aligned=False,
@ -84,4 +84,4 @@ class CorrectGFPGANStage(BaseStage):
) )
outputs.append(result) outputs.append(result)
return StageResult.from_arrays(outputs) return StageResult.from_arrays(outputs, metadata=sources.metadata)

View File

@ -31,18 +31,14 @@ class PersistDiskStage(BaseStage):
) -> StageResult: ) -> StageResult:
logger.info("persisting %s images to disk: %s", len(sources), output) logger.info("persisting %s images to disk: %s", len(sources), output)
for name, source, metadata in zip(output, sources.as_image(), sources.metadata): for name, source, metadata in zip(
output, sources.as_images(), sources.metadata
):
dest = save_image( dest = save_image(
server, server,
name, name,
source, source,
params=metadata.params, metadata=metadata,
size=metadata.size,
upscale=metadata.upscale,
border=metadata.border,
highres=metadata.highres,
inversions=metadata.inversions,
loras=metadata.loras,
) )
logger.info("saved image to %s", dest) logger.info("saved image to %s", dest)

View File

@ -33,7 +33,8 @@ class PersistS3Stage(BaseStage):
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
for source, name in zip(sources.as_image(), output): # TODO: save metadata as well
for source, name in zip(sources.as_images(), output):
data = BytesIO() data = BytesIO()
source.save(data, format=server.image_format) source.save(data, format=server.image_format)
data.seek(0) data.seek(0)

View File

@ -11,7 +11,6 @@ from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug, run_gc from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from ..worker.command import Progress
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
from .tile import needs_tile, process_tile_order from .tile import needs_tile, process_tile_order
@ -107,7 +106,7 @@ class ChainPipeline:
result = self( result = self(
worker, server, params, sources=sources, callback=callback, **kwargs worker, server, params, sources=sources, callback=callback, **kwargs
) )
return result.as_image() return result.as_images()
def stage(self, callback: BaseStage, params: StageParams, **kwargs): def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs)) self.stages.append((callback, params, kwargs))
@ -184,7 +183,7 @@ class ChainPipeline:
size=kwargs.get("size", None), size=kwargs.get("size", None),
source=source, source=source,
) )
for source in stage_sources.as_image() for source in stage_sources.as_images()
] ]
) )
@ -302,7 +301,7 @@ class ChainPipeline:
) )
if is_debug(): if is_debug():
for j, image in enumerate(stage_sources.as_image()): for j, image in enumerate(stage_sources.as_images()):
save_image(server, f"last-stage-{j}.png", image) save_image(server, f"last-stage-{j}.png", image)
end = monotonic() end = monotonic()

View File

@ -28,11 +28,11 @@ class ReduceCropStage(BaseStage):
) -> StageResult: ) -> StageResult:
outputs = [] outputs = []
for source in sources.as_image(): for source in sources.as_images():
image = source.crop((origin.width, origin.height, size.width, size.height)) image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info( logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height "created thumbnail with dimensions: %sx%s", image.width, image.height
) )
outputs.append(image) outputs.append(image)
return StageResult(images=outputs) return StageResult.from_images(outputs, metadata=sources.metadata)

View File

@ -26,7 +26,7 @@ class ReduceThumbnailStage(BaseStage):
) -> StageResult: ) -> StageResult:
outputs = [] outputs = []
for source in sources.as_image(): for source in sources.as_images():
image = source.copy() image = source.copy()
image = image.thumbnail((size.width, size.height)) image = image.thumbnail((size.width, size.height))
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
outputs.append(image) outputs.append(image)
return StageResult(images=outputs) return StageResult.from_images(outputs, metadata=sources.metadata)

View File

@ -194,12 +194,16 @@ class StageResult:
return StageResult(images=[]) return StageResult(images=[])
@staticmethod @staticmethod
def from_arrays(arrays: List[np.ndarray]): def from_arrays(
return StageResult(arrays=arrays) arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
):
return StageResult(arrays=arrays, metadata=metadata)
@staticmethod @staticmethod
def from_images(images: List[Image.Image]): def from_images(
return StageResult(images=images) images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
):
return StageResult(images=images, metadata=metadata)
def __init__( def __init__(
self, self,
@ -208,14 +212,21 @@ class StageResult:
metadata: Optional[List[ImageMetadata]] = None, metadata: Optional[List[ImageMetadata]] = None,
source: Optional[Any] = None, source: Optional[Any] = None,
) -> None: ) -> None:
if sum([arrays is not None, images is not None, source is not None]) > 1: data_provided = sum(
raise ValueError("stages must only return one type of result") [arrays is not None, images is not None, source is not None]
elif arrays is None and images is None and source is None: )
raise ValueError("stages must return results") if data_provided > 1:
raise ValueError("results must only contain one type of data")
elif data_provided == 0:
raise ValueError("results must contain some data")
if source is not None:
self.arrays = source.arrays
self.images = source.images
self.metadata = source.metadata
else:
self.arrays = arrays self.arrays = arrays
self.images = images self.images = images
self.source = source
self.metadata = metadata or [] self.metadata = metadata or []
def __len__(self) -> int: def __len__(self) -> int:
@ -226,7 +237,7 @@ class StageResult:
else: else:
return 0 return 0
def as_numpy(self) -> List[np.ndarray]: def as_arrays(self) -> List[np.ndarray]:
if self.arrays is not None: if self.arrays is not None:
return self.arrays return self.arrays
elif self.images is not None: elif self.images is not None:
@ -234,7 +245,7 @@ class StageResult:
else: else:
return [] return []
def as_image(self) -> List[Image.Image]: def as_images(self) -> List[Image.Image]:
if self.images is not None: if self.images is not None:
return self.images return self.images
elif self.arrays is not None: elif self.arrays is not None:
@ -242,7 +253,7 @@ class StageResult:
else: else:
return [] return []
def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]): def push_array(self, array: np.ndarray, metadata: ImageMetadata):
if self.arrays is not None: if self.arrays is not None:
self.arrays.append(array) self.arrays.append(array)
elif self.images is not None: elif self.images is not None:
@ -253,9 +264,9 @@ class StageResult:
if metadata is not None: if metadata is not None:
self.metadata.append(metadata) self.metadata.append(metadata)
else: else:
self.metadata.append(ImageMetadata()) raise ValueError("metadata must be provided")
def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]): def push_image(self, image: Image.Image, metadata: ImageMetadata):
if self.images is not None: if self.images is not None:
self.images.append(image) self.images.append(image)
elif self.arrays is not None: elif self.arrays is not None:
@ -266,11 +277,9 @@ class StageResult:
if metadata is not None: if metadata is not None:
self.metadata.append(metadata) self.metadata.append(metadata)
else: else:
self.metadata.append(ImageMetadata()) raise ValueError("metadata must be provided")
def insert_array( def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
):
if self.arrays is not None: if self.arrays is not None:
self.arrays.insert(index, array) self.arrays.insert(index, array)
elif self.images is not None: elif self.images is not None:
@ -283,11 +292,9 @@ class StageResult:
if metadata is not None: if metadata is not None:
self.metadata.insert(index, metadata) self.metadata.insert(index, metadata)
else: else:
self.metadata.insert(index, ImageMetadata()) raise ValueError("metadata must be provided")
def insert_image( def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
):
if self.images is not None: if self.images is not None:
self.images.insert(index, image) self.images.insert(index, image)
elif self.arrays is not None: elif self.arrays is not None:
@ -298,7 +305,28 @@ class StageResult:
if metadata is not None: if metadata is not None:
self.metadata.insert(index, metadata) self.metadata.insert(index, metadata)
else: else:
self.metadata.insert(index, ImageMetadata()) raise ValueError("metadata must be provided")
def size(self) -> Size:
if self.images is not None:
return Size(self.images[0].width, self.images[0].height)
elif self.arrays is not None:
return Size(
self.arrays[0].shape[0], self.arrays[0].shape[1]
) # TODO: which fields within the shape are width/height?
else:
return Size(0, 0)
def validate(self) -> None:
"""
Make sure the data exists and that data and metadata match in length.
"""
if self.arrays is None and self.images is None:
raise ValueError("no data in result")
if len(self) != len(self.metadata):
raise ValueError("metadata and data do not match in length")
def shape_mode(arr: np.ndarray) -> str: def shape_mode(arr: np.ndarray) -> str:

View File

@ -36,13 +36,13 @@ class SourceNoiseStage(BaseStage):
outputs = [] outputs = []
# TODO: looping over sources and ignoring params does not make much sense for a source stage # TODO: looping over sources and ignoring params does not make much sense for a source stage
for source in sources.as_image(): for source in sources.as_images():
output = noise_source(source, (size.width, size.height), (0, 0)) output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return StageResult(images=outputs) return StageResult.from_images(outputs, metadata=sources.metadata)
def outputs( def outputs(
self, self,

View File

@ -37,7 +37,7 @@ class SourceS3Stage(BaseStage):
"source images were passed to a source stage, new images will be appended" "source images were passed to a source stage, new images will be appended"
) )
outputs = sources.as_image() outputs = sources.as_images()
for key in source_keys: for key in source_keys:
try: try:
logger.info("loading image from s3://%s/%s", bucket, key) logger.info("loading image from s3://%s/%s", bucket, key)

View File

@ -34,7 +34,7 @@ class SourceURLStage(BaseStage):
"source images were passed to a source stage, new images will be appended" "source images were passed to a source stage, new images will be appended"
) )
outputs = sources.as_image() outputs = sources.as_images()
for url in source_urls: for url in source_urls:
response = requests.get(url) response = requests.get(url)
output = Image.open(BytesIO(response.content)) output = Image.open(BytesIO(response.content))

View File

@ -257,7 +257,7 @@ def process_tile_stack(
overlap: float = 0.5, overlap: float = 0.5,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> List[Image.Image]:
sources = stack.as_image() sources = stack.as_images()
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None) width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
mask = kwargs.get("mask", kwargs.get("stage_mask", None)) mask = kwargs.get("mask", kwargs.get("stage_mask", None))
@ -308,7 +308,7 @@ def process_tile_stack(
bottom_margin, bottom_margin,
) )
tile_stack = add_margin( tile_stack = add_margin(
stack.as_image(), stack.as_images(),
left, left,
top, top,
right, right,
@ -346,7 +346,7 @@ def process_tile_stack(
if isinstance(tile_stack, list): if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack) tile_stack = StageResult.from_images(tile_stack)
tiles.append((left, top, tile_stack.as_image())) tiles.append((left, top, tile_stack.as_images()))
lefts, tops, stacks = list(zip(*tiles)) lefts, tops, stacks = list(zip(*tiles))
coords = list(zip(lefts, tops)) coords = list(zip(lefts, tops))
@ -516,7 +516,7 @@ def get_result_tile(
top, left = origin top, left = origin
return [ return [
layer.crop((top, left, top + tile.height, left + tile.width)) layer.crop((top, left, top + tile.height, left + tile.width))
for layer in result.as_image() for layer in result.as_images()
] ]

View File

@ -79,7 +79,7 @@ class UpscaleBSRGANStage(BaseStage):
bsrgan = self.load(server, stage, upscale, device) bsrgan = self.load(server, stage, upscale, device)
outputs = [] outputs = []
for source in sources.as_numpy(): for source in sources.as_arrays():
image = source / 255.0 image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
@ -105,7 +105,7 @@ class UpscaleBSRGANStage(BaseStage):
logger.debug("output image shape: %s", output.shape) logger.debug("output image shape: %s", output.shape)
outputs.append(output) outputs.append(output)
return StageResult(arrays=outputs) return StageResult(arrays=outputs, metadata=sources.metadata)
def steps( def steps(
self, self,

View File

@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage):
source, source,
callback=callback, callback=callback,
) )
for source in sources.as_image() for source in sources.as_images()
] ]
return StageResult(images=outputs) return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -62,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage):
) )
outputs = [] outputs = []
for source in sources.as_image(): for source in sources.as_images():
if is_debug(): if is_debug():
save_image(server, "tile-source.png", source) save_image(server, "tile-source.png", source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "tile-mask.png", tile_mask)
@ -123,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
outputs.extend(result.images) outputs.extend(result.images)
return StageResult(images=outputs) return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -112,7 +112,7 @@ class UpscaleRealESRGANStage(BaseStage):
) )
outputs = [] outputs = []
for source in sources.as_numpy(): for source in sources.as_arrays():
output, _ = upsampler.enhance(source, outscale=upscale.outscale) output, _ = upsampler.enhance(source, outscale=upscale.outscale)
logger.info("final output image size: %s", output.shape) logger.info("final output image size: %s", output.shape)
outputs.append(output) outputs.append(output)

View File

@ -33,7 +33,7 @@ class UpscaleSimpleStage(BaseStage):
return sources return sources
outputs = [] outputs = []
for source in sources.as_image(): for source in sources.as_images():
scaled_size = (source.width * upscale.scale, source.height * upscale.scale) scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear": if method == "bilinear":
@ -49,4 +49,4 @@ class UpscaleSimpleStage(BaseStage):
else: else:
logger.warning("unknown upscaling method: %s", method) logger.warning("unknown upscaling method: %s", method)
return StageResult(images=outputs) return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -59,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage):
pipeline.unet.set_prompts(prompt_embeds) pipeline.unet.set_prompts(prompt_embeds)
outputs = [] outputs = []
for source in sources.as_image(): for source in sources.as_images():
result = pipeline( result = pipeline(
prompt, prompt,
source, source,
@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
) )
outputs.extend(result.images) outputs.extend(result.images)
return StageResult(images=outputs) return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -72,7 +72,7 @@ class UpscaleSwinIRStage(BaseStage):
swinir = self.load(server, stage, upscale, device) swinir = self.load(server, stage, upscale, device)
outputs = [] outputs = []
for source in sources.as_numpy(): for source in sources.as_arrays():
# TODO: add support for grayscale (1-channel) images # TODO: add support for grayscale (1-channel) images
image = source / 255.0 image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
@ -98,4 +98,4 @@ class UpscaleSwinIRStage(BaseStage):
logger.info("output image size: %s", output.shape) logger.info("output image size: %s", output.shape)
outputs.append(output) outputs.append(output)
return StageResult(images=outputs) return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -118,7 +118,7 @@ def run_txt2img_pipeline(
) )
# add a thumbnail, if requested # add a thumbnail, if requested
cover = images.as_image()[0] cover = images.as_images()[0]
if params.thumbnail and ( if params.thumbnail and (
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
): ):
@ -385,12 +385,12 @@ def run_inpaint_pipeline(
worker, worker,
server, server,
params, params,
StageResult(images=[source]), StageResult(images=[source]), # TODO: load metadata from source image
callback=progress, callback=progress,
latents=latents, latents=latents,
) )
for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)): for i, image, metadata in enumerate(zip(images.as_images(), images.metadata)):
if full_res_inpaint: if full_res_inpaint:
if is_debug(): if is_debug():
save_image(server, "adjusted-output.png", image) save_image(server, "adjusted-output.png", image)

View File

@ -62,7 +62,7 @@ def save_result(
result: StageResult, result: StageResult,
base_name: str, base_name: str,
) -> List[str]: ) -> List[str]:
images = result.as_image() images = result.as_images()
outputs = make_output_names(server, base_name, len(images)) outputs = make_output_names(server, base_name, len(images))
results = [] results = []
for image, metadata, filename in zip(images, result.metadata, outputs): for image, metadata, filename in zip(images, result.metadata, outputs):