1
0
Fork 0

new tiled inpainting method

This commit is contained in:
HoopyFreud 2023-07-09 00:56:20 -04:00
parent 3fbf9baae6
commit 08172a7236
7 changed files with 163 additions and 147 deletions

View File

@ -144,13 +144,14 @@ class ChainPipeline:
tile,
)
def stage_tile(source_tile: Image.Image, _dims) -> Image.Image:
def stage_tile(source_tile: Image.Image, tile_mask: Image.Image, _dims) -> Image.Image:
output_tile = stage_pipe.run(
job,
server,
stage_params,
params,
[source_tile],
tile_mask=tile_mask,
callback=callback,
**kwargs,
)[0]

View File

@ -7,7 +7,14 @@ from typing import List, Optional, Protocol, Tuple
import numpy as np
from PIL import Image
from ..image.noise_source import noise_source_histogram
from ..image.noise_source import (
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
from ..params import Size, TileOrder
# from skimage.exposure import match_histograms
@ -232,6 +239,11 @@ def process_tile_spiral(
**kwargs,
) -> Image.Image:
width, height = kwargs.get("size", source.size if source else None)
mask = kwargs.get("mask",None)
noise_source = kwargs.get("noise_source",noise_source_histogram)
fill_color = kwargs.get("fill_color",None)
if not mask:
tile_mask = None
tiles: List[Tuple[int, int, Image.Image]] = []
@ -277,25 +289,50 @@ def process_tile_spiral(
if single_tile:
logger.debug("creating and processing single-tile subtile")
tile_image = source
if mask:
tile_mask = mask
#otherwise use add histogram noise outside of the image border
else:
logger.debug("tiling and adding margin")
base_image = source.crop(
logger.debug("tiling and adding margins: %s, %s, %s, %s",
left_margin,
top_margin,
right_margin,
bottom_margin)
base_image = (
source.crop(
(
left + left_margin,
top + top_margin,
right - right_margin,
bottom - bottom_margin,
right + right_margin,
bottom + bottom_margin,
)
)
tile_image = noise_source_histogram(base_image, (tile, tile), (0, 0))
)
tile_image = noise_source(base_image, (tile, tile), (0, 0),fill=fill_color)
tile_image.paste(base_image, (left_margin, top_margin))
if mask:
base_mask = (
mask.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
)
tile_mask = Image.new("L",(tile,tile),color=0)
tile_mask.paste(base_mask, (left_margin, top_margin))
else:
logger.debug("tiling normally")
tile_image = source.crop((left, top, right, bottom))
if mask:
tile_mask = mask.crop((left, top, right, bottom))
for image_filter in filters:
tile_image = image_filter(tile_image, (left, top, tile))
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
tiles.append((left, top, tile_image))
@ -353,7 +390,10 @@ def generate_tile_spiral(
span_x = tile + (width_tile_target - 1) * tile_increment
span_y = tile + (height_tile_target - 1) * tile_increment
logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y)
logger.debug(
"tiled image overlap: %s. Span: %s x %s",
overlap,span_x,span_y
)
tile_left = (
width - span_x

View File

@ -34,6 +34,7 @@ class UpscaleOutpaintStage(BaseStage):
stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
tile_mask: Image.Image,
*,
border: Border,
stage_source: Optional[Image.Image] = None,
@ -60,69 +61,42 @@ class UpscaleOutpaintStage(BaseStage):
outputs = []
for source in sources:
logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
)
margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height)
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
source, stage_mask, noise, full_size = expand_image(
source,
stage_mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
draw_mask = ImageDraw.Draw(stage_mask)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-source.png", source)
save_image(server, "tile-mask.png", tile_mask)
latents = get_tile_latents(full_latents, dims, size)
#if the tile mask is all black, skip processing this tile
if not tile_mask.getbbox():
outputs.append(source)
continue
source_width, source_height = source.size
source_size = Size(source_width, source_height)
tile_size = params.tiles
if max(source_size) > tile_size:
latent_size = Size(tile_size,tile_size)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width=pipe_height=tile_size
else:
latent_size = Size(source_size.width,source_size.height)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width=source_size.width
pipe_height=source_size.height
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
tile_source,
source,
tile_mask,
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=negative_prompt,
height=pipe_height,
width=pipe_width,
num_inference_steps=params.steps,
width=size.width,
guidance_scale=params.cfg,
generator=rng,
latents=latents,
callback=callback,
)
else:
@ -135,46 +109,18 @@ class UpscaleOutpaintStage(BaseStage):
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,
tile_source,
source,
tile_mask,
height=size.height,
width=size.width,
negative_prompt=negative_prompt,
height=pipe_height,
width=pipe_width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
output = outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
outputs.extend(result.images)
return outputs

View File

@ -1,7 +1,7 @@
from logging import getLogger
from typing import Any, List, Optional
from PIL import Image
from PIL import Image, ImageOps
from onnx_web.chain.highres import stage_highres
@ -13,6 +13,7 @@ from ..chain import (
UpscaleOutpaintStage,
)
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image
from ..output import save_image
from ..params import (
Border,
@ -24,7 +25,7 @@ from ..params import (
)
from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
from ..utils import run_gc, show_system_toast, is_debug
from ..worker import WorkerContext
from .utils import parse_prompt
@ -221,6 +222,29 @@ def run_inpaint_pipeline(
) -> None:
logger.debug("building inpaint pipeline")
if mask is None:
# if no mask was provided, keep the full source image
mask = Image.new("L", source.size, 0)
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
mask = ImageOps.contain(mask, (mask_max, mask_max))
mask = mask.crop((0, 0, source.width, source.height))
source, mask, noise, full_size = expand_image(
source,
mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "full-source.png", source)
save_image(server, "full-mask.png", mask)
save_image(server, "full-noise.png", noise)
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
@ -228,10 +252,11 @@ def run_inpaint_pipeline(
UpscaleOutpaintStage(),
stage,
border=border,
stage_mask=mask,
mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
overlap=params.overlap,
)
# apply upscaling and correction, before highres

View File

@ -14,7 +14,7 @@ def expand_image(
noise_source=noise_source_histogram,
mask_filter=mask_filter_none,
):
size = Size(*source.size).add_border(expand).round_to_tile()
size = Size(*source.size).add_border(expand)
size = tuple(size)
origin = (expand.left, expand.top)

View File

@ -249,7 +249,10 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("mask image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = Image.new("RGBA",mask_top_layer.size,color=(0,0,0,255))
mask.alpha_composite(mask_top_layer)
mask.convert(mode="L")
device, params, size = pipeline_from_request(server, "inpaint")
expand = border_from_request()
@ -262,6 +265,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
tile_order = get_from_list(
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
)
tile_order = TileOrder.spiral
replace_wildcards(params, get_wildcard_data())