1
0
Fork 0

new tiled inpainting method

This commit is contained in:
HoopyFreud 2023-07-09 00:56:20 -04:00
parent 3fbf9baae6
commit 08172a7236
7 changed files with 163 additions and 147 deletions

View File

@ -144,13 +144,14 @@ class ChainPipeline:
tile, tile,
) )
def stage_tile(source_tile: Image.Image, _dims) -> Image.Image: def stage_tile(source_tile: Image.Image, tile_mask: Image.Image, _dims) -> Image.Image:
output_tile = stage_pipe.run( output_tile = stage_pipe.run(
job, job,
server, server,
stage_params, stage_params,
params, params,
[source_tile], [source_tile],
tile_mask=tile_mask,
callback=callback, callback=callback,
**kwargs, **kwargs,
)[0] )[0]

View File

@ -48,14 +48,14 @@ class SourceTxt2ImgStage(BaseStage):
tile_size = params.tiles tile_size = params.tiles
if max(size) > tile_size: if max(size) > tile_size:
latent_size = Size(tile_size, tile_size) latent_size = Size(tile_size,tile_size)
latents = get_latents_from_seed(params.seed, latent_size, params.batch) latents = get_latents_from_seed(params.seed, latent_size, params.batch)
pipe_width = pipe_height = tile_size pipe_width=pipe_height=tile_size
else: else:
latent_size = Size(size.width, size.height) latent_size = Size(size.width,size.height)
latents = get_latents_from_seed(params.seed, latent_size, params.batch) latents = get_latents_from_seed(params.seed, latent_size, params.batch)
pipe_width = size.width pipe_width=size.width
pipe_height = size.height pipe_height=size.height
pipe_type = params.get_valid_pipeline("txt2img") pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline( pipe = load_pipeline(

View File

@ -7,7 +7,14 @@ from typing import List, Optional, Protocol, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ..image.noise_source import noise_source_histogram from ..image.noise_source import (
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
from ..params import Size, TileOrder from ..params import Size, TileOrder
# from skimage.exposure import match_histograms # from skimage.exposure import match_histograms
@ -232,6 +239,11 @@ def process_tile_spiral(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
width, height = kwargs.get("size", source.size if source else None) width, height = kwargs.get("size", source.size if source else None)
mask = kwargs.get("mask",None)
noise_source = kwargs.get("noise_source",noise_source_histogram)
fill_color = kwargs.get("fill_color",None)
if not mask:
tile_mask = None
tiles: List[Tuple[int, int, Image.Image]] = [] tiles: List[Tuple[int, int, Image.Image]] = []
@ -269,33 +281,58 @@ def process_tile_spiral(
needs_margin = True needs_margin = True
bottom_margin = height - bottom bottom_margin = height - bottom
# if no source given, we don't have a source image #if no source given, we don't have a source image
if not source: if not source:
tile_image = None tile_image = None
elif needs_margin: elif needs_margin:
# in the special case where the image is smaller than the specified tile size, just use the image #in the special case where the image is smaller than the specified tile size, just use the image
if single_tile: if single_tile:
logger.debug("creating and processing single-tile subtile") logger.debug("creating and processing single-tile subtile")
tile_image = source tile_image = source
# otherwise use add histogram noise outside of the image border if mask:
tile_mask = mask
#otherwise use add histogram noise outside of the image border
else: else:
logger.debug("tiling and adding margin") logger.debug("tiling and adding margins: %s, %s, %s, %s",
base_image = source.crop( left_margin,
top_margin,
right_margin,
bottom_margin)
base_image = (
source.crop(
( (
left + left_margin, left + left_margin,
top + top_margin, top + top_margin,
right - right_margin, right + right_margin,
bottom - bottom_margin, bottom + bottom_margin,
) )
) )
tile_image = noise_source_histogram(base_image, (tile, tile), (0, 0)) )
tile_image = noise_source(base_image, (tile, tile), (0, 0),fill=fill_color)
tile_image.paste(base_image, (left_margin, top_margin)) tile_image.paste(base_image, (left_margin, top_margin))
if mask:
base_mask = (
mask.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
)
tile_mask = Image.new("L",(tile,tile),color=0)
tile_mask.paste(base_mask, (left_margin, top_margin))
else: else:
logger.debug("tiling normally") logger.debug("tiling normally")
tile_image = source.crop((left, top, right, bottom)) tile_image = source.crop((left, top, right, bottom))
if mask:
tile_mask = mask.crop((left, top, right, bottom))
for image_filter in filters: for image_filter in filters:
tile_image = image_filter(tile_image, (left, top, tile)) tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
tiles.append((left, top, tile_image)) tiles.append((left, top, tile_image))
@ -353,7 +390,10 @@ def generate_tile_spiral(
span_x = tile + (width_tile_target - 1) * tile_increment span_x = tile + (width_tile_target - 1) * tile_increment
span_y = tile + (height_tile_target - 1) * tile_increment span_y = tile + (height_tile_target - 1) * tile_increment
logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y) logger.debug(
"tiled image overlap: %s. Span: %s x %s",
overlap,span_x,span_y
)
tile_left = ( tile_left = (
width - span_x width - span_x

View File

@ -34,6 +34,7 @@ class UpscaleOutpaintStage(BaseStage):
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: List[Image.Image],
tile_mask: Image.Image,
*, *,
border: Border, border: Border,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
@ -60,69 +61,42 @@ class UpscaleOutpaintStage(BaseStage):
outputs = [] outputs = []
for source in sources: for source in sources:
logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
)
margin_x = float(max(border.left, border.right)) save_image(server, "tile-source.png", source)
margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height)
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
source, stage_mask, noise, full_size = expand_image(
source,
stage_mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
draw_mask = ImageDraw.Draw(stage_mask)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "tile-mask.png", tile_mask)
latents = get_tile_latents(full_latents, dims, size) #if the tile mask is all black, skip processing this tile
if not tile_mask.getbbox():
outputs.append(source)
continue
source_width, source_height = source.size
source_size = Size(source_width, source_height)
tile_size = params.tiles
if max(source_size) > tile_size:
latent_size = Size(tile_size,tile_size)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width=pipe_height=tile_size
else:
latent_size = Size(source_size.width,source_size.height)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width=source_size.width
pipe_height=source_size.height
if params.lpw(): if params.lpw():
logger.debug("using LPW pipeline for inpaint") logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.inpaint( result = pipe.inpaint(
tile_source, source,
tile_mask, tile_mask,
prompt, prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
height=pipe_height,
width=pipe_width,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, guidance_scale=params.cfg,
generator=rng,
latents=latents,
callback=callback, callback=callback,
) )
else: else:
@ -135,46 +109,18 @@ class UpscaleOutpaintStage(BaseStage):
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
prompt, prompt,
tile_source, source,
tile_mask, tile_mask,
height=size.height, negative_prompt=negative_prompt,
width=size.width, height=pipe_height,
width=pipe_width,
num_inference_steps=params.steps, num_inference_steps=params.steps,
guidance_scale=params.cfg, guidance_scale=params.cfg,
negative_prompt=negative_prompt,
generator=rng, generator=rng,
latents=latents, latents=latents,
callback=callback, callback=callback,
) )
# once part of the image has been drawn, keep it outputs.extend(result.images)
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
output = outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return outputs return outputs

View File

@ -1,7 +1,7 @@
from logging import getLogger from logging import getLogger
from typing import Any, List, Optional from typing import Any, List, Optional
from PIL import Image from PIL import Image, ImageOps
from onnx_web.chain.highres import stage_highres from onnx_web.chain.highres import stage_highres
@ -13,6 +13,7 @@ from ..chain import (
UpscaleOutpaintStage, UpscaleOutpaintStage,
) )
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image
from ..output import save_image from ..output import save_image
from ..params import ( from ..params import (
Border, Border,
@ -24,7 +25,7 @@ from ..params import (
) )
from ..server import ServerContext from ..server import ServerContext
from ..server.load import get_source_filters from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast from ..utils import run_gc, show_system_toast, is_debug
from ..worker import WorkerContext from ..worker import WorkerContext
from .utils import parse_prompt from .utils import parse_prompt
@ -221,6 +222,29 @@ def run_inpaint_pipeline(
) -> None: ) -> None:
logger.debug("building inpaint pipeline") logger.debug("building inpaint pipeline")
if mask is None:
# if no mask was provided, keep the full source image
mask = Image.new("L", source.size, 0)
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
mask = ImageOps.contain(mask, (mask_max, mask_max))
mask = mask.crop((0, 0, source.width, source.height))
source, mask, noise, full_size = expand_image(
source,
mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "full-source.png", source)
save_image(server, "full-mask.png", mask)
save_image(server, "full-noise.png", noise)
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_order=tile_order, tile_size=params.tiles) stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
@ -228,10 +252,11 @@ def run_inpaint_pipeline(
UpscaleOutpaintStage(), UpscaleOutpaintStage(),
stage, stage,
border=border, border=border,
stage_mask=mask, mask=mask,
fill_color=fill_color, fill_color=fill_color,
mask_filter=mask_filter, mask_filter=mask_filter,
noise_source=noise_source, noise_source=noise_source,
overlap=params.overlap,
) )
# apply upscaling and correction, before highres # apply upscaling and correction, before highres

View File

@ -14,7 +14,7 @@ def expand_image(
noise_source=noise_source_histogram, noise_source=noise_source_histogram,
mask_filter=mask_filter_none, mask_filter=mask_filter_none,
): ):
size = Size(*source.size).add_border(expand).round_to_tile() size = Size(*source.size).add_border(expand)
size = tuple(size) size = tuple(size)
origin = (expand.left, expand.top) origin = (expand.left, expand.top)

View File

@ -249,7 +249,10 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("mask image is required") return error_reply("mask image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB") mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = Image.new("RGBA",mask_top_layer.size,color=(0,0,0,255))
mask.alpha_composite(mask_top_layer)
mask.convert(mode="L")
device, params, size = pipeline_from_request(server, "inpaint") device, params, size = pipeline_from_request(server, "inpaint")
expand = border_from_request() expand = border_from_request()
@ -262,6 +265,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
tile_order = get_from_list( tile_order = get_from_list(
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
) )
tile_order = TileOrder.spiral
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())