1
0
Fork 0

start replacing image output with results

This commit is contained in:
Sean Sube 2023-11-18 18:08:38 -06:00
parent 5a517704ea
commit a63669c76b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
27 changed files with 159 additions and 149 deletions

View File

@ -17,7 +17,7 @@ class BaseStage:
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
_sources: List[Image.Image], _sources: StageResult,
*args, *args,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,

View File

@ -9,6 +9,7 @@ from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -22,19 +23,19 @@ class BlendDenoiseStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
strength: int = 3, strength: int = 3,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("denoising source images") logger.info("denoising source images")
results = [] results = []
for source in sources: for source in sources.as_numpy():
data = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR) data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength) data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
results.append(Image.fromarray(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))) results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
return results return StageResult(arrays=results)

View File

@ -7,6 +7,7 @@ from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,7 +21,7 @@ class BlendGridStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
height: int, height: int,
width: int, width: int,
@ -31,7 +32,7 @@ class BlendGridStage(BaseStage):
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("combining source images using grid layout") logger.info("combining source images using grid layout")
size = sources[0].size size = sources[0].size
@ -49,7 +50,7 @@ class BlendGridStage(BaseStage):
n = order[i] n = order[i]
output.paste(sources[n], (x * size[0], y * size[1])) output.paste(sources[n], (x * size[0], y * size[1]))
return [*sources, output] return StageResult(images=[*sources, output])
def outputs( def outputs(
self, self,

View File

@ -11,6 +11,7 @@ from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
strength: float, strength: float,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
prompt_index: Optional[int] = None, prompt_index: Optional[int] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
# multi-stage prompting # multi-stage prompting
@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
pipe_params["strength"] = strength pipe_params["strength"] = strength
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
if params.is_lpw(): if params.is_lpw():
logger.debug("using LPW pipeline for img2img") logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
@ -101,7 +102,7 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images) outputs.extend(result.images)
return outputs return StageResult(images=outputs)
def steps( def steps(
self, self,

View File

@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,13 +19,13 @@ class BlendLinearStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
alpha: float, alpha: float,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, _callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("blending source images using linear interpolation") logger.info("blending source images using linear interpolation")
return [Image.blend(source, stage_source, alpha) for source in sources] return StageResult(images=[Image.blend(source, stage_source, alpha) for source in sources])

View File

@ -9,6 +9,7 @@ from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,13 +21,13 @@ class BlendMaskStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, _callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("blending image using mask") logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black") mult_mask = Image.new("RGBA", stage_mask.size, color="black")
@ -37,4 +38,4 @@ class BlendMaskStage(BaseStage):
save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
return [Image.composite(stage_source, source, mult_mask) for source in sources] return StageResult(images=[Image.composite(stage_source, source, mult_mask) for source in sources])

View File

@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
# must be within the load function for patch to take effect # must be within the load function for patch to take effect
# TODO: rewrite and remove # TODO: rewrite and remove
from codeformer import CodeFormer from codeformer import CodeFormer
@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
device = worker.get_device() device = worker.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return [pipe(source) for source in sources] return StageResult(images=[pipe(source) for source in sources])

View File

@ -10,6 +10,7 @@ from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -57,12 +58,12 @@ class CorrectGFPGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
if upscale.correction_model is None: if upscale.correction_model is None:
@ -73,16 +74,12 @@ class CorrectGFPGANStage(BaseStage):
device = worker.get_device() device = worker.get_device()
gfpgan = self.load(server, stage, upscale, device) gfpgan = self.load(server, stage, upscale, device)
outputs = [] outputs = [gfpgan.enhance(
for source in sources: source,
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
has_aligned=False, has_aligned=False,
only_center_face=False, only_center_face=False,
paste_back=True, paste_back=True,
weight=upscale.face_strength, weight=upscale.face_strength,
) ) for source in sources.as_numpy()]
outputs.append(Image.fromarray(output, "RGB"))
return outputs return StageResult(images=outputs)

View File

@ -8,6 +8,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -21,13 +22,13 @@ class PersistDiskStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
output: List[str], output: List[str],
size: Optional[Size] = None, size: Optional[Size] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info( logger.info(
"persisting images to disk: %s, %s", [s.size for s in sources], output "persisting images to disk: %s, %s", [s.size for s in sources], output
) )

View File

@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,7 +21,7 @@ class PersistS3Stage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
output: str, output: str,
bucket: str, bucket: str,
@ -28,11 +29,11 @@ class PersistS3Stage(BaseStage):
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
for source in sources: for source in sources.as_image():
data = BytesIO() data = BytesIO()
source.save(data, format=server.image_format) source.save(data, format=server.image_format)
data.seek(0) data.seek(0)

View File

@ -13,6 +13,7 @@ from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .tile import needs_tile, process_tile_order from .tile import needs_tile, process_tile_order
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -73,26 +74,27 @@ class ChainPipeline:
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
callback: Optional[ProgressCallback], callback: Optional[ProgressCallback],
**kwargs **kwargs
) -> List[Image.Image]: ) -> StageResult:
return self( result = self(
worker, server, params, sources=sources, callback=callback, **kwargs worker, server, params, sources=sources, callback=callback, **kwargs
) )
return result.as_image()
def stage(self, callback: BaseStage, params: StageParams, **kwargs): def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs)) self.stages.append((callback, params, kwargs))
return self return self
def steps(self, params: ImageParams, size: Size): def steps(self, params: ImageParams, size: Size) -> int:
steps = 0 steps = 0
for callback, _params, kwargs in self.stages: for callback, _params, kwargs in self.stages:
steps += callback.steps(kwargs.get("params", params), size) steps += callback.steps(kwargs.get("params", params), size)
return steps return steps
def outputs(self, params: ImageParams, sources: int): def outputs(self, params: ImageParams, sources: int) -> int:
outputs = sources outputs = sources
for callback, _params, kwargs in self.stages: for callback, _params, kwargs in self.stages:
outputs = callback.outputs(kwargs.get("params", params), outputs) outputs = callback.outputs(kwargs.get("params", params), outputs)
@ -104,10 +106,10 @@ class ChainPipeline:
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**pipeline_kwargs **pipeline_kwargs
) -> List[Image.Image]: ) -> StageResult:
""" """
DEPRECATED: use `run` instead DEPRECATED: use `run` instead
""" """
@ -161,23 +163,21 @@ class ChainPipeline:
tile = min(stage_pipe.max_tile, stage_params.tile_size) tile = min(stage_pipe.max_tile, stage_params.tile_size)
if stage_sources or must_tile: if stage_sources or must_tile:
stage_outputs = [] stage_results = []
for source in stage_sources: for source in stage_sources:
logger.info( logger.info(
"image contains sources or is larger than tile size of %s, tiling stage", "image contains sources or is larger than tile size of %s, tiling stage",
tile, tile,
) )
extra_tiles = []
def stage_tile( def stage_tile(
source_tile: Image.Image, source_tile: Image.Image,
tile_mask: Image.Image, tile_mask: Image.Image,
dims: Tuple[int, int, int], dims: Tuple[int, int, int],
) -> Image.Image: ) -> StageResult:
for _i in range(worker.retries): for _i in range(worker.retries):
try: try:
output_tile = stage_pipe.run( tile_result = stage_pipe.run(
worker, worker,
server, server,
stage_params, stage_params,
@ -189,17 +189,11 @@ class ChainPipeline:
**kwargs, **kwargs,
) )
if len(output_tile) > 1:
while len(extra_tiles) < len(output_tile):
extra_tiles.append([])
for tile, layer in zip(output_tile, extra_tiles):
layer.append((tile, dims))
if is_debug(): if is_debug():
save_image(server, "last-tile.png", output_tile[0]) for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image)
return output_tile[0] return tile_result
except Exception: except Exception:
worker.retries = worker.retries - 1 worker.retries = worker.retries - 1
logger.exception( logger.exception(
@ -220,17 +214,9 @@ class ChainPipeline:
**kwargs, **kwargs,
) )
stage_outputs.append(output) stage_results.append(output)
if len(extra_tiles) > 1: stage_sources = StageResult(images=stage_results)
for layer in extra_tiles:
layer_output = Image.new("RGB", output.size)
for layer_tile, dims in layer:
layer_output.paste(layer_tile, (dims[0], dims[1]))
stage_outputs.append(layer_output)
stage_sources = stage_outputs
else: else:
logger.debug( logger.debug(
"image does not contain sources and is within tile size of %s, running stage", "image does not contain sources and is within tile size of %s, running stage",
@ -238,7 +224,7 @@ class ChainPipeline:
) )
for i in range(worker.retries): for i in range(worker.retries):
try: try:
stage_outputs = stage_pipe.run( stage_result = stage_pipe.run(
worker, worker,
server, server,
stage_params, stage_params,
@ -250,7 +236,7 @@ class ChainPipeline:
) )
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline # doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws # does not like, so it throws
stage_sources = stage_outputs stage_sources = stage_result
break break
except Exception: except Exception:
worker.retries = worker.retries - 1 worker.retries = worker.retries - 1

