1
0
Fork 0

clean up result and metadata handling

This commit is contained in:
Sean Sube 2024-01-05 20:11:58 -06:00
parent 10acad232c
commit 4f230f4111
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
27 changed files with 119 additions and 85 deletions

View File

@ -32,9 +32,9 @@ class BlendDenoiseFastNLMeansStage(BaseStage):
logger.info("denoising source images")
results = []
for source in sources.as_numpy():
for source in sources.as_arrays():
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
return StageResult(arrays=results)
return StageResult.from_arrays(results, metadata=sources.metadata)

View File

@ -14,6 +14,11 @@ logger = getLogger(__name__)
class BlendDenoiseLocalStdStage(BaseStage):
"""
Experimental stage to blend and denoise images using local means compared to local standard deviation.
Very slow.
"""
max_tile = SizeChart.max
def run(
@ -35,8 +40,9 @@ class BlendDenoiseLocalStdStage(BaseStage):
return StageResult.from_arrays(
[
remove_noise(source, threshold=strength, deviation=range)[0]
for source in sources.as_numpy()
]
for source in sources.as_arrays()
],
metadata=sources.metadata,
)

View File

@ -35,7 +35,7 @@ class BlendGridStage(BaseStage):
) -> StageResult:
logger.info("combining source images using grid layout")
images = sources.as_image()
images = sources.as_images()
ref_image = images[0]
size = Size(*ref_image.size)
@ -52,7 +52,9 @@ class BlendGridStage(BaseStage):
n = order[i]
output.paste(images[n], (x * size.width, y * size.height))
return StageResult(images=[*images, output])
result = StageResult(source=sources)
result.push_image(output, sources.metadata[0])
return result
def outputs(
self,

View File

@ -66,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
pipe_params["strength"] = strength
outputs = []
for source in sources.as_image():
for source in sources.as_images():
if params.is_lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)

View File

@ -28,9 +28,10 @@ class BlendLinearStage(BaseStage):
) -> StageResult:
logger.info("blending source images using linear interpolation")
return StageResult(
images=[
return StageResult.from_images(
[
Image.blend(source, stage_source, alpha)
for source in sources.as_image()
]
for source in sources.as_images()
],
metadata=sources.metadata,
)

View File

@ -48,6 +48,7 @@ class BlendMaskStage(BaseStage):
return StageResult.from_images(
[
Image.composite(stage_source_tile, source, mult_mask)
for source in sources.as_image()
]
for source in sources.as_images()
],
metadata=sources.metadata,
)

View File

@ -67,7 +67,7 @@ class CorrectCodeformerStage(BaseStage):
)
results = []
for img in sources.as_numpy():
for img in sources.as_arrays():
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# clean all the intermediate results to process the next image
face_helper.clean_all()
@ -121,4 +121,4 @@ class CorrectCodeformerStage(BaseStage):
)
results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)))
return StageResult.from_images(results)
return StageResult.from_images(results, metadata=sources.metadata)

View File

@ -74,7 +74,7 @@ class CorrectGFPGANStage(BaseStage):
gfpgan = self.load(server, stage, upscale, device)
outputs = []
for source in sources.as_numpy():
for source in sources.as_arrays():
cropped, restored, result = gfpgan.enhance(
source,
has_aligned=False,
@ -84,4 +84,4 @@ class CorrectGFPGANStage(BaseStage):
)
outputs.append(result)
return StageResult.from_arrays(outputs)
return StageResult.from_arrays(outputs, metadata=sources.metadata)

View File

@ -31,18 +31,14 @@ class PersistDiskStage(BaseStage):
) -> StageResult:
logger.info("persisting %s images to disk: %s", len(sources), output)
for name, source, metadata in zip(output, sources.as_image(), sources.metadata):
for name, source, metadata in zip(
output, sources.as_images(), sources.metadata
):
dest = save_image(
server,
name,
source,
params=metadata.params,
size=metadata.size,
upscale=metadata.upscale,
border=metadata.border,
highres=metadata.highres,
inversions=metadata.inversions,
loras=metadata.loras,
metadata=metadata,
)
logger.info("saved image to %s", dest)

View File

@ -33,7 +33,8 @@ class PersistS3Stage(BaseStage):
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
for source, name in zip(sources.as_image(), output):
# TODO: save metadata as well
for source, name in zip(sources.as_images(), output):
data = BytesIO()
source.save(data, format=server.image_format)
data.seek(0)

View File

@ -11,7 +11,6 @@ from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from ..worker.command import Progress
from .base import BaseStage
from .result import StageResult
from .tile import needs_tile, process_tile_order
@ -107,7 +106,7 @@ class ChainPipeline:
result = self(
worker, server, params, sources=sources, callback=callback, **kwargs
)
return result.as_image()
return result.as_images()
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
@ -184,7 +183,7 @@ class ChainPipeline:
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources.as_image()
for source in stage_sources.as_images()
]
)
@ -302,7 +301,7 @@ class ChainPipeline:
)
if is_debug():
for j, image in enumerate(stage_sources.as_image()):
for j, image in enumerate(stage_sources.as_images()):
save_image(server, f"last-stage-{j}.png", image)
end = monotonic()

