1
0
Fork 0

start updating chain logic for multiple outputs

This commit is contained in:
Sean Sube 2023-07-04 13:47:31 -05:00
parent 37185252a5
commit e1fcbb9093
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 65 additions and 56 deletions

View File

@ -90,7 +90,7 @@ class ChainPipeline:
job: WorkerContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
source: List[Image.Image], sources: List[Image.Image],
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**pipeline_kwargs **pipeline_kwargs
) -> List[Image.Image]: ) -> List[Image.Image]:
@ -102,42 +102,50 @@ class ChainPipeline:
start = monotonic() start = monotonic()
# TODO: turn this into stage images if len(sources) > 0:
image = source
if source is not None:
logger.info( logger.info(
"running pipeline on source image with dimensions %sx%s", "running pipeline on %s source images",
source.width, len(sources),
source.height,
) )
else: else:
logger.info("running pipeline without source image") logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages: for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__ name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {} kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs} kwargs = {**pipeline_kwargs, **kwargs}
if image is not None: if len(stage_sources) > 0:
logger.debug( logger.debug(
"running stage %s with source size of %sx%s, parameters: %s", "running stage %s with %s source images, parameters: %s",
name, name,
image.width, len(stage_sources),
image.height,
kwargs.keys(), kwargs.keys(),
) )
else: else:
logger.debug( logger.debug(
"running stage %s without source image, %s", name, kwargs.keys() "running stage %s without source images, parameters: %s",
name,
kwargs.keys(),
) )
if needs_tile( # the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile, stage_pipe.max_tile,
stage_params.tile_size, stage_params.tile_size,
size=kwargs.get("size", None), size=kwargs.get("size", None),
source=image, source=source,
): )
for source in stage_sources
]
)
if must_tile:
stage_outputs = []
for source in stage_sources:
tile = stage_params.tile_size tile = stage_params.tile_size
if stage_pipe.max_tile > 0: if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size) tile = min(stage_pipe.max_tile, stage_params.tile_size)
@ -147,58 +155,59 @@ class ChainPipeline:
tile, tile,
) )
def stage_tile(tile: Image.Image, _dims) -> Image.Image: def stage_tile(source_tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe.run( output_tile = stage_pipe.run(
job, job,
server, server,
stage_params, stage_params,
params, params,
tile, source_tile,
callback=callback, callback=callback,
**kwargs, **kwargs,
) )
if is_debug(): if is_debug():
save_image(server, "last-tile.png", tile) save_image(server, "last-tile.png", output_tile)
return tile return output_tile
image = process_tile_order( output = process_tile_order(
stage_params.tile_order, stage_params.tile_order,
image, source,
tile, tile,
stage_params.outscale, stage_params.outscale,
[stage_tile], [stage_tile],
**kwargs, **kwargs,
) )
stage_outputs.append(output)
stage_sources = stage_outputs
else: else:
logger.debug("image within tile size of %s, running stage", tile) logger.debug("image within tile size of %s, running stage", tile)
image = stage_pipe.run( stage_sources = stage_pipe.run(
job, job,
server, server,
stage_params, stage_params,
params, params,
image, stage_sources,
callback=callback, callback=callback,
**kwargs, **kwargs,
) )
logger.debug( logger.debug(
"finished stage %s with result size of %sx%s", "finished stage %s with %s results",
name, name,
image.width, len(stage_sources),
image.height,
) )
if is_debug(): if is_debug():
save_image(server, "last-stage.png", image) save_image(server, "last-stage.png", stage_sources[0])
end = monotonic() end = monotonic()
duration = timedelta(seconds=(end - start)) duration = timedelta(seconds=(end - start))
logger.info( logger.info(
"finished pipeline in %s with result size of %sx%s", "finished pipeline in %s with %s results",
duration, duration,
image.width, len(stage_outputs),
image.height,
) )
return image return stage_outputs