diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 1719c1e4..cbcf5a67 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -2,7 +2,7 @@ from diffusers import ( OnnxStableDiffusionInpaintPipeline, ) from logging import getLogger -from PIL import Image +from PIL import Image, ImageDraw from typing import Callable, Tuple from ..diffusion.load import ( @@ -41,7 +41,7 @@ def upscale_outpaint( params: ImageParams, source_image: Image.Image, *, - expand: Border, + border: Border, prompt: str = None, mask_image: Image.Image = None, fill_color: str = 'white', @@ -50,7 +50,7 @@ def upscale_outpaint( **kwargs, ) -> Image.Image: prompt = prompt or params.prompt - logger.info('upscaling image by expanding borders: %s', expand) + logger.info('upscaling image by expanding borders: %s', border) if mask_image is None: # if no mask was provided, keep the full source image @@ -59,11 +59,13 @@ def upscale_outpaint( source_image, mask_image, noise_image, _full_dims = expand_image( source_image, mask_image, - expand, + border, fill=fill_color, noise_source=noise_source, mask_filter=mask_filter) + draw_mask = ImageDraw.Draw(mask_image) + if is_debug(): source_image.save(base_join(ctx.output_path, 'last-source.png')) mask_image.save(base_join(ctx.output_path, 'last-mask.png')) @@ -98,6 +100,9 @@ def upscale_outpaint( num_inference_steps=params.steps, width=size.width, ) + + # once part of the image has been drawn, keep it + draw_mask.rectangle((left, top, left + tile, top + tile), fill='black') return result.images[0] output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint]) diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index 41f36ec4..e4f2162d 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -76,7 +76,7 @@ def process_tile_spiral( top = center_y + int(top) counter += 1 - logger.info('processing tile %s of %s', counter, len(tiles)) + logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top) # TODO: only valid for scale == 1, resize source for others tile_image = image.crop((left, top, left + tile, top + tile))