new tiled inpainting method
This commit is contained in:
parent
3fbf9baae6
commit
08172a7236
|
@ -144,13 +144,14 @@ class ChainPipeline:
|
||||||
tile,
|
tile,
|
||||||
)
|
)
|
||||||
|
|
||||||
def stage_tile(source_tile: Image.Image, _dims) -> Image.Image:
|
def stage_tile(source_tile: Image.Image, tile_mask: Image.Image, _dims) -> Image.Image:
|
||||||
output_tile = stage_pipe.run(
|
output_tile = stage_pipe.run(
|
||||||
job,
|
job,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
params,
|
params,
|
||||||
[source_tile],
|
[source_tile],
|
||||||
|
tile_mask=tile_mask,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
|
@ -7,7 +7,14 @@ from typing import List, Optional, Protocol, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..image.noise_source import noise_source_histogram
|
from ..image.noise_source import (
|
||||||
|
noise_source_fill_edge,
|
||||||
|
noise_source_fill_mask,
|
||||||
|
noise_source_gaussian,
|
||||||
|
noise_source_histogram,
|
||||||
|
noise_source_normal,
|
||||||
|
noise_source_uniform,
|
||||||
|
)
|
||||||
from ..params import Size, TileOrder
|
from ..params import Size, TileOrder
|
||||||
|
|
||||||
# from skimage.exposure import match_histograms
|
# from skimage.exposure import match_histograms
|
||||||
|
@ -232,6 +239,11 @@ def process_tile_spiral(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
width, height = kwargs.get("size", source.size if source else None)
|
width, height = kwargs.get("size", source.size if source else None)
|
||||||
|
mask = kwargs.get("mask",None)
|
||||||
|
noise_source = kwargs.get("noise_source",noise_source_histogram)
|
||||||
|
fill_color = kwargs.get("fill_color",None)
|
||||||
|
if not mask:
|
||||||
|
tile_mask = None
|
||||||
|
|
||||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||||
|
|
||||||
|
@ -277,25 +289,50 @@ def process_tile_spiral(
|
||||||
if single_tile:
|
if single_tile:
|
||||||
logger.debug("creating and processing single-tile subtile")
|
logger.debug("creating and processing single-tile subtile")
|
||||||
tile_image = source
|
tile_image = source
|
||||||
|
if mask:
|
||||||
|
tile_mask = mask
|
||||||
#otherwise use add histogram noise outside of the image border
|
#otherwise use add histogram noise outside of the image border
|
||||||
else:
|
else:
|
||||||
logger.debug("tiling and adding margin")
|
logger.debug("tiling and adding margins: %s, %s, %s, %s",
|
||||||
base_image = source.crop(
|
left_margin,
|
||||||
|
top_margin,
|
||||||
|
right_margin,
|
||||||
|
bottom_margin)
|
||||||
|
base_image = (
|
||||||
|
source.crop(
|
||||||
(
|
(
|
||||||
left + left_margin,
|
left + left_margin,
|
||||||
top + top_margin,
|
top + top_margin,
|
||||||
right - right_margin,
|
right + right_margin,
|
||||||
bottom - bottom_margin,
|
bottom + bottom_margin,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tile_image = noise_source_histogram(base_image, (tile, tile), (0, 0))
|
)
|
||||||
|
tile_image = noise_source(base_image, (tile, tile), (0, 0),fill=fill_color)
|
||||||
tile_image.paste(base_image, (left_margin, top_margin))
|
tile_image.paste(base_image, (left_margin, top_margin))
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
base_mask = (
|
||||||
|
mask.crop(
|
||||||
|
(
|
||||||
|
left + left_margin,
|
||||||
|
top + top_margin,
|
||||||
|
right + right_margin,
|
||||||
|
bottom + bottom_margin,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tile_mask = Image.new("L",(tile,tile),color=0)
|
||||||
|
tile_mask.paste(base_mask, (left_margin, top_margin))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("tiling normally")
|
logger.debug("tiling normally")
|
||||||
tile_image = source.crop((left, top, right, bottom))
|
tile_image = source.crop((left, top, right, bottom))
|
||||||
|
if mask:
|
||||||
|
tile_mask = mask.crop((left, top, right, bottom))
|
||||||
|
|
||||||
for image_filter in filters:
|
for image_filter in filters:
|
||||||
tile_image = image_filter(tile_image, (left, top, tile))
|
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
|
||||||
|
|
||||||
tiles.append((left, top, tile_image))
|
tiles.append((left, top, tile_image))
|
||||||
|
|
||||||
|
@ -353,7 +390,10 @@ def generate_tile_spiral(
|
||||||
span_x = tile + (width_tile_target - 1) * tile_increment
|
span_x = tile + (width_tile_target - 1) * tile_increment
|
||||||
span_y = tile + (height_tile_target - 1) * tile_increment
|
span_y = tile + (height_tile_target - 1) * tile_increment
|
||||||
|
|
||||||
logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y)
|
logger.debug(
|
||||||
|
"tiled image overlap: %s. Span: %s x %s",
|
||||||
|
overlap,span_x,span_y
|
||||||
|
)
|
||||||
|
|
||||||
tile_left = (
|
tile_left = (
|
||||||
width - span_x
|
width - span_x
|
||||||
|
|
|
@ -34,6 +34,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: List[Image.Image],
|
||||||
|
tile_mask: Image.Image,
|
||||||
*,
|
*,
|
||||||
border: Border,
|
border: Border,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
@ -60,69 +61,42 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources:
|
||||||
logger.info(
|
|
||||||
"upscaling %s x %s image by expanding borders: %s",
|
|
||||||
source.width,
|
|
||||||
source.height,
|
|
||||||
border,
|
|
||||||
)
|
|
||||||
|
|
||||||
margin_x = float(max(border.left, border.right))
|
save_image(server, "tile-source.png", source)
|
||||||
margin_y = float(max(border.top, border.bottom))
|
|
||||||
overlap = min(margin_x / source.width, margin_y / source.height)
|
|
||||||
|
|
||||||
if stage_mask is None:
|
|
||||||
# if no mask was provided, keep the full source image
|
|
||||||
stage_mask = Image.new("RGB", source.size, "black")
|
|
||||||
|
|
||||||
# masks start as 512x512, resize to cover the source, then trim the extra
|
|
||||||
mask_max = max(source.width, source.height)
|
|
||||||
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
|
|
||||||
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
|
|
||||||
|
|
||||||
source, stage_mask, noise, full_size = expand_image(
|
|
||||||
source,
|
|
||||||
stage_mask,
|
|
||||||
border,
|
|
||||||
fill=fill_color,
|
|
||||||
noise_source=noise_source,
|
|
||||||
mask_filter=mask_filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
|
|
||||||
|
|
||||||
draw_mask = ImageDraw.Draw(stage_mask)
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-source.png", source)
|
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
|
||||||
save_image(server, "last-noise.png", noise)
|
|
||||||
|
|
||||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
|
||||||
left, top, tile = dims
|
|
||||||
size = Size(*tile_source.size)
|
|
||||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
|
||||||
tile_mask = complete_tile(tile_mask, tile)
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "tile-source.png", tile_source)
|
|
||||||
save_image(server, "tile-mask.png", tile_mask)
|
save_image(server, "tile-mask.png", tile_mask)
|
||||||
|
|
||||||
latents = get_tile_latents(full_latents, dims, size)
|
#if the tile mask is all black, skip processing this tile
|
||||||
|
if not tile_mask.getbbox():
|
||||||
|
outputs.append(source)
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_width, source_height = source.size
|
||||||
|
source_size = Size(source_width, source_height)
|
||||||
|
tile_size = params.tiles
|
||||||
|
if max(source_size) > tile_size:
|
||||||
|
latent_size = Size(tile_size,tile_size)
|
||||||
|
latents = get_latents_from_seed(params.seed, latent_size)
|
||||||
|
pipe_width=pipe_height=tile_size
|
||||||
|
else:
|
||||||
|
latent_size = Size(source_size.width,source_size.height)
|
||||||
|
latents = get_latents_from_seed(params.seed, latent_size)
|
||||||
|
pipe_width=source_size.width
|
||||||
|
pipe_height=source_size.height
|
||||||
|
|
||||||
if params.lpw():
|
if params.lpw():
|
||||||
logger.debug("using LPW pipeline for inpaint")
|
logger.debug("using LPW pipeline for inpaint")
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
result = pipe.inpaint(
|
result = pipe.inpaint(
|
||||||
tile_source,
|
source,
|
||||||
tile_mask,
|
tile_mask,
|
||||||
prompt,
|
prompt,
|
||||||
generator=rng,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
height=size.height,
|
|
||||||
latents=latents,
|
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
|
height=pipe_height,
|
||||||
|
width=pipe_width,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
width=size.width,
|
guidance_scale=params.cfg,
|
||||||
|
generator=rng,
|
||||||
|
latents=latents,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -135,46 +109,18 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = pipe(
|
result = pipe(
|
||||||
prompt,
|
prompt,
|
||||||
tile_source,
|
source,
|
||||||
tile_mask,
|
tile_mask,
|
||||||
height=size.height,
|
negative_prompt=negative_prompt,
|
||||||
width=size.width,
|
height=pipe_height,
|
||||||
|
width=pipe_width,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
guidance_scale=params.cfg,
|
guidance_scale=params.cfg,
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
generator=rng,
|
generator=rng,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# once part of the image has been drawn, keep it
|
outputs.extend(result.images)
|
||||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
|
||||||
return result.images[0]
|
|
||||||
|
|
||||||
if params.pipeline == "panorama":
|
|
||||||
logger.debug("outpainting with one shot panorama, no tiling")
|
|
||||||
output = outpaint(source, (0, 0, max(source.width, source.height)))
|
|
||||||
if overlap == 0:
|
|
||||||
logger.debug("outpainting with 0 margin, using grid tiling")
|
|
||||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
|
||||||
elif border.left == border.right and border.top == border.bottom:
|
|
||||||
logger.debug(
|
|
||||||
"outpainting with an even border, using spiral tiling with %s overlap",
|
|
||||||
overlap,
|
|
||||||
)
|
|
||||||
output = process_tile_order(
|
|
||||||
stage.tile_order,
|
|
||||||
source,
|
|
||||||
SizeChart.auto,
|
|
||||||
1,
|
|
||||||
[outpaint],
|
|
||||||
overlap=overlap,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("outpainting with an uneven border, using grid tiling")
|
|
||||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from onnx_web.chain.highres import stage_highres
|
from onnx_web.chain.highres import stage_highres
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from ..chain import (
|
||||||
UpscaleOutpaintStage,
|
UpscaleOutpaintStage,
|
||||||
)
|
)
|
||||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
|
from ..image import expand_image
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
|
@ -24,7 +25,7 @@ from ..params import (
|
||||||
)
|
)
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..server.load import get_source_filters
|
from ..server.load import get_source_filters
|
||||||
from ..utils import run_gc, show_system_toast
|
from ..utils import run_gc, show_system_toast, is_debug
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .utils import parse_prompt
|
from .utils import parse_prompt
|
||||||
|
|
||||||
|
@ -221,6 +222,29 @@ def run_inpaint_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("building inpaint pipeline")
|
logger.debug("building inpaint pipeline")
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
# if no mask was provided, keep the full source image
|
||||||
|
mask = Image.new("L", source.size, 0)
|
||||||
|
|
||||||
|
# masks start as 512x512, resize to cover the source, then trim the extra
|
||||||
|
mask_max = max(source.width, source.height)
|
||||||
|
mask = ImageOps.contain(mask, (mask_max, mask_max))
|
||||||
|
mask = mask.crop((0, 0, source.width, source.height))
|
||||||
|
|
||||||
|
source, mask, noise, full_size = expand_image(
|
||||||
|
source,
|
||||||
|
mask,
|
||||||
|
border,
|
||||||
|
fill=fill_color,
|
||||||
|
noise_source=noise_source,
|
||||||
|
mask_filter=mask_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_debug():
|
||||||
|
save_image(server, "full-source.png", source)
|
||||||
|
save_image(server, "full-mask.png", mask)
|
||||||
|
save_image(server, "full-noise.png", noise)
|
||||||
|
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
|
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
|
||||||
|
@ -228,10 +252,11 @@ def run_inpaint_pipeline(
|
||||||
UpscaleOutpaintStage(),
|
UpscaleOutpaintStage(),
|
||||||
stage,
|
stage,
|
||||||
border=border,
|
border=border,
|
||||||
stage_mask=mask,
|
mask=mask,
|
||||||
fill_color=fill_color,
|
fill_color=fill_color,
|
||||||
mask_filter=mask_filter,
|
mask_filter=mask_filter,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
|
overlap=params.overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
|
|
|
@ -14,7 +14,7 @@ def expand_image(
|
||||||
noise_source=noise_source_histogram,
|
noise_source=noise_source_histogram,
|
||||||
mask_filter=mask_filter_none,
|
mask_filter=mask_filter_none,
|
||||||
):
|
):
|
||||||
size = Size(*source.size).add_border(expand).round_to_tile()
|
size = Size(*source.size).add_border(expand)
|
||||||
size = tuple(size)
|
size = tuple(size)
|
||||||
origin = (expand.left, expand.top)
|
origin = (expand.left, expand.top)
|
||||||
|
|
||||||
|
|
|
@ -249,7 +249,10 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return error_reply("mask image is required")
|
return error_reply("mask image is required")
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||||
|
mask = Image.new("RGBA",mask_top_layer.size,color=(0,0,0,255))
|
||||||
|
mask.alpha_composite(mask_top_layer)
|
||||||
|
mask.convert(mode="L")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server, "inpaint")
|
device, params, size = pipeline_from_request(server, "inpaint")
|
||||||
expand = border_from_request()
|
expand = border_from_request()
|
||||||
|
@ -262,6 +265,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
tile_order = get_from_list(
|
tile_order = get_from_list(
|
||||||
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
|
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
|
||||||
)
|
)
|
||||||
|
tile_order = TileOrder.spiral
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue