From 1247cb7307359883842e104d141ada470b75b571 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 11 Sep 2023 07:28:20 -0500 Subject: [PATCH] super hacky multi tiling --- api/onnx_web/chain/base.py | 25 ++++++++++++++++++++++--- api/onnx_web/chain/blend_grid.py | 2 +- api/onnx_web/chain/tile.py | 3 +++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 34b1f5f2..077f5998 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -152,6 +152,8 @@ class ChainPipeline: tile, ) + extra_tiles = [] + def stage_tile( source_tile: Image.Image, tile_mask: Image.Image, @@ -169,12 +171,19 @@ class ChainPipeline: callback=callback, dims=dims, **kwargs, - )[0] + ) + + if len(output_tile) > 1: + while len(extra_tiles) < len(output_tile): + extra_tiles.append([]) + + for tile, layer in zip(output_tile, extra_tiles): + layer.append((tile, dims)) if is_debug(): - save_image(server, "last-tile.png", output_tile) + save_image(server, "last-tile.png", output_tile[0]) - return output_tile + return output_tile[0] except Exception: logger.exception( "error while running stage pipeline for tile, retry %s of 3", @@ -194,8 +203,17 @@ class ChainPipeline: [stage_tile], **kwargs, ) + stage_outputs.append(output) + if len(extra_tiles) > 1: + for layer in extra_tiles: + layer_output = Image.new("RGB", output.size) + for tile, dims in layer: + layer_output.paste(tile, (dims[0], dims[1])) + + stage_outputs.append(layer_output) + stage_sources = stage_outputs else: logger.debug("image within tile size of %s, running stage", tile) @@ -208,6 +226,7 @@ class ChainPipeline: per_stage_params, stage_sources, callback=callback, + dims=(0, 0, tile), **kwargs, ) # doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index cea63322..5a23f779 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -47,4 +47,4 @@ class BlendGridStage(BaseStage): n = order[i] output.paste(sources[n], (x * size[0], y * size[1])) - return [output] + return [*sources, output] diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 8b7898c6..c80a5719 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -343,6 +343,9 @@ def process_tile_order( filters: List[TileCallback], **kwargs, ) -> Image.Image: + """ + TODO: needs to handle more than one image + """ if order == TileOrder.grid: logger.debug("using grid tile order with tile size: %s", tile) return process_tile_grid(source, tile, scale, filters, **kwargs)