View File

@ -28,11 +28,11 @@ class ReduceCropStage(BaseStage):
) -> StageResult:
outputs = []
for source in sources.as_image():
for source in sources.as_images():
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
outputs.append(image)
return StageResult(images=outputs)
return StageResult.from_images(outputs, metadata=sources.metadata)

View File

@ -26,7 +26,7 @@ class ReduceThumbnailStage(BaseStage):
) -> StageResult:
outputs = []
for source in sources.as_image():
for source in sources.as_images():
image = source.copy()
image = image.thumbnail((size.width, size.height))
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
outputs.append(image)
return StageResult(images=outputs)
return StageResult.from_images(outputs, metadata=sources.metadata)

View File

@ -194,12 +194,16 @@ class StageResult:
return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
def from_arrays(
arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
):
return StageResult(arrays=arrays, metadata=metadata)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def from_images(
images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
):
return StageResult(images=images, metadata=metadata)
def __init__(
self,
@ -208,15 +212,22 @@ class StageResult:
metadata: Optional[List[ImageMetadata]] = None,
source: Optional[Any] = None,
) -> None:
if sum([arrays is not None, images is not None, source is not None]) > 1:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None and source is None:
raise ValueError("stages must return results")
data_provided = sum(
[arrays is not None, images is not None, source is not None]
)
if data_provided > 1:
raise ValueError("results must only contain one type of data")
elif data_provided == 0:
raise ValueError("results must contain some data")
self.arrays = arrays
self.images = images
self.source = source
self.metadata = metadata or []
if source is not None:
self.arrays = source.arrays
self.images = source.images
self.metadata = source.metadata
else:
self.arrays = arrays
self.images = images
self.metadata = metadata or []
def __len__(self) -> int:
if self.arrays is not None:
@ -226,7 +237,7 @@ class StageResult:
else:
return 0
def as_numpy(self) -> List[np.ndarray]:
def as_arrays(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
elif self.images is not None:
@ -234,7 +245,7 @@ class StageResult:
else:
return []
def as_image(self) -> List[Image.Image]:
def as_images(self) -> List[Image.Image]:
if self.images is not None:
return self.images
elif self.arrays is not None:
@ -242,7 +253,7 @@ class StageResult:
else:
return []
def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]):
def push_array(self, array: np.ndarray, metadata: ImageMetadata):
if self.arrays is not None:
self.arrays.append(array)
elif self.images is not None:
@ -253,9 +264,9 @@ class StageResult:
if metadata is not None:
self.metadata.append(metadata)
else:
self.metadata.append(ImageMetadata())
raise ValueError("metadata must be provided")
def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]):
def push_image(self, image: Image.Image, metadata: ImageMetadata):
if self.images is not None:
self.images.append(image)
elif self.arrays is not None:
@ -266,11 +277,9 @@ class StageResult:
if metadata is not None:
self.metadata.append(metadata)
else:
self.metadata.append(ImageMetadata())
raise ValueError("metadata must be provided")
def insert_array(
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
):
def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
if self.arrays is not None:
self.arrays.insert(index, array)
elif self.images is not None:
@ -283,11 +292,9 @@ class StageResult:
if metadata is not None:
self.metadata.insert(index, metadata)
else:
self.metadata.insert(index, ImageMetadata())
raise ValueError("metadata must be provided")
def insert_image(
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
):
def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
if self.images is not None:
self.images.insert(index, image)
elif self.arrays is not None:
@ -298,7 +305,28 @@ class StageResult:
if metadata is not None:
self.metadata.insert(index, metadata)
else:
self.metadata.insert(index, ImageMetadata())
raise ValueError("metadata must be provided")
def size(self) -> Size:
if self.images is not None:
return Size(self.images[0].width, self.images[0].height)
elif self.arrays is not None:
return Size(
self.arrays[0].shape[0], self.arrays[0].shape[1]
) # TODO: which fields within the shape are width/height?
else:
return Size(0, 0)
def validate(self) -> None:
"""
Make sure the data exists and that data and metadata match in length.
"""
if self.arrays is None and self.images is None:
raise ValueError("no data in result")
if len(self) != len(self.metadata):
raise ValueError("metadata and data do not match in length")
def shape_mode(arr: np.ndarray) -> str:

View File

@ -36,13 +36,13 @@ class SourceNoiseStage(BaseStage):
outputs = []
# TODO: looping over sources and ignoring params does not make much sense for a source stage
for source in sources.as_image():
for source in sources.as_images():
output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return StageResult(images=outputs)
return StageResult.from_images(outputs, metadata=sources.metadata)
def outputs(
self,

View File

@ -37,7 +37,7 @@ class SourceS3Stage(BaseStage):
"source images were passed to a source stage, new images will be appended"
)
outputs = sources.as_image()
outputs = sources.as_images()
for key in source_keys:
try:
logger.info("loading image from s3://%s/%s", bucket, key)

