diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 7bfd49a5..63261d29 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -14,7 +14,7 @@ from ..utils import ( ServerContext, ) from .utils import ( - process_tiles, + process_tile_grid, ) logger = getLogger(__name__) @@ -86,7 +86,7 @@ class ChainPipeline: return tile - image = process_tiles( + image = process_tile_grid( image, stage_params.tile_size, stage_params.outscale, [stage_tile]) else: logger.info('source image within tile size, running stage') diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 5ed461db..8af082d6 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -27,7 +27,7 @@ from ..utils import ( ServerContext, ) from .utils import ( - process_tiles, + process_tile_grid, ) import numpy as np @@ -98,7 +98,7 @@ def blend_inpaint( ) return result.images[0] - output = process_tiles(source_image, SizeChart.auto, 1, [outpaint]) + output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) logger.info('final output image size', output.size) return output diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index f6957aa8..7712a976 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -27,7 +27,7 @@ from ..utils import ( ServerContext, ) from .utils import ( - process_tiles, + process_tile_spiral, ) import numpy as np @@ -98,7 +98,7 @@ def upscale_outpaint( ) return result.images[0] - output = process_tiles(source_image, SizeChart.auto, 1, [outpaint]) + output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint]) logger.info('final output image size: %sx%s', output.width, output.height) return output diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index db98da3c..15119de1 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -10,7 +10,7 @@ class TileCallback(Protocol): pass -def process_tiles( +def process_tile_grid( source: Image.Image, tile: int, scale: int, @@ -37,3 +37,47 @@ def process_tiles( image.paste(tile_image, (left * scale, top * scale)) return image + + +def process_tile_spiral( + source: Image.Image, + tile: int, + scale: int, + filters: List[TileCallback], + overlap: float = 0.5, +) -> Image.Image: + if scale != 1: + raise Exception('unsupported scale') + + width, height = source.size + image = Image.new('RGB', (width * scale, height * scale)) + image.paste(source, (0, 0)) + + # TODO: only valid for overlap = 0.5 + if overlap == 0.5: + tiles = [ + (0, tile * -overlap), + (tile * overlap, tile * -overlap), + (tile * overlap, 0), + (tile * overlap, tile * overlap), + (0, tile * overlap), + (tile * -overlap, tile * -overlap), + (tile * -overlap, 0), + (tile * -overlap, tile * overlap), + ] + + # tile tuples is source, multiply by scale for dest + counter = 0 + for left, top in tiles: + logger.info('processing tile %s of %s', counter, len(tiles)) + counter += 1 + + # TODO: only valid for scale == 1, resize source for others + tile_image = image.crop((left, top, left + tile, top + tile)) + + for filter in filters: + tile_image = filter(tile_image, (left, top, tile)) + + image.paste(tile_image, (left * scale, top * scale)) + + return image diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 64c0be9d..cbd76208 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -97,15 +97,16 @@ def get_size(val: Union[int, str, None]) -> SizeChart: if val is None: return SizeChart.auto - if type(val) is str: - if val in SizeChart: - return SizeChart[val] - else: - return int(val) - if type(val) is int: return val + if type(val) is str: + for size in SizeChart: + if val == size.name: + return size + + return int(val) + raise Exception('invalid size') diff --git a/common/pipelines/example.json b/common/pipelines/example.json index f059c599..547991e0 100644 --- a/common/pipelines/example.json +++ b/common/pipelines/example.json @@ -30,7 +30,7 @@ "name": "save-local", "type": "persist-disk", "params": { - "tile_size": "8k" + "tile_size": "hd8k" } }, { @@ -40,7 +40,7 @@ "bucket": "storage-stable-diffusion", "endpoint_url": "http://scylla.home.holdmyran.ch:8000", "profile_name": "ceph", - "tile_size": "8k" + "tile_size": "hd8k" } } ]