clean up result and metadata handling
This commit is contained in:
parent
10acad232c
commit
4f230f4111
|
@ -32,9 +32,9 @@ class BlendDenoiseFastNLMeansStage(BaseStage):
|
|||
logger.info("denoising source images")
|
||||
|
||||
results = []
|
||||
for source in sources.as_numpy():
|
||||
for source in sources.as_arrays():
|
||||
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
|
||||
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
|
||||
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
|
||||
|
||||
return StageResult(arrays=results)
|
||||
return StageResult.from_arrays(results, metadata=sources.metadata)
|
||||
|
|
|
@ -14,6 +14,11 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class BlendDenoiseLocalStdStage(BaseStage):
|
||||
"""
|
||||
Experimental stage to blend and denoise images using local means compared to local standard deviation.
|
||||
Very slow.
|
||||
"""
|
||||
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
|
@ -35,8 +40,9 @@ class BlendDenoiseLocalStdStage(BaseStage):
|
|||
return StageResult.from_arrays(
|
||||
[
|
||||
remove_noise(source, threshold=strength, deviation=range)[0]
|
||||
for source in sources.as_numpy()
|
||||
]
|
||||
for source in sources.as_arrays()
|
||||
],
|
||||
metadata=sources.metadata,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class BlendGridStage(BaseStage):
|
|||
) -> StageResult:
|
||||
logger.info("combining source images using grid layout")
|
||||
|
||||
images = sources.as_image()
|
||||
images = sources.as_images()
|
||||
ref_image = images[0]
|
||||
size = Size(*ref_image.size)
|
||||
|
||||
|
@ -52,7 +52,9 @@ class BlendGridStage(BaseStage):
|
|||
n = order[i]
|
||||
output.paste(images[n], (x * size.width, y * size.height))
|
||||
|
||||
return StageResult(images=[*images, output])
|
||||
result = StageResult(source=sources)
|
||||
result.push_image(output, sources.metadata[0])
|
||||
return result
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
|
|
|
@ -66,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
|
|||
pipe_params["strength"] = strength
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_image():
|
||||
for source in sources.as_images():
|
||||
if params.is_lpw():
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
|
|
@ -28,9 +28,10 @@ class BlendLinearStage(BaseStage):
|
|||
) -> StageResult:
|
||||
logger.info("blending source images using linear interpolation")
|
||||
|
||||
return StageResult(
|
||||
images=[
|
||||
return StageResult.from_images(
|
||||
[
|
||||
Image.blend(source, stage_source, alpha)
|
||||
for source in sources.as_image()
|
||||
]
|
||||
for source in sources.as_images()
|
||||
],
|
||||
metadata=sources.metadata,
|
||||
)
|
||||
|
|
|
@ -48,6 +48,7 @@ class BlendMaskStage(BaseStage):
|
|||
return StageResult.from_images(
|
||||
[
|
||||
Image.composite(stage_source_tile, source, mult_mask)
|
||||
for source in sources.as_image()
|
||||
]
|
||||
for source in sources.as_images()
|
||||
],
|
||||
metadata=sources.metadata,
|
||||
)
|
||||
|
|
|
@ -67,7 +67,7 @@ class CorrectCodeformerStage(BaseStage):
|
|||
)
|
||||
|
||||
results = []
|
||||
for img in sources.as_numpy():
|
||||
for img in sources.as_arrays():
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
# clean all the intermediate results to process the next image
|
||||
face_helper.clean_all()
|
||||
|
@ -121,4 +121,4 @@ class CorrectCodeformerStage(BaseStage):
|
|||
)
|
||||
results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)))
|
||||
|
||||
return StageResult.from_images(results)
|
||||
return StageResult.from_images(results, metadata=sources.metadata)
|
||||
|
|
|
@ -74,7 +74,7 @@ class CorrectGFPGANStage(BaseStage):
|
|||
gfpgan = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_numpy():
|
||||
for source in sources.as_arrays():
|
||||
cropped, restored, result = gfpgan.enhance(
|
||||
source,
|
||||
has_aligned=False,
|
||||
|
@ -84,4 +84,4 @@ class CorrectGFPGANStage(BaseStage):
|
|||
)
|
||||
outputs.append(result)
|
||||
|
||||
return StageResult.from_arrays(outputs)
|
||||
return StageResult.from_arrays(outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -31,18 +31,14 @@ class PersistDiskStage(BaseStage):
|
|||
) -> StageResult:
|
||||
logger.info("persisting %s images to disk: %s", len(sources), output)
|
||||
|
||||
for name, source, metadata in zip(output, sources.as_image(), sources.metadata):
|
||||
for name, source, metadata in zip(
|
||||
output, sources.as_images(), sources.metadata
|
||||
):
|
||||
dest = save_image(
|
||||
server,
|
||||
name,
|
||||
source,
|
||||
params=metadata.params,
|
||||
size=metadata.size,
|
||||
upscale=metadata.upscale,
|
||||
border=metadata.border,
|
||||
highres=metadata.highres,
|
||||
inversions=metadata.inversions,
|
||||
loras=metadata.loras,
|
||||
metadata=metadata,
|
||||
)
|
||||
logger.info("saved image to %s", dest)
|
||||
|
||||
|
|
|
@ -33,7 +33,8 @@ class PersistS3Stage(BaseStage):
|
|||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
for source, name in zip(sources.as_image(), output):
|
||||
# TODO: save metadata as well
|
||||
for source, name in zip(sources.as_images(), output):
|
||||
data = BytesIO()
|
||||
source.save(data, format=server.image_format)
|
||||
data.seek(0)
|
||||
|
|
|
@ -11,7 +11,6 @@ from ..params import ImageParams, Size, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import is_debug, run_gc
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from ..worker.command import Progress
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
from .tile import needs_tile, process_tile_order
|
||||
|
@ -107,7 +106,7 @@ class ChainPipeline:
|
|||
result = self(
|
||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||
)
|
||||
return result.as_image()
|
||||
return result.as_images()
|
||||
|
||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||
self.stages.append((callback, params, kwargs))
|
||||
|
@ -184,7 +183,7 @@ class ChainPipeline:
|
|||
size=kwargs.get("size", None),
|
||||
source=source,
|
||||
)
|
||||
for source in stage_sources.as_image()
|
||||
for source in stage_sources.as_images()
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -302,7 +301,7 @@ class ChainPipeline:
|
|||
)
|
||||
|
||||
if is_debug():
|
||||
for j, image in enumerate(stage_sources.as_image()):
|
||||
for j, image in enumerate(stage_sources.as_images()):
|
||||
save_image(server, f"last-stage-{j}.png", image)
|
||||
|
||||
end = monotonic()
|
||||
|
|
|
@ -28,11 +28,11 @@ class ReduceCropStage(BaseStage):
|
|||
) -> StageResult:
|
||||
outputs = []
|
||||
|
||||
for source in sources.as_image():
|
||||
for source in sources.as_images():
|
||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
outputs.append(image)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult.from_images(outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -26,7 +26,7 @@ class ReduceThumbnailStage(BaseStage):
|
|||
) -> StageResult:
|
||||
outputs = []
|
||||
|
||||
for source in sources.as_image():
|
||||
for source in sources.as_images():
|
||||
image = source.copy()
|
||||
|
||||
image = image.thumbnail((size.width, size.height))
|
||||
|
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
|
|||
|
||||
outputs.append(image)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult.from_images(outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -194,12 +194,16 @@ class StageResult:
|
|||
return StageResult(images=[])
|
||||
|
||||
@staticmethod
|
||||
def from_arrays(arrays: List[np.ndarray]):
|
||||
return StageResult(arrays=arrays)
|
||||
def from_arrays(
|
||||
arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
|
||||
):
|
||||
return StageResult(arrays=arrays, metadata=metadata)
|
||||
|
||||
@staticmethod
|
||||
def from_images(images: List[Image.Image]):
|
||||
return StageResult(images=images)
|
||||
def from_images(
|
||||
images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
|
||||
):
|
||||
return StageResult(images=images, metadata=metadata)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -208,14 +212,21 @@ class StageResult:
|
|||
metadata: Optional[List[ImageMetadata]] = None,
|
||||
source: Optional[Any] = None,
|
||||
) -> None:
|
||||
if sum([arrays is not None, images is not None, source is not None]) > 1:
|
||||
raise ValueError("stages must only return one type of result")
|
||||
elif arrays is None and images is None and source is None:
|
||||
raise ValueError("stages must return results")
|
||||
data_provided = sum(
|
||||
[arrays is not None, images is not None, source is not None]
|
||||
)
|
||||
if data_provided > 1:
|
||||
raise ValueError("results must only contain one type of data")
|
||||
elif data_provided == 0:
|
||||
raise ValueError("results must contain some data")
|
||||
|
||||
if source is not None:
|
||||
self.arrays = source.arrays
|
||||
self.images = source.images
|
||||
self.metadata = source.metadata
|
||||
else:
|
||||
self.arrays = arrays
|
||||
self.images = images
|
||||
self.source = source
|
||||
self.metadata = metadata or []
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
@ -226,7 +237,7 @@ class StageResult:
|
|||
else:
|
||||
return 0
|
||||
|
||||
def as_numpy(self) -> List[np.ndarray]:
|
||||
def as_arrays(self) -> List[np.ndarray]:
|
||||
if self.arrays is not None:
|
||||
return self.arrays
|
||||
elif self.images is not None:
|
||||
|
@ -234,7 +245,7 @@ class StageResult:
|
|||
else:
|
||||
return []
|
||||
|
||||
def as_image(self) -> List[Image.Image]:
|
||||
def as_images(self) -> List[Image.Image]:
|
||||
if self.images is not None:
|
||||
return self.images
|
||||
elif self.arrays is not None:
|
||||
|
@ -242,7 +253,7 @@ class StageResult:
|
|||
else:
|
||||
return []
|
||||
|
||||
def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]):
|
||||
def push_array(self, array: np.ndarray, metadata: ImageMetadata):
|
||||
if self.arrays is not None:
|
||||
self.arrays.append(array)
|
||||
elif self.images is not None:
|
||||
|
@ -253,9 +264,9 @@ class StageResult:
|
|||
if metadata is not None:
|
||||
self.metadata.append(metadata)
|
||||
else:
|
||||
self.metadata.append(ImageMetadata())
|
||||
raise ValueError("metadata must be provided")
|
||||
|
||||
def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]):
|
||||
def push_image(self, image: Image.Image, metadata: ImageMetadata):
|
||||
if self.images is not None:
|
||||
self.images.append(image)
|
||||
elif self.arrays is not None:
|
||||
|
@ -266,11 +277,9 @@ class StageResult:
|
|||
if metadata is not None:
|
||||
self.metadata.append(metadata)
|
||||
else:
|
||||
self.metadata.append(ImageMetadata())
|
||||
raise ValueError("metadata must be provided")
|
||||
|
||||
def insert_array(
|
||||
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
|
||||
):
|
||||
def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
|
||||
if self.arrays is not None:
|
||||
self.arrays.insert(index, array)
|
||||
elif self.images is not None:
|
||||
|
@ -283,11 +292,9 @@ class StageResult:
|
|||
if metadata is not None:
|
||||
self.metadata.insert(index, metadata)
|
||||
else:
|
||||
self.metadata.insert(index, ImageMetadata())
|
||||
raise ValueError("metadata must be provided")
|
||||
|
||||
def insert_image(
|
||||
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
|
||||
):
|
||||
def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
|
||||
if self.images is not None:
|
||||
self.images.insert(index, image)
|
||||
elif self.arrays is not None:
|
||||
|
@ -298,7 +305,28 @@ class StageResult:
|
|||
if metadata is not None:
|
||||
self.metadata.insert(index, metadata)
|
||||
else:
|
||||
self.metadata.insert(index, ImageMetadata())
|
||||
raise ValueError("metadata must be provided")
|
||||
|
||||
def size(self) -> Size:
|
||||
if self.images is not None:
|
||||
return Size(self.images[0].width, self.images[0].height)
|
||||
elif self.arrays is not None:
|
||||
return Size(
|
||||
self.arrays[0].shape[0], self.arrays[0].shape[1]
|
||||
) # TODO: which fields within the shape are width/height?
|
||||
else:
|
||||
return Size(0, 0)
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Make sure the data exists and that data and metadata match in length.
|
||||
"""
|
||||
|
||||
if self.arrays is None and self.images is None:
|
||||
raise ValueError("no data in result")
|
||||
|
||||
if len(self) != len(self.metadata):
|
||||
raise ValueError("metadata and data do not match in length")
|
||||
|
||||
|
||||
def shape_mode(arr: np.ndarray) -> str:
|
||||
|
|
|
@ -36,13 +36,13 @@ class SourceNoiseStage(BaseStage):
|
|||
outputs = []
|
||||
|
||||
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
||||
for source in sources.as_image():
|
||||
for source in sources.as_images():
|
||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult.from_images(outputs, metadata=sources.metadata)
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
|
|
|
@ -37,7 +37,7 @@ class SourceS3Stage(BaseStage):
|
|||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = sources.as_image()
|
||||
outputs = sources.as_images()
|
||||
for key in source_keys:
|
||||
try:
|
||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||
|
|
|
@ -34,7 +34,7 @@ class SourceURLStage(BaseStage):
|
|||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = sources.as_image()
|
||||
outputs = sources.as_images()
|
||||
for url in source_urls:
|
||||
response = requests.get(url)
|
||||
output = Image.open(BytesIO(response.content))
|
||||
|
|
|
@ -257,7 +257,7 @@ def process_tile_stack(
|
|||
overlap: float = 0.5,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
sources = stack.as_image()
|
||||
sources = stack.as_images()
|
||||
|
||||
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
||||
mask = kwargs.get("mask", kwargs.get("stage_mask", None))
|
||||
|
@ -308,7 +308,7 @@ def process_tile_stack(
|
|||
bottom_margin,
|
||||
)
|
||||
tile_stack = add_margin(
|
||||
stack.as_image(),
|
||||
stack.as_images(),
|
||||
left,
|
||||
top,
|
||||
right,
|
||||
|
@ -346,7 +346,7 @@ def process_tile_stack(
|
|||
if isinstance(tile_stack, list):
|
||||
tile_stack = StageResult.from_images(tile_stack)
|
||||
|
||||
tiles.append((left, top, tile_stack.as_image()))
|
||||
tiles.append((left, top, tile_stack.as_images()))
|
||||
|
||||
lefts, tops, stacks = list(zip(*tiles))
|
||||
coords = list(zip(lefts, tops))
|
||||
|
@ -516,7 +516,7 @@ def get_result_tile(
|
|||
top, left = origin
|
||||
return [
|
||||
layer.crop((top, left, top + tile.height, left + tile.width))
|
||||
for layer in result.as_image()
|
||||
for layer in result.as_images()
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
bsrgan = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_numpy():
|
||||
for source in sources.as_arrays():
|
||||
image = source / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
@ -105,7 +105,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
logger.debug("output image shape: %s", output.shape)
|
||||
outputs.append(output)
|
||||
|
||||
return StageResult(arrays=outputs)
|
||||
return StageResult(arrays=outputs, metadata=sources.metadata)
|
||||
|
||||
def steps(
|
||||
self,
|
||||
|
|
|
@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage):
|
|||
source,
|
||||
callback=callback,
|
||||
)
|
||||
for source in sources.as_image()
|
||||
for source in sources.as_images()
|
||||
]
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult(images=outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -62,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
)
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_image():
|
||||
for source in sources.as_images():
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
|
@ -123,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
|
||||
outputs.extend(result.images)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult(images=outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -112,7 +112,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
)
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_numpy():
|
||||
for source in sources.as_arrays():
|
||||
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
|
||||
logger.info("final output image size: %s", output.shape)
|
||||
outputs.append(output)
|
||||
|
|
|
@ -33,7 +33,7 @@ class UpscaleSimpleStage(BaseStage):
|
|||
return sources
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_image():
|
||||
for source in sources.as_images():
|
||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||
|
||||
if method == "bilinear":
|
||||
|
@ -49,4 +49,4 @@ class UpscaleSimpleStage(BaseStage):
|
|||
else:
|
||||
logger.warning("unknown upscaling method: %s", method)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult(images=outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -59,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
pipeline.unet.set_prompts(prompt_embeds)
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_image():
|
||||
for source in sources.as_images():
|
||||
result = pipeline(
|
||||
prompt,
|
||||
source,
|
||||
|
@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
)
|
||||
outputs.extend(result.images)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult(images=outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -72,7 +72,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
swinir = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
for source in sources.as_numpy():
|
||||
for source in sources.as_arrays():
|
||||
# TODO: add support for grayscale (1-channel) images
|
||||
image = source / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
|
@ -98,4 +98,4 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
logger.info("output image size: %s", output.shape)
|
||||
outputs.append(output)
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult(images=outputs, metadata=sources.metadata)
|
||||
|
|
|
@ -118,7 +118,7 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
|
||||
# add a thumbnail, if requested
|
||||
cover = images.as_image()[0]
|
||||
cover = images.as_images()[0]
|
||||
if params.thumbnail and (
|
||||
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
|
||||
):
|
||||
|
@ -385,12 +385,12 @@ def run_inpaint_pipeline(
|
|||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(images=[source]),
|
||||
StageResult(images=[source]), # TODO: load metadata from source image
|
||||
callback=progress,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)):
|
||||
for i, image, metadata in enumerate(zip(images.as_images(), images.metadata)):
|
||||
if full_res_inpaint:
|
||||
if is_debug():
|
||||
save_image(server, "adjusted-output.png", image)
|
||||
|
|
|
@ -62,7 +62,7 @@ def save_result(
|
|||
result: StageResult,
|
||||
base_name: str,
|
||||
) -> List[str]:
|
||||
images = result.as_image()
|
||||
images = result.as_images()
|
||||
outputs = make_output_names(server, base_name, len(images))
|
||||
results = []
|
||||
for image, metadata, filename in zip(images, result.metadata, outputs):
|
||||
|
|
Loading…
Reference in New Issue