From 6fe278c744041cf5b7d33299f6e576b9f5dce76a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 5 Feb 2023 16:56:11 -0600 Subject: [PATCH] fix(api): switch between spiral and grid tiling based on outpaint margins (#101) --- api/onnx_web/chain/upscale_outpaint.py | 13 +++++++++++-- api/onnx_web/chain/utils.py | 23 +++++++++++------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 253a881c..602a7438 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -11,7 +11,7 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..utils import ServerContext, is_debug -from .utils import process_tile_spiral +from .utils import process_tile_grid, process_tile_spiral logger = getLogger(__name__) @@ -92,7 +92,16 @@ def upscale_outpaint( draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") return result.images[0] - output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint]) + margin_x = float(max(border.left, border.right)) + margin_y = float(max(border.top, border.bottom)) + overlap = min(margin_x / source_image.width, margin_y / source_image.height) + + if overlap > 0 and border.left == border.right and border.top == border.bottom: + logger.debug("outpainting with an even border, using spiral tiling") + output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint], overlap=overlap) + else: + logger.debug("outpainting with an uneven border, using grid tiling") + output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) logger.info("final output image size: %sx%s", output.width, output.height) return output diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index 0e112303..1bb3ac22 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -57,18 +57,17 @@ def process_tile_spiral( center_x = (width // 2) - (tile // 2) center_y = (height // 2) - (tile // 2) - # TODO: only valid for overlap = 0.5 - if overlap == 0.5: - tiles = [ - (0, tile * -overlap), - (tile * overlap, tile * -overlap), - (tile * overlap, 0), - (tile * overlap, tile * overlap), - (0, tile * overlap), - (tile * -overlap, tile * overlap), - (tile * -overlap, 0), - (tile * -overlap, tile * -overlap), - ] + # TODO: should add/remove tiles when overlap != 0.5 + tiles = [ + (0, tile * -overlap), + (tile * overlap, tile * -overlap), + (tile * overlap, 0), + (tile * overlap, tile * overlap), + (0, tile * overlap), + (tile * -overlap, tile * overlap), + (tile * -overlap, 0), + (tile * -overlap, tile * -overlap), + ] # tile tuples is source, multiply by scale for dest counter = 0