diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 2710df9e..8ada7677 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -1,4 +1,5 @@ from logging import getLogger +from math import ceil from typing import Any, List, Optional from PIL import Image, ImageOps @@ -221,6 +222,8 @@ def run_inpaint_pipeline( tile_order: str, ) -> None: logger.debug("building inpaint pipeline") + tile_size = params.tiles + full_res_inpaint = False if mask is None: # if no mask was provided, keep the full source image @@ -240,6 +243,36 @@ def run_inpaint_pipeline( mask_filter=mask_filter, ) + # check if we can do full-res inpainting if no outpainting is done + logger.debug("border zero: %s", border.isZero()) + if border.isZero(): + mask_left, mask_top, mask_right, mask_bottom = mask.getbbox() + logger.debug("mask bbox: %s", mask.getbbox()) + mask_width = mask_right - mask_left + mask_height = mask_bottom - mask_top + # ensure we have some padding around the mask when we do the inpaint (and that the region size is even) + adj_mask_size = ceil(max(mask_width, mask_height) * 1.5 / 2) * 2 + logger.debug("adjusted mask size %s", adj_mask_size) + if adj_mask_size <= tile_size: + full_res_inpaint = True + original_source = source + mask_center_x = int(round((mask_right + mask_left) / 2)) + mask_center_y = int(round((mask_bottom + mask_top) / 2)) + adj_mask_border = ( + int(mask_center_x - adj_mask_size / 2), + int(mask_center_y - adj_mask_size / 2), + int(mask_center_x + adj_mask_size / 2), + int(mask_center_y + adj_mask_size / 2), + ) + logger.debug("mask bounding box: %s", adj_mask_border) + source = source.crop(adj_mask_border) + source = ImageOps.contain(source, (tile_size, tile_size)) + mask = mask.crop(adj_mask_border) + mask = ImageOps.contain(mask, (tile_size, tile_size)) + if is_debug(): + save_image(server, "adjusted-mask.png", mask) + save_image(server, "adjusted-source.png", source) + if is_debug(): save_image(server, "full-source.png", source) save_image(server, "full-mask.png", mask) @@ -247,7 +280,7 @@ def run_inpaint_pipeline( # set up the chain pipeline and base stage chain = ChainPipeline() - stage = StageParams(tile_order=tile_order, tile_size=params.tiles) + stage = StageParams(tile_order=tile_order, tile_size=tile_size) chain.stage( UpscaleOutpaintStage(), stage, @@ -292,6 +325,12 @@ def run_inpaint_pipeline( _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): + if full_res_inpaint: + if is_debug(): + save_image(server, "adjusted-output.png", image) + mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size)) + image = original_source + image.paste(mini_image, box=adj_mask_border) dest = save_image( server, output, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 4a077fb8..06e038f1 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -43,6 +43,11 @@ class Border: def __str__(self) -> str: return "(%s, %s, %s, %s)" % (self.left, self.right, self.top, self.bottom) + def isZero(self) -> bool: + return all( + value == 0 for value in (self.left, self.right, self.top, self.bottom) + ) + def tojson(self): return { "left": self.left,