From e1fcbb9093ce8757ab136ca37d151b946e835e69 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 4 Jul 2023 13:47:31 -0500 Subject: [PATCH] start updating chain logic for multiple outputs --- api/onnx_web/chain/base.py | 121 ++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 56 deletions(-) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 643b7fa1..eae4e45f 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -90,7 +90,7 @@ class ChainPipeline: job: WorkerContext, server: ServerContext, params: ImageParams, - source: List[Image.Image], + sources: List[Image.Image], callback: Optional[ProgressCallback] = None, **pipeline_kwargs ) -> List[Image.Image]: @@ -102,103 +102,112 @@ class ChainPipeline: start = monotonic() - # TODO: turn this into stage images - image = source - - if source is not None: + if len(sources) > 0: logger.info( - "running pipeline on source image with dimensions %sx%s", - source.width, - source.height, + "running pipeline on %s source images", + len(sources), ) else: - logger.info("running pipeline without source image") + logger.info("running pipeline without source images") + stage_sources = sources for stage_pipe, stage_params, stage_kwargs in self.stages: name = stage_params.name or stage_pipe.__class__.__name__ kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} - if image is not None: + if len(stage_sources) > 0: logger.debug( - "running stage %s with source size of %sx%s, parameters: %s", + "running stage %s with %s source images, parameters: %s", name, - image.width, - image.height, + len(stage_sources), kwargs.keys(), ) else: logger.debug( - "running stage %s without source image, %s", name, kwargs.keys() + "running stage %s without source images, parameters: %s", + name, + kwargs.keys(), ) - if needs_tile( - stage_pipe.max_tile, - stage_params.tile_size, - size=kwargs.get("size", None), - source=image, - ): - tile = stage_params.tile_size - if stage_pipe.max_tile > 0: - tile = min(stage_pipe.max_tile, stage_params.tile_size) + # the stage must be split and tiled if any image is larger than the selected/max tile size + must_tile = any( + [ + needs_tile( + stage_pipe.max_tile, + stage_params.tile_size, + size=kwargs.get("size", None), + source=source, + ) + for source in stage_sources + ] + ) - logger.info( - "image larger than tile size of %s, tiling stage", - tile, - ) + if must_tile: + stage_outputs = [] + for source in stage_sources: + tile = stage_params.tile_size + if stage_pipe.max_tile > 0: + tile = min(stage_pipe.max_tile, stage_params.tile_size) - def stage_tile(tile: Image.Image, _dims) -> Image.Image: - tile = stage_pipe.run( - job, - server, - stage_params, - params, + logger.info( + "image larger than tile size of %s, tiling stage", tile, - callback=callback, - **kwargs, ) - if is_debug(): - save_image(server, "last-tile.png", tile) + def stage_tile(source_tile: Image.Image, _dims) -> Image.Image: + output_tile = stage_pipe.run( + job, + server, + stage_params, + params, + source_tile, + callback=callback, + **kwargs, + ) - return tile + if is_debug(): + save_image(server, "last-tile.png", output_tile) - image = process_tile_order( - stage_params.tile_order, - image, - tile, - stage_params.outscale, - [stage_tile], - **kwargs, - ) + return output_tile + + output = process_tile_order( + stage_params.tile_order, + source, + tile, + stage_params.outscale, + [stage_tile], + **kwargs, + ) + stage_outputs.append(output) + + stage_sources = stage_outputs else: logger.debug("image within tile size of %s, running stage", tile) - image = stage_pipe.run( + stage_sources = stage_pipe.run( job, server, stage_params, params, - image, + stage_sources, callback=callback, **kwargs, ) logger.debug( - "finished stage %s with result size of %sx%s", + "finished stage %s with %s results", name, - image.width, - image.height, + len(stage_sources), ) if is_debug(): - save_image(server, "last-stage.png", image) + save_image(server, "last-stage.png", stage_sources[0]) end = monotonic() duration = timedelta(seconds=(end - start)) logger.info( - "finished pipeline in %s with result size of %sx%s", + "finished pipeline in %s with %s results", duration, - image.width, - image.height, + len(stage_outputs), ) - return image + return stage_outputs