fix(api): preserve new pixels after outpainting
This commit is contained in:
parent
20beff839e
commit
7083505483
|
@ -2,7 +2,7 @@ from diffusers import (
|
|||
OnnxStableDiffusionInpaintPipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageDraw
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from ..diffusion.load import (
|
||||
|
@ -41,7 +41,7 @@ def upscale_outpaint(
|
|||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
expand: Border,
|
||||
border: Border,
|
||||
prompt: str = None,
|
||||
mask_image: Image.Image = None,
|
||||
fill_color: str = 'white',
|
||||
|
@ -50,7 +50,7 @@ def upscale_outpaint(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('upscaling image by expanding borders: %s', expand)
|
||||
logger.info('upscaling image by expanding borders: %s', border)
|
||||
|
||||
if mask_image is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
|
@ -59,11 +59,13 @@ def upscale_outpaint(
|
|||
source_image, mask_image, noise_image, _full_dims = expand_image(
|
||||
source_image,
|
||||
mask_image,
|
||||
expand,
|
||||
border,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter)
|
||||
|
||||
draw_mask = ImageDraw.Draw(mask_image)
|
||||
|
||||
if is_debug():
|
||||
source_image.save(base_join(ctx.output_path, 'last-source.png'))
|
||||
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
|
||||
|
@ -98,6 +100,9 @@ def upscale_outpaint(
|
|||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
)
|
||||
|
||||
# once part of the image has been drawn, keep it
|
||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill='black')
|
||||
return result.images[0]
|
||||
|
||||
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
|
||||
|
|
|
@ -76,7 +76,7 @@ def process_tile_spiral(
|
|||
top = center_y + int(top)
|
||||
|
||||
counter += 1
|
||||
logger.info('processing tile %s of %s', counter, len(tiles))
|
||||
logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top)
|
||||
|
||||
# TODO: only valid for scale == 1, resize source for others
|
||||
tile_image = image.crop((left, top, left + tile, top + tile))
|
||||
|
|
Loading…
Reference in New Issue