1
0
Fork 0

start updating chain logic for multiple outputs

This commit is contained in:
Sean Sube 2023-07-04 13:47:31 -05:00
parent 37185252a5
commit e1fcbb9093
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 65 additions and 56 deletions

View File

@ -90,7 +90,7 @@ class ChainPipeline:
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: List[Image.Image],
sources: List[Image.Image],
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> List[Image.Image]:
@ -102,42 +102,50 @@ class ChainPipeline:
start = monotonic()
# TODO: turn this into stage images
image = source
if source is not None:
if len(sources) > 0:
logger.info(
"running pipeline on source image with dimensions %sx%s",
source.width,
source.height,
"running pipeline on %s source images",
len(sources),
)
else:
logger.info("running pipeline without source image")
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
if image is not None:
if len(stage_sources) > 0:
logger.debug(
"running stage %s with source size of %sx%s, parameters: %s",
"running stage %s with %s source images, parameters: %s",
name,
image.width,
image.height,
len(stage_sources),
kwargs.keys(),
)
else:
logger.debug(
"running stage %s without source image, %s", name, kwargs.keys()
"running stage %s without source images, parameters: %s",
name,
kwargs.keys(),
)
if needs_tile(
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=image,
):
source=source,
)
for source in stage_sources
]
)
if must_tile:
stage_outputs = []
for source in stage_sources:
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
@ -147,58 +155,59 @@ class ChainPipeline:
tile,
)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe.run(
def stage_tile(source_tile: Image.Image, _dims) -> Image.Image:
output_tile = stage_pipe.run(
job,
server,
stage_params,
params,
tile,
source_tile,
callback=callback,
**kwargs,
)
if is_debug():
save_image(server, "last-tile.png", tile)
save_image(server, "last-tile.png", output_tile)
return tile
return output_tile
image = process_tile_order(
output = process_tile_order(
stage_params.tile_order,
image,
source,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_outputs.append(output)
stage_sources = stage_outputs
else:
logger.debug("image within tile size of %s, running stage", tile)
image = stage_pipe.run(
stage_sources = stage_pipe.run(
job,
server,
stage_params,
params,
image,
stage_sources,
callback=callback,
**kwargs,
)
logger.debug(
"finished stage %s with result size of %sx%s",
"finished stage %s with %s results",
name,
image.width,
image.height,
len(stage_sources),
)
if is_debug():
save_image(server, "last-stage.png", image)
save_image(server, "last-stage.png", stage_sources[0])
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info(
"finished pipeline in %s with result size of %sx%s",
"finished pipeline in %s with %s results",
duration,
image.width,
image.height,
len(stage_outputs),
)
return image
return stage_outputs