full-res inpaint
This commit is contained in:
parent
c15f750821
commit
b9603faccd
|
@ -1,4 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from math import ceil
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
@ -221,6 +222,8 @@ def run_inpaint_pipeline(
|
||||||
tile_order: str,
|
tile_order: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("building inpaint pipeline")
|
logger.debug("building inpaint pipeline")
|
||||||
|
tile_size = params.tiles
|
||||||
|
full_res_inpaint = False
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# if no mask was provided, keep the full source image
|
# if no mask was provided, keep the full source image
|
||||||
|
@ -240,6 +243,36 @@ def run_inpaint_pipeline(
|
||||||
mask_filter=mask_filter,
|
mask_filter=mask_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check if we can do full-res inpainting if no outpainting is done
|
||||||
|
logger.debug("border zero: %s", border.isZero())
|
||||||
|
if border.isZero():
|
||||||
|
mask_left, mask_top, mask_right, mask_bottom = mask.getbbox()
|
||||||
|
logger.debug("mask bbox: %s", mask.getbbox())
|
||||||
|
mask_width = mask_right - mask_left
|
||||||
|
mask_height = mask_bottom - mask_top
|
||||||
|
# ensure we have some padding around the mask when we do the inpaint (and that the region size is even)
|
||||||
|
adj_mask_size = ceil(max(mask_width, mask_height) * 1.5 / 2) * 2
|
||||||
|
logger.debug("adjusted mask size %s", adj_mask_size)
|
||||||
|
if adj_mask_size <= tile_size:
|
||||||
|
full_res_inpaint = True
|
||||||
|
original_source = source
|
||||||
|
mask_center_x = int(round((mask_right + mask_left) / 2))
|
||||||
|
mask_center_y = int(round((mask_bottom + mask_top) / 2))
|
||||||
|
adj_mask_border = (
|
||||||
|
int(mask_center_x - adj_mask_size / 2),
|
||||||
|
int(mask_center_y - adj_mask_size / 2),
|
||||||
|
int(mask_center_x + adj_mask_size / 2),
|
||||||
|
int(mask_center_y + adj_mask_size / 2),
|
||||||
|
)
|
||||||
|
logger.debug("mask bounding box: %s", adj_mask_border)
|
||||||
|
source = source.crop(adj_mask_border)
|
||||||
|
source = ImageOps.contain(source, (tile_size, tile_size))
|
||||||
|
mask = mask.crop(adj_mask_border)
|
||||||
|
mask = ImageOps.contain(mask, (tile_size, tile_size))
|
||||||
|
if is_debug():
|
||||||
|
save_image(server, "adjusted-mask.png", mask)
|
||||||
|
save_image(server, "adjusted-source.png", source)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "full-source.png", source)
|
save_image(server, "full-source.png", source)
|
||||||
save_image(server, "full-mask.png", mask)
|
save_image(server, "full-mask.png", mask)
|
||||||
|
@ -247,7 +280,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
|
stage = StageParams(tile_order=tile_order, tile_size=tile_size)
|
||||||
chain.stage(
|
chain.stage(
|
||||||
UpscaleOutpaintStage(),
|
UpscaleOutpaintStage(),
|
||||||
stage,
|
stage,
|
||||||
|
@ -292,6 +325,12 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
if full_res_inpaint:
|
||||||
|
if is_debug():
|
||||||
|
save_image(server, "adjusted-output.png", image)
|
||||||
|
mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size))
|
||||||
|
image = original_source
|
||||||
|
image.paste(mini_image, box=adj_mask_border)
|
||||||
dest = save_image(
|
dest = save_image(
|
||||||
server,
|
server,
|
||||||
output,
|
output,
|
||||||
|
|
|
@ -43,6 +43,11 @@ class Border:
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "(%s, %s, %s, %s)" % (self.left, self.right, self.top, self.bottom)
|
return "(%s, %s, %s, %s)" % (self.left, self.right, self.top, self.bottom)
|
||||||
|
|
||||||
|
def isZero(self) -> bool:
|
||||||
|
return all(
|
||||||
|
value == 0 for value in (self.left, self.right, self.top, self.bottom)
|
||||||
|
)
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
return {
|
return {
|
||||||
"left": self.left,
|
"left": self.left,
|
||||||
|
|
Loading…
Reference in New Issue