new tiled inpainting method
This commit is contained in:
parent
3fbf9baae6
commit
08172a7236
|
@ -144,13 +144,14 @@ class ChainPipeline:
|
|||
tile,
|
||||
)
|
||||
|
||||
def stage_tile(source_tile: Image.Image, _dims) -> Image.Image:
|
||||
def stage_tile(source_tile: Image.Image, tile_mask: Image.Image, _dims) -> Image.Image:
|
||||
output_tile = stage_pipe.run(
|
||||
job,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
[source_tile],
|
||||
tile_mask=tile_mask,
|
||||
callback=callback,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
|
|
@ -7,7 +7,14 @@ from typing import List, Optional, Protocol, Tuple
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..image.noise_source import noise_source_histogram
|
||||
from ..image.noise_source import (
|
||||
noise_source_fill_edge,
|
||||
noise_source_fill_mask,
|
||||
noise_source_gaussian,
|
||||
noise_source_histogram,
|
||||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from ..params import Size, TileOrder
|
||||
|
||||
# from skimage.exposure import match_histograms
|
||||
|
@ -232,6 +239,11 @@ def process_tile_spiral(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
width, height = kwargs.get("size", source.size if source else None)
|
||||
mask = kwargs.get("mask",None)
|
||||
noise_source = kwargs.get("noise_source",noise_source_histogram)
|
||||
fill_color = kwargs.get("fill_color",None)
|
||||
if not mask:
|
||||
tile_mask = None
|
||||
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
|
||||
|
@ -277,25 +289,50 @@ def process_tile_spiral(
|
|||
if single_tile:
|
||||
logger.debug("creating and processing single-tile subtile")
|
||||
tile_image = source
|
||||
if mask:
|
||||
tile_mask = mask
|
||||
#otherwise use add histogram noise outside of the image border
|
||||
else:
|
||||
logger.debug("tiling and adding margin")
|
||||
base_image = source.crop(
|
||||
logger.debug("tiling and adding margins: %s, %s, %s, %s",
|
||||
left_margin,
|
||||
top_margin,
|
||||
right_margin,
|
||||
bottom_margin)
|
||||
base_image = (
|
||||
source.crop(
|
||||
(
|
||||
left + left_margin,
|
||||
top + top_margin,
|
||||
right - right_margin,
|
||||
bottom - bottom_margin,
|
||||
right + right_margin,
|
||||
bottom + bottom_margin,
|
||||
)
|
||||
)
|
||||
tile_image = noise_source_histogram(base_image, (tile, tile), (0, 0))
|
||||
)
|
||||
tile_image = noise_source(base_image, (tile, tile), (0, 0),fill=fill_color)
|
||||
tile_image.paste(base_image, (left_margin, top_margin))
|
||||
|
||||
if mask:
|
||||
base_mask = (
|
||||
mask.crop(
|
||||
(
|
||||
left + left_margin,
|
||||
top + top_margin,
|
||||
right + right_margin,
|
||||
bottom + bottom_margin,
|
||||
)
|
||||
)
|
||||
)
|
||||
tile_mask = Image.new("L",(tile,tile),color=0)
|
||||
tile_mask.paste(base_mask, (left_margin, top_margin))
|
||||
|
||||
else:
|
||||
logger.debug("tiling normally")
|
||||
tile_image = source.crop((left, top, right, bottom))
|
||||
if mask:
|
||||
tile_mask = mask.crop((left, top, right, bottom))
|
||||
|
||||
for image_filter in filters:
|
||||
tile_image = image_filter(tile_image, (left, top, tile))
|
||||
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
|
||||
|
||||
tiles.append((left, top, tile_image))
|
||||
|
||||
|
@ -353,7 +390,10 @@ def generate_tile_spiral(
|
|||
span_x = tile + (width_tile_target - 1) * tile_increment
|
||||
span_y = tile + (height_tile_target - 1) * tile_increment
|
||||
|
||||
logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y)
|
||||
logger.debug(
|
||||
"tiled image overlap: %s. Span: %s x %s",
|
||||
overlap,span_x,span_y
|
||||
)
|
||||
|
||||
tile_left = (
|
||||
width - span_x
|
||||
|
|
|
@ -34,6 +34,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
tile_mask: Image.Image,
|
||||
*,
|
||||
border: Border,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
|
@ -60,69 +61,42 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
logger.info(
|
||||
"upscaling %s x %s image by expanding borders: %s",
|
||||
source.width,
|
||||
source.height,
|
||||
border,
|
||||
)
|
||||
|
||||
margin_x = float(max(border.left, border.right))
|
||||
margin_y = float(max(border.top, border.bottom))
|
||||
overlap = min(margin_x / source.width, margin_y / source.height)
|
||||
|
||||
if stage_mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
stage_mask = Image.new("RGB", source.size, "black")
|
||||
|
||||
# masks start as 512x512, resize to cover the source, then trim the extra
|
||||
mask_max = max(source.width, source.height)
|
||||
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
|
||||
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
|
||||
|
||||
source, stage_mask, noise, full_size = expand_image(
|
||||
source,
|
||||
stage_mask,
|
||||
border,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
|
||||
|
||||
draw_mask = ImageDraw.Draw(stage_mask)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-source.png", source)
|
||||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-noise.png", noise)
|
||||
|
||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*tile_source.size)
|
||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||
tile_mask = complete_tile(tile_mask, tile)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", tile_source)
|
||||
save_image(server, "tile-source.png", source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
|
||||
latents = get_tile_latents(full_latents, dims, size)
|
||||
#if the tile mask is all black, skip processing this tile
|
||||
if not tile_mask.getbbox():
|
||||
outputs.append(source)
|
||||
continue
|
||||
|
||||
source_width, source_height = source.size
|
||||
source_size = Size(source_width, source_height)
|
||||
tile_size = params.tiles
|
||||
if max(source_size) > tile_size:
|
||||
latent_size = Size(tile_size,tile_size)
|
||||
latents = get_latents_from_seed(params.seed, latent_size)
|
||||
pipe_width=pipe_height=tile_size
|
||||
else:
|
||||
latent_size = Size(source_size.width,source_size.height)
|
||||
latents = get_latents_from_seed(params.seed, latent_size)
|
||||
pipe_width=source_size.width
|
||||
pipe_height=source_size.height
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
tile_source,
|
||||
source,
|
||||
tile_mask,
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
latents=latents,
|
||||
negative_prompt=negative_prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
guidance_scale=params.cfg,
|
||||
generator=rng,
|
||||
latents=latents,
|
||||
callback=callback,
|
||||
)
|
||||
else:
|
||||
|
@ -135,46 +109,18 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
prompt,
|
||||
tile_source,
|
||||
source,
|
||||
tile_mask,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
negative_prompt=negative_prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
num_inference_steps=params.steps,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=rng,
|
||||
latents=latents,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# once part of the image has been drawn, keep it
|
||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
||||
return result.images[0]
|
||||
|
||||
if params.pipeline == "panorama":
|
||||
logger.debug("outpainting with one shot panorama, no tiling")
|
||||
output = outpaint(source, (0, 0, max(source.width, source.height)))
|
||||
if overlap == 0:
|
||||
logger.debug("outpainting with 0 margin, using grid tiling")
|
||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||
elif border.left == border.right and border.top == border.bottom:
|
||||
logger.debug(
|
||||
"outpainting with an even border, using spiral tiling with %s overlap",
|
||||
overlap,
|
||||
)
|
||||
output = process_tile_order(
|
||||
stage.tile_order,
|
||||
source,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=overlap,
|
||||
)
|
||||
else:
|
||||
logger.debug("outpainting with an uneven border, using grid tiling")
|
||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
outputs.extend(result.images)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from logging import getLogger
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from onnx_web.chain.highres import stage_highres
|
||||
|
||||
|
@ -13,6 +13,7 @@ from ..chain import (
|
|||
UpscaleOutpaintStage,
|
||||
)
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
from ..image import expand_image
|
||||
from ..output import save_image
|
||||
from ..params import (
|
||||
Border,
|
||||
|
@ -24,7 +25,7 @@ from ..params import (
|
|||
)
|
||||
from ..server import ServerContext
|
||||
from ..server.load import get_source_filters
|
||||
from ..utils import run_gc, show_system_toast
|
||||
from ..utils import run_gc, show_system_toast, is_debug
|
||||
from ..worker import WorkerContext
|
||||
from .utils import parse_prompt
|
||||
|
||||
|
@ -221,6 +222,29 @@ def run_inpaint_pipeline(
|
|||
) -> None:
|
||||
logger.debug("building inpaint pipeline")
|
||||
|
||||
if mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
mask = Image.new("L", source.size, 0)
|
||||
|
||||
# masks start as 512x512, resize to cover the source, then trim the extra
|
||||
mask_max = max(source.width, source.height)
|
||||
mask = ImageOps.contain(mask, (mask_max, mask_max))
|
||||
mask = mask.crop((0, 0, source.width, source.height))
|
||||
|
||||
source, mask, noise, full_size = expand_image(
|
||||
source,
|
||||
mask,
|
||||
border,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "full-source.png", source)
|
||||
save_image(server, "full-mask.png", mask)
|
||||
save_image(server, "full-noise.png", noise)
|
||||
|
||||
# set up the chain pipeline and base stage
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
|
||||
|
@ -228,10 +252,11 @@ def run_inpaint_pipeline(
|
|||
UpscaleOutpaintStage(),
|
||||
stage,
|
||||
border=border,
|
||||
stage_mask=mask,
|
||||
mask=mask,
|
||||
fill_color=fill_color,
|
||||
mask_filter=mask_filter,
|
||||
noise_source=noise_source,
|
||||
overlap=params.overlap,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, before highres
|
||||
|
|
|
@ -14,7 +14,7 @@ def expand_image(
|
|||
noise_source=noise_source_histogram,
|
||||
mask_filter=mask_filter_none,
|
||||
):
|
||||
size = Size(*source.size).add_border(expand).round_to_tile()
|
||||
size = Size(*source.size).add_border(expand)
|
||||
size = tuple(size)
|
||||
origin = (expand.left, expand.top)
|
||||
|
||||
|
|
|
@ -249,7 +249,10 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return error_reply("mask image is required")
|
||||
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||
mask = Image.new("RGBA",mask_top_layer.size,color=(0,0,0,255))
|
||||
mask.alpha_composite(mask_top_layer)
|
||||
mask.convert(mode="L")
|
||||
|
||||
device, params, size = pipeline_from_request(server, "inpaint")
|
||||
expand = border_from_request()
|
||||
|
@ -262,6 +265,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
tile_order = get_from_list(
|
||||
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
|
||||
)
|
||||
tile_order = TileOrder.spiral
|
||||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
|
|
Loading…
Reference in New Issue