View File

@ -34,7 +34,7 @@ class SourceURLStage(BaseStage):
"source images were passed to a source stage, new images will be appended"
)
outputs = sources.as_image()
outputs = sources.as_images()
for url in source_urls:
response = requests.get(url)
output = Image.open(BytesIO(response.content))

View File

@ -257,7 +257,7 @@ def process_tile_stack(
overlap: float = 0.5,
**kwargs,
) -> List[Image.Image]:
sources = stack.as_image()
sources = stack.as_images()
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
mask = kwargs.get("mask", kwargs.get("stage_mask", None))
@ -308,7 +308,7 @@ def process_tile_stack(
bottom_margin,
)
tile_stack = add_margin(
stack.as_image(),
stack.as_images(),
left,
top,
right,
@ -346,7 +346,7 @@ def process_tile_stack(
if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
tiles.append((left, top, tile_stack.as_image()))
tiles.append((left, top, tile_stack.as_images()))
lefts, tops, stacks = list(zip(*tiles))
coords = list(zip(lefts, tops))
@ -516,7 +516,7 @@ def get_result_tile(
top, left = origin
return [
layer.crop((top, left, top + tile.height, left + tile.width))
for layer in result.as_image()
for layer in result.as_images()
]

View File

@ -79,7 +79,7 @@ class UpscaleBSRGANStage(BaseStage):
bsrgan = self.load(server, stage, upscale, device)
outputs = []
for source in sources.as_numpy():
for source in sources.as_arrays():
image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
@ -105,7 +105,7 @@ class UpscaleBSRGANStage(BaseStage):
logger.debug("output image shape: %s", output.shape)
outputs.append(output)
return StageResult(arrays=outputs)
return StageResult(arrays=outputs, metadata=sources.metadata)
def steps(
self,

View File

@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage):
source,
callback=callback,
)
for source in sources.as_image()
for source in sources.as_images()
]
return StageResult(images=outputs)
return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -62,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage):
)
outputs = []
for source in sources.as_image():
for source in sources.as_images():
if is_debug():
save_image(server, "tile-source.png", source)
save_image(server, "tile-mask.png", tile_mask)
@ -123,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
outputs.extend(result.images)
return StageResult(images=outputs)
return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -112,7 +112,7 @@ class UpscaleRealESRGANStage(BaseStage):
)
outputs = []
for source in sources.as_numpy():
for source in sources.as_arrays():
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
logger.info("final output image size: %s", output.shape)
outputs.append(output)

View File

@ -33,7 +33,7 @@ class UpscaleSimpleStage(BaseStage):
return sources
outputs = []
for source in sources.as_image():
for source in sources.as_images():
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear":
@ -49,4 +49,4 @@ class UpscaleSimpleStage(BaseStage):
else:
logger.warning("unknown upscaling method: %s", method)
return StageResult(images=outputs)
return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -59,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage):
pipeline.unet.set_prompts(prompt_embeds)
outputs = []
for source in sources.as_image():
for source in sources.as_images():
result = pipeline(
prompt,
source,
@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
)
outputs.extend(result.images)
return StageResult(images=outputs)
return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -72,7 +72,7 @@ class UpscaleSwinIRStage(BaseStage):
swinir = self.load(server, stage, upscale, device)
outputs = []
for source in sources.as_numpy():
for source in sources.as_arrays():
# TODO: add support for grayscale (1-channel) images
image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
@ -98,4 +98,4 @@ class UpscaleSwinIRStage(BaseStage):
logger.info("output image size: %s", output.shape)
outputs.append(output)
return StageResult(images=outputs)
return StageResult(images=outputs, metadata=sources.metadata)

View File

@ -118,7 +118,7 @@ def run_txt2img_pipeline(
)
# add a thumbnail, if requested
cover = images.as_image()[0]
cover = images.as_images()[0]
if params.thumbnail and (
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
):
@ -385,12 +385,12 @@ def run_inpaint_pipeline(
worker,
server,
params,
StageResult(images=[source]),
StageResult(images=[source]), # TODO: load metadata from source image
callback=progress,
latents=latents,
)
for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)):
for i, image, metadata in enumerate(zip(images.as_images(), images.metadata)):
if full_res_inpaint:
if is_debug():
save_image(server, "adjusted-output.png", image)

View File

@ -62,7 +62,7 @@ def save_result(
result: StageResult,
base_name: str,
) -> List[str]:
images = result.as_image()
images = result.as_images()
outputs = make_output_names(server, base_name, len(images))
results = []
for image, metadata, filename in zip(images, result.metadata, outputs):