clean up result and metadata handling
This commit is contained in:
parent
10acad232c
commit
4f230f4111
|
@ -32,9 +32,9 @@ class BlendDenoiseFastNLMeansStage(BaseStage):
|
||||||
logger.info("denoising source images")
|
logger.info("denoising source images")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for source in sources.as_numpy():
|
for source in sources.as_arrays():
|
||||||
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
|
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
|
||||||
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
|
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
|
||||||
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
|
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
return StageResult(arrays=results)
|
return StageResult.from_arrays(results, metadata=sources.metadata)
|
||||||
|
|
|
@ -14,6 +14,11 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BlendDenoiseLocalStdStage(BaseStage):
|
class BlendDenoiseLocalStdStage(BaseStage):
|
||||||
|
"""
|
||||||
|
Experimental stage to blend and denoise images using local means compared to local standard deviation.
|
||||||
|
Very slow.
|
||||||
|
"""
|
||||||
|
|
||||||
max_tile = SizeChart.max
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
|
@ -35,8 +40,9 @@ class BlendDenoiseLocalStdStage(BaseStage):
|
||||||
return StageResult.from_arrays(
|
return StageResult.from_arrays(
|
||||||
[
|
[
|
||||||
remove_noise(source, threshold=strength, deviation=range)[0]
|
remove_noise(source, threshold=strength, deviation=range)[0]
|
||||||
for source in sources.as_numpy()
|
for source in sources.as_arrays()
|
||||||
]
|
],
|
||||||
|
metadata=sources.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ class BlendGridStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("combining source images using grid layout")
|
logger.info("combining source images using grid layout")
|
||||||
|
|
||||||
images = sources.as_image()
|
images = sources.as_images()
|
||||||
ref_image = images[0]
|
ref_image = images[0]
|
||||||
size = Size(*ref_image.size)
|
size = Size(*ref_image.size)
|
||||||
|
|
||||||
|
@ -52,7 +52,9 @@ class BlendGridStage(BaseStage):
|
||||||
n = order[i]
|
n = order[i]
|
||||||
output.paste(images[n], (x * size.width, y * size.height))
|
output.paste(images[n], (x * size.width, y * size.height))
|
||||||
|
|
||||||
return StageResult(images=[*images, output])
|
result = StageResult(source=sources)
|
||||||
|
result.push_image(output, sources.metadata[0])
|
||||||
|
return result
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -66,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
pipe_params["strength"] = strength
|
pipe_params["strength"] = strength
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_image():
|
for source in sources.as_images():
|
||||||
if params.is_lpw():
|
if params.is_lpw():
|
||||||
logger.debug("using LPW pipeline for img2img")
|
logger.debug("using LPW pipeline for img2img")
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
|
|
@ -28,9 +28,10 @@ class BlendLinearStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("blending source images using linear interpolation")
|
logger.info("blending source images using linear interpolation")
|
||||||
|
|
||||||
return StageResult(
|
return StageResult.from_images(
|
||||||
images=[
|
[
|
||||||
Image.blend(source, stage_source, alpha)
|
Image.blend(source, stage_source, alpha)
|
||||||
for source in sources.as_image()
|
for source in sources.as_images()
|
||||||
]
|
],
|
||||||
|
metadata=sources.metadata,
|
||||||
)
|
)
|
||||||
|
|
|
@ -48,6 +48,7 @@ class BlendMaskStage(BaseStage):
|
||||||
return StageResult.from_images(
|
return StageResult.from_images(
|
||||||
[
|
[
|
||||||
Image.composite(stage_source_tile, source, mult_mask)
|
Image.composite(stage_source_tile, source, mult_mask)
|
||||||
for source in sources.as_image()
|
for source in sources.as_images()
|
||||||
]
|
],
|
||||||
|
metadata=sources.metadata,
|
||||||
)
|
)
|
||||||
|
|
|
@ -67,7 +67,7 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
)
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for img in sources.as_numpy():
|
for img in sources.as_arrays():
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||||
# clean all the intermediate results to process the next image
|
# clean all the intermediate results to process the next image
|
||||||
face_helper.clean_all()
|
face_helper.clean_all()
|
||||||
|
@ -121,4 +121,4 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
)
|
)
|
||||||
results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)))
|
results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)))
|
||||||
|
|
||||||
return StageResult.from_images(results)
|
return StageResult.from_images(results, metadata=sources.metadata)
|
||||||
|
|
|
@ -74,7 +74,7 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
gfpgan = self.load(server, stage, upscale, device)
|
gfpgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_numpy():
|
for source in sources.as_arrays():
|
||||||
cropped, restored, result = gfpgan.enhance(
|
cropped, restored, result = gfpgan.enhance(
|
||||||
source,
|
source,
|
||||||
has_aligned=False,
|
has_aligned=False,
|
||||||
|
@ -84,4 +84,4 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
)
|
)
|
||||||
outputs.append(result)
|
outputs.append(result)
|
||||||
|
|
||||||
return StageResult.from_arrays(outputs)
|
return StageResult.from_arrays(outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -31,18 +31,14 @@ class PersistDiskStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("persisting %s images to disk: %s", len(sources), output)
|
logger.info("persisting %s images to disk: %s", len(sources), output)
|
||||||
|
|
||||||
for name, source, metadata in zip(output, sources.as_image(), sources.metadata):
|
for name, source, metadata in zip(
|
||||||
|
output, sources.as_images(), sources.metadata
|
||||||
|
):
|
||||||
dest = save_image(
|
dest = save_image(
|
||||||
server,
|
server,
|
||||||
name,
|
name,
|
||||||
source,
|
source,
|
||||||
params=metadata.params,
|
metadata=metadata,
|
||||||
size=metadata.size,
|
|
||||||
upscale=metadata.upscale,
|
|
||||||
border=metadata.border,
|
|
||||||
highres=metadata.highres,
|
|
||||||
inversions=metadata.inversions,
|
|
||||||
loras=metadata.loras,
|
|
||||||
)
|
)
|
||||||
logger.info("saved image to %s", dest)
|
logger.info("saved image to %s", dest)
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,8 @@ class PersistS3Stage(BaseStage):
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
for source, name in zip(sources.as_image(), output):
|
# TODO: save metadata as well
|
||||||
|
for source, name in zip(sources.as_images(), output):
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
source.save(data, format=server.image_format)
|
source.save(data, format=server.image_format)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug, run_gc
|
from ..utils import is_debug, run_gc
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from ..worker.command import Progress
|
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
from .tile import needs_tile, process_tile_order
|
from .tile import needs_tile, process_tile_order
|
||||||
|
@ -107,7 +106,7 @@ class ChainPipeline:
|
||||||
result = self(
|
result = self(
|
||||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||||
)
|
)
|
||||||
return result.as_image()
|
return result.as_images()
|
||||||
|
|
||||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||||
self.stages.append((callback, params, kwargs))
|
self.stages.append((callback, params, kwargs))
|
||||||
|
@ -184,7 +183,7 @@ class ChainPipeline:
|
||||||
size=kwargs.get("size", None),
|
size=kwargs.get("size", None),
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
for source in stage_sources.as_image()
|
for source in stage_sources.as_images()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -302,7 +301,7 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
for j, image in enumerate(stage_sources.as_image()):
|
for j, image in enumerate(stage_sources.as_images()):
|
||||||
save_image(server, f"last-stage-{j}.png", image)
|
save_image(server, f"last-stage-{j}.png", image)
|
||||||
|
|
||||||
end = monotonic()
|
end = monotonic()
|
||||||
|
|
|
@ -28,11 +28,11 @@ class ReduceCropStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
for source in sources.as_image():
|
for source in sources.as_images():
|
||||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||||
logger.info(
|
logger.info(
|
||||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||||
)
|
)
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult.from_images(outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -26,7 +26,7 @@ class ReduceThumbnailStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
for source in sources.as_image():
|
for source in sources.as_images():
|
||||||
image = source.copy()
|
image = source.copy()
|
||||||
|
|
||||||
image = image.thumbnail((size.width, size.height))
|
image = image.thumbnail((size.width, size.height))
|
||||||
|
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
|
||||||
|
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult.from_images(outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -194,12 +194,16 @@ class StageResult:
|
||||||
return StageResult(images=[])
|
return StageResult(images=[])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_arrays(arrays: List[np.ndarray]):
|
def from_arrays(
|
||||||
return StageResult(arrays=arrays)
|
arrays: List[np.ndarray], metadata: Optional[List[ImageMetadata]] = None
|
||||||
|
):
|
||||||
|
return StageResult(arrays=arrays, metadata=metadata)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_images(images: List[Image.Image]):
|
def from_images(
|
||||||
return StageResult(images=images)
|
images: List[Image.Image], metadata: Optional[List[ImageMetadata]] = None
|
||||||
|
):
|
||||||
|
return StageResult(images=images, metadata=metadata)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -208,15 +212,22 @@ class StageResult:
|
||||||
metadata: Optional[List[ImageMetadata]] = None,
|
metadata: Optional[List[ImageMetadata]] = None,
|
||||||
source: Optional[Any] = None,
|
source: Optional[Any] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if sum([arrays is not None, images is not None, source is not None]) > 1:
|
data_provided = sum(
|
||||||
raise ValueError("stages must only return one type of result")
|
[arrays is not None, images is not None, source is not None]
|
||||||
elif arrays is None and images is None and source is None:
|
)
|
||||||
raise ValueError("stages must return results")
|
if data_provided > 1:
|
||||||
|
raise ValueError("results must only contain one type of data")
|
||||||
|
elif data_provided == 0:
|
||||||
|
raise ValueError("results must contain some data")
|
||||||
|
|
||||||
self.arrays = arrays
|
if source is not None:
|
||||||
self.images = images
|
self.arrays = source.arrays
|
||||||
self.source = source
|
self.images = source.images
|
||||||
self.metadata = metadata or []
|
self.metadata = source.metadata
|
||||||
|
else:
|
||||||
|
self.arrays = arrays
|
||||||
|
self.images = images
|
||||||
|
self.metadata = metadata or []
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
if self.arrays is not None:
|
if self.arrays is not None:
|
||||||
|
@ -226,7 +237,7 @@ class StageResult:
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def as_numpy(self) -> List[np.ndarray]:
|
def as_arrays(self) -> List[np.ndarray]:
|
||||||
if self.arrays is not None:
|
if self.arrays is not None:
|
||||||
return self.arrays
|
return self.arrays
|
||||||
elif self.images is not None:
|
elif self.images is not None:
|
||||||
|
@ -234,7 +245,7 @@ class StageResult:
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def as_image(self) -> List[Image.Image]:
|
def as_images(self) -> List[Image.Image]:
|
||||||
if self.images is not None:
|
if self.images is not None:
|
||||||
return self.images
|
return self.images
|
||||||
elif self.arrays is not None:
|
elif self.arrays is not None:
|
||||||
|
@ -242,7 +253,7 @@ class StageResult:
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]):
|
def push_array(self, array: np.ndarray, metadata: ImageMetadata):
|
||||||
if self.arrays is not None:
|
if self.arrays is not None:
|
||||||
self.arrays.append(array)
|
self.arrays.append(array)
|
||||||
elif self.images is not None:
|
elif self.images is not None:
|
||||||
|
@ -253,9 +264,9 @@ class StageResult:
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
self.metadata.append(metadata)
|
self.metadata.append(metadata)
|
||||||
else:
|
else:
|
||||||
self.metadata.append(ImageMetadata())
|
raise ValueError("metadata must be provided")
|
||||||
|
|
||||||
def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]):
|
def push_image(self, image: Image.Image, metadata: ImageMetadata):
|
||||||
if self.images is not None:
|
if self.images is not None:
|
||||||
self.images.append(image)
|
self.images.append(image)
|
||||||
elif self.arrays is not None:
|
elif self.arrays is not None:
|
||||||
|
@ -266,11 +277,9 @@ class StageResult:
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
self.metadata.append(metadata)
|
self.metadata.append(metadata)
|
||||||
else:
|
else:
|
||||||
self.metadata.append(ImageMetadata())
|
raise ValueError("metadata must be provided")
|
||||||
|
|
||||||
def insert_array(
|
def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
|
||||||
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
|
|
||||||
):
|
|
||||||
if self.arrays is not None:
|
if self.arrays is not None:
|
||||||
self.arrays.insert(index, array)
|
self.arrays.insert(index, array)
|
||||||
elif self.images is not None:
|
elif self.images is not None:
|
||||||
|
@ -283,11 +292,9 @@ class StageResult:
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
self.metadata.insert(index, metadata)
|
self.metadata.insert(index, metadata)
|
||||||
else:
|
else:
|
||||||
self.metadata.insert(index, ImageMetadata())
|
raise ValueError("metadata must be provided")
|
||||||
|
|
||||||
def insert_image(
|
def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
|
||||||
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
|
|
||||||
):
|
|
||||||
if self.images is not None:
|
if self.images is not None:
|
||||||
self.images.insert(index, image)
|
self.images.insert(index, image)
|
||||||
elif self.arrays is not None:
|
elif self.arrays is not None:
|
||||||
|
@ -298,7 +305,28 @@ class StageResult:
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
self.metadata.insert(index, metadata)
|
self.metadata.insert(index, metadata)
|
||||||
else:
|
else:
|
||||||
self.metadata.insert(index, ImageMetadata())
|
raise ValueError("metadata must be provided")
|
||||||
|
|
||||||
|
def size(self) -> Size:
|
||||||
|
if self.images is not None:
|
||||||
|
return Size(self.images[0].width, self.images[0].height)
|
||||||
|
elif self.arrays is not None:
|
||||||
|
return Size(
|
||||||
|
self.arrays[0].shape[0], self.arrays[0].shape[1]
|
||||||
|
) # TODO: which fields within the shape are width/height?
|
||||||
|
else:
|
||||||
|
return Size(0, 0)
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""
|
||||||
|
Make sure the data exists and that data and metadata match in length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.arrays is None and self.images is None:
|
||||||
|
raise ValueError("no data in result")
|
||||||
|
|
||||||
|
if len(self) != len(self.metadata):
|
||||||
|
raise ValueError("metadata and data do not match in length")
|
||||||
|
|
||||||
|
|
||||||
def shape_mode(arr: np.ndarray) -> str:
|
def shape_mode(arr: np.ndarray) -> str:
|
||||||
|
|
|
@ -36,13 +36,13 @@ class SourceNoiseStage(BaseStage):
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
||||||
for source in sources.as_image():
|
for source in sources.as_images():
|
||||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult.from_images(outputs, metadata=sources.metadata)
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -37,7 +37,7 @@ class SourceS3Stage(BaseStage):
|
||||||
"source images were passed to a source stage, new images will be appended"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = sources.as_image()
|
outputs = sources.as_images()
|
||||||
for key in source_keys:
|
for key in source_keys:
|
||||||
try:
|
try:
|
||||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class SourceURLStage(BaseStage):
|
||||||
"source images were passed to a source stage, new images will be appended"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = sources.as_image()
|
outputs = sources.as_images()
|
||||||
for url in source_urls:
|
for url in source_urls:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
output = Image.open(BytesIO(response.content))
|
output = Image.open(BytesIO(response.content))
|
||||||
|
|
|
@ -257,7 +257,7 @@ def process_tile_stack(
|
||||||
overlap: float = 0.5,
|
overlap: float = 0.5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> List[Image.Image]:
|
||||||
sources = stack.as_image()
|
sources = stack.as_images()
|
||||||
|
|
||||||
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
||||||
mask = kwargs.get("mask", kwargs.get("stage_mask", None))
|
mask = kwargs.get("mask", kwargs.get("stage_mask", None))
|
||||||
|
@ -308,7 +308,7 @@ def process_tile_stack(
|
||||||
bottom_margin,
|
bottom_margin,
|
||||||
)
|
)
|
||||||
tile_stack = add_margin(
|
tile_stack = add_margin(
|
||||||
stack.as_image(),
|
stack.as_images(),
|
||||||
left,
|
left,
|
||||||
top,
|
top,
|
||||||
right,
|
right,
|
||||||
|
@ -346,7 +346,7 @@ def process_tile_stack(
|
||||||
if isinstance(tile_stack, list):
|
if isinstance(tile_stack, list):
|
||||||
tile_stack = StageResult.from_images(tile_stack)
|
tile_stack = StageResult.from_images(tile_stack)
|
||||||
|
|
||||||
tiles.append((left, top, tile_stack.as_image()))
|
tiles.append((left, top, tile_stack.as_images()))
|
||||||
|
|
||||||
lefts, tops, stacks = list(zip(*tiles))
|
lefts, tops, stacks = list(zip(*tiles))
|
||||||
coords = list(zip(lefts, tops))
|
coords = list(zip(lefts, tops))
|
||||||
|
@ -516,7 +516,7 @@ def get_result_tile(
|
||||||
top, left = origin
|
top, left = origin
|
||||||
return [
|
return [
|
||||||
layer.crop((top, left, top + tile.height, left + tile.width))
|
layer.crop((top, left, top + tile.height, left + tile.width))
|
||||||
for layer in result.as_image()
|
for layer in result.as_images()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
bsrgan = self.load(server, stage, upscale, device)
|
bsrgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_numpy():
|
for source in sources.as_arrays():
|
||||||
image = source / 255.0
|
image = source / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
|
@ -105,7 +105,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
logger.debug("output image shape: %s", output.shape)
|
logger.debug("output image shape: %s", output.shape)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return StageResult(arrays=outputs)
|
return StageResult(arrays=outputs, metadata=sources.metadata)
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage):
|
||||||
source,
|
source,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
for source in sources.as_image()
|
for source in sources.as_images()
|
||||||
]
|
]
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -62,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_image():
|
for source in sources.as_images():
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "tile-source.png", source)
|
save_image(server, "tile-source.png", source)
|
||||||
save_image(server, "tile-mask.png", tile_mask)
|
save_image(server, "tile-mask.png", tile_mask)
|
||||||
|
@ -123,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
|
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -112,7 +112,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_numpy():
|
for source in sources.as_arrays():
|
||||||
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
|
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
|
||||||
logger.info("final output image size: %s", output.shape)
|
logger.info("final output image size: %s", output.shape)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
|
@ -33,7 +33,7 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_image():
|
for source in sources.as_images():
|
||||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||||
|
|
||||||
if method == "bilinear":
|
if method == "bilinear":
|
||||||
|
@ -49,4 +49,4 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown upscaling method: %s", method)
|
logger.warning("unknown upscaling method: %s", method)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -59,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
pipeline.unet.set_prompts(prompt_embeds)
|
pipeline.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_image():
|
for source in sources.as_images():
|
||||||
result = pipeline(
|
result = pipeline(
|
||||||
prompt,
|
prompt,
|
||||||
source,
|
source,
|
||||||
|
@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
)
|
)
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -72,7 +72,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
swinir = self.load(server, stage, upscale, device)
|
swinir = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources.as_numpy():
|
for source in sources.as_arrays():
|
||||||
# TODO: add support for grayscale (1-channel) images
|
# TODO: add support for grayscale (1-channel) images
|
||||||
image = source / 255.0
|
image = source / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
|
@ -98,4 +98,4 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
logger.info("output image size: %s", output.shape)
|
logger.info("output image size: %s", output.shape)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs, metadata=sources.metadata)
|
||||||
|
|
|
@ -118,7 +118,7 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# add a thumbnail, if requested
|
# add a thumbnail, if requested
|
||||||
cover = images.as_image()[0]
|
cover = images.as_images()[0]
|
||||||
if params.thumbnail and (
|
if params.thumbnail and (
|
||||||
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
|
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
|
||||||
):
|
):
|
||||||
|
@ -385,12 +385,12 @@ def run_inpaint_pipeline(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
StageResult(images=[source]),
|
StageResult(images=[source]), # TODO: load metadata from source image
|
||||||
callback=progress,
|
callback=progress,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)):
|
for i, image, metadata in enumerate(zip(images.as_images(), images.metadata)):
|
||||||
if full_res_inpaint:
|
if full_res_inpaint:
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "adjusted-output.png", image)
|
save_image(server, "adjusted-output.png", image)
|
||||||
|
|
|
@ -62,7 +62,7 @@ def save_result(
|
||||||
result: StageResult,
|
result: StageResult,
|
||||||
base_name: str,
|
base_name: str,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
images = result.as_image()
|
images = result.as_images()
|
||||||
outputs = make_output_names(server, base_name, len(images))
|
outputs = make_output_names(server, base_name, len(images))
|
||||||
results = []
|
results = []
|
||||||
for image, metadata, filename in zip(images, result.metadata, outputs):
|
for image, metadata, filename in zip(images, result.metadata, outputs):
|
||||||
|
|
Loading…
Reference in New Issue