View File

@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
origin: Size, origin: Size,
size: Size, size: Size,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
image = source.crop((origin.width, origin.height, size.width, size.height)) image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info( logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height "created thumbnail with dimensions: %sx%s", image.width, image.height
) )
outputs.append(image) outputs.append(image)
return outputs return StageResult(images=outputs)

View File

@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,15 +19,15 @@ class ReduceThumbnailStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
size: Size, size: Size,
stage_source: Image.Image, stage_source: Image.Image,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
image = source.copy() image = source.copy()
image = image.thumbnail((size.width, size.height)) image = image.thumbnail((size.width, size.height))
@ -37,4 +38,4 @@ class ReduceThumbnailStage(BaseStage):
outputs.append(image) outputs.append(image)
return outputs return StageResult(images=outputs)

View File

@ -1,4 +1,4 @@
from PIL.Image import Image, fromarray from PIL import Image
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
@ -7,25 +7,35 @@ class StageResult:
""" """
Chain pipeline stage result. Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them. Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
""" """
arrays: Optional[List[np.ndarray]] arrays: Optional[List[np.ndarray]]
images: Optional[List[Image]] images: Optional[List[Image.Image]]
def __init__(self, arrays = None, images = None) -> None: def __init__(self, arrays = None, images = None) -> None:
if arrays is not None and images is not None: if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result") raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
raise ValueError("stages must return results")
self.arrays = arrays self.arrays = arrays
self.images = images self.images = images
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
else:
return len(self.images)
def as_numpy(self) -> List[np.ndarray]: def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None: if self.arrays is not None:
return self.arrays return self.arrays
return [np.array(i) for i in self.images] return [np.array(i) for i in self.images]
def as_image(self) -> List[Image]: def as_image(self) -> List[Image.Image]:
if self.images is not None: if self.images is not None:
return self.images return self.images
return [fromarray(i) for i in self.arrays] return [Image.fromarray(i, "RGB") for i in self.arrays]

View File

@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,13 +19,13 @@ class SourceNoiseStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
size: Size, size: Size,
noise_source: Callable, noise_source: Callable,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("generating image from noise source") logger.info("generating image from noise source")
if len(sources) > 0: if len(sources) > 0:
@ -32,16 +33,16 @@ class SourceNoiseStage(BaseStage):
"source images were passed to a source stage, new images will be appended" "source images were passed to a source stage, new images will be appended"
) )
outputs = list(sources) outputs = []
# TODO: looping over sources and ignoring params does not make much sense for a source stage # TODO: looping over sources and ignoring params does not make much sense for a source stage
for source in sources: for source in sources.as_image():
output = noise_source(source, (size.width, size.height), (0, 0)) output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(images=outputs)
def outputs( def outputs(
self, self,

View File

@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,14 +21,14 @@ class SourceS3Stage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
source_keys: List[str], source_keys: List[str],
bucket: str, bucket: str,
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
@ -36,7 +37,7 @@ class SourceS3Stage(BaseStage):
"source images were passed to a source stage, new images will be appended" "source images were passed to a source stage, new images will be appended"
) )
outputs = list(sources) outputs = sources.as_image()
for key in source_keys: for key in source_keys:
try: try:
logger.info("loading image from s3://%s/%s", bucket, key) logger.info("loading image from s3://%s/%s", bucket, key)
@ -48,7 +49,7 @@ class SourceS3Stage(BaseStage):
except Exception: except Exception:
logger.exception("error loading image from S3") logger.exception("error loading image from S3")
return outputs return StageResult(outputs)
def outputs( def outputs(
self, self,

View File

@ -19,6 +19,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -32,7 +33,7 @@ class SourceTxt2ImgStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
dims: Tuple[int, int, int] = None, dims: Tuple[int, int, int] = None,
size: Size, size: Size,
@ -153,10 +154,10 @@ class SourceTxt2ImgStage(BaseStage):
callback=callback, callback=callback,
) )
output = list(sources) outputs = list(sources)
output.extend(result.images) outputs.extend(result.images)
logger.debug("produced %s outputs", len(output)) logger.debug("produced %s outputs", len(outputs))
return output return StageResult(images=outputs)
def steps( def steps(
self, self,

View File

@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,12 +21,12 @@ class SourceURLStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
source_urls: List[str], source_urls: List[str],
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("loading image from URL source") logger.info("loading image from URL source")
if len(sources) > 0: if len(sources) > 0:
@ -41,7 +42,7 @@ class SourceURLStage(BaseStage):
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(images=outputs)
def outputs( def outputs(
self, self,

View File

@ -9,6 +9,7 @@ from PIL import Image
from ..image.noise_source import noise_source_histogram from ..image.noise_source import noise_source_histogram
from ..params import Size, TileOrder from ..params import Size, TileOrder
from .result import StageResult
# from skimage.exposure import match_histograms # from skimage.exposure import match_histograms
@ -21,7 +22,7 @@ class TileCallback(Protocol):
Definition for a tile job function. Definition for a tile job function.
""" """
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image: def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
""" """
Run this stage against a single tile. Run this stage against a single tile.
""" """

View File

@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -54,12 +55,12 @@ class UpscaleBSRGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None: if upscale.upscale_model is None:
@ -71,8 +72,8 @@ class UpscaleBSRGANStage(BaseStage):
bsrgan = self.load(server, stage, upscale, device) bsrgan = self.load(server, stage, upscale, device)
outputs = [] outputs = []
for source in sources: for source in sources.as_numpy():
image = np.array(source) / 255.0 image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape) logger.trace("BSRGAN input shape: %s", image.shape)
@ -99,7 +100,7 @@ class UpscaleBSRGANStage(BaseStage):
outputs.append(output) outputs.append(output)
return outputs return StageResult(images=outputs)
def steps( def steps(
self, self,

View File

@ -9,6 +9,7 @@ from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from ..worker.context import ProgressCallback from ..worker.context import ProgressCallback
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*args, *args,
highres: HighresParams, highres: HighresParams,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
if highres.scale <= 1: if highres.scale <= 1:
return sources return sources
chain = stage_highres(stage, params, highres, upscale) chain = stage_highres(stage, params, highres, upscale)
return [ outputs = [
chain( chain(
worker, worker,
server, server,
@ -43,3 +44,5 @@ class UpscaleHighresStage(BaseStage):
) )
for source in sources for source in sources
] ]
return StageResult(images=outputs)

View File

@ -19,6 +19,7 @@ from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
border: Border, border: Border,
dims: Tuple[int, int, int], dims: Tuple[int, int, int],
@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage):
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params params
) )
@ -61,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage):
) )
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
if is_debug(): if is_debug():
save_image(server, "tile-source.png", source) save_image(server, "tile-source.png", source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "tile-mask.png", tile_mask)
@ -122,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
outputs.extend(result.images) outputs.extend(result.images)
return outputs return StageResult(images=outputs)

View File

@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -77,25 +78,22 @@ class UpscaleRealESRGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
upsampler = self.load(
server, upscale, worker.get_device(), tile=stage.tile_size
)
outputs = [] outputs = []
for source in sources: for source in sources.as_numpy():
output = np.array(source) output, _ = upsampler.enhance(source, outscale=upscale.outscale)
upsampler = self.load( logger.info("final output image size: %s", output.shape)
server, upscale, worker.get_device(), tile=stage.tile_size
)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(arrays=outputs)

View File

@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
method: str, method: str,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
if upscale.scale <= 1: if upscale.scale <= 1:
logger.debug( logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale "simple upscale stage run with scale of %s, skipping", upscale.scale
@ -32,7 +33,7 @@ class UpscaleSimpleStage(BaseStage):
return sources return sources
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
scaled_size = (source.width * upscale.scale, source.height * upscale.scale) scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear": if method == "bilinear":

View File

@ -11,6 +11,7 @@ from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
logger.info( logger.info(
@ -58,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage):
pipeline.unet.set_prompts(prompt_embeds) pipeline.unet.set_prompts(prompt_embeds)
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
result = pipeline( result = pipeline(
prompt, prompt,
source, source,

View File

@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None: if upscale.upscale_model is None:
@ -71,31 +72,27 @@ class UpscaleSwinIRStage(BaseStage):
swinir = self.load(server, stage, upscale, device) swinir = self.load(server, stage, upscale, device)
outputs = [] outputs = []
for source in sources: for source in sources.as_numpy():
# TODO: add support for grayscale (1-channel) images # TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0 image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
logger.trace("SwinIR input shape: %s", image.shape) logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( logger.trace("SwinIR output shape: %s", (
(
image.shape[0], image.shape[0],
image.shape[1], image.shape[1],
image.shape[2] * scale, image.shape[2] * scale,
image.shape[3] * scale, image.shape[3] * scale,
) ))
)
logger.trace("SwinIR output shape: %s", dest.shape)
dest = swinir(image) output = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) output = np.clip(np.squeeze(output, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8) output = (output * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB") logger.info("output image size: %s", output.shape)
logger.info("output image size: %s x %s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(images=outputs)

View File

@ -486,7 +486,7 @@ def run_blend_pipeline(
outputs: List[str], outputs: List[str],
upscale: UpscaleParams, upscale: UpscaleParams,
# highres: HighresParams, # highres: HighresParams,
sources: List[Image.Image], sources: StageResult,
mask: Image.Image, mask: Image.Image,
) -> None: ) -> None:
# set up the chain pipeline and base stage # set up the chain pipeline and base stage