From 1cf8c7eff609ace0115bbb04701929508dd3adf0 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 12 Jan 2024 23:28:52 -0600 Subject: [PATCH] make pipeline progress more accurate and complete --- api/onnx_web/chain/pipeline.py | 10 ++++++---- api/onnx_web/chain/tile.py | 11 +++++++---- api/onnx_web/worker/context.py | 6 +++--- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index ad4b3e66..b21bff6e 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -124,12 +124,13 @@ class ChainPipeline: callback = ChainProgress.from_progress(callback) # set estimated totals - steps = params.steps if "size" in pipeline_kwargs and isinstance(pipeline_kwargs["size"], Size): size = pipeline_kwargs["size"] - steps = self.steps(params, size) + else: + size = sources.size() - worker.set_totals(steps, stages=len(self.stages), tiles=0) + total_steps = self.steps(params, size) + worker.set_totals(total_steps, stages=len(self.stages), tiles=0) start = monotonic() @@ -182,6 +183,7 @@ class ChainPipeline: source_tile: List[Image.Image], tile_mask: Image.Image, dims: Tuple[int, int, int], + progress: Tuple[int, int], ) -> List[Image.Image]: for _i in range(worker.retries): try: @@ -203,7 +205,7 @@ class ChainPipeline: for j, image in enumerate(tile_result.as_image()): save_image(server, f"last-tile-{j}.png", image) - worker.set_tiles(worker.tiles.current + 1) + worker.set_tiles(current=progress[0], total=progress[1]) return tile_result except CancelledException as err: diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 03441a12..6b9b9b70 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -26,7 +26,7 @@ class TileCallback(Protocol): """ def __call__( - self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int] + self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int], progress: Tuple[int, int] ) -> StageResult: """ Run this stage against a single tile. @@ -268,11 +268,13 @@ def process_tile_stack( tiles: List[Tuple[int, int, Image.Image]] = [] tile_coords = tile_generator(width, height, tile, overlap) - single_tile = len(tile_coords) == 1 + + total_tiles = len(tile_coords) + single_tile = total_tiles == 1 for counter, (left, top) in enumerate(tile_coords): logger.info( - "processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top + "processing tile %s of %s, %sx%s", counter, total_tiles, left, top ) right = left + tile @@ -341,8 +343,9 @@ def process_tile_stack( tile_mask = mask.crop((left, top, right, bottom)) for image_filter in filters: - tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) + tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile), (counter, total_tiles)) + # TODO: this should be inverted to extract them from the result if isinstance(tile_stack, list): tile_stack = StageResult.from_images(tile_stack, metadata=stack.metadata) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index ebfbd462..e55a563d 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -189,9 +189,9 @@ class WorkerContext: self.tiles = self.tiles.update(current) def set_totals(self, steps: int, stages: int = 0, tiles: int = 0) -> None: - self.steps.total = steps - self.stages.total = stages - self.tiles.total = tiles + self.steps = Progress(0, steps) + self.stages.total = Progress(0, stages) + self.tiles.total = Progress(0, tiles) def finish(self) -> None: if self.job is None: