1
0
Fork 0

Merge branch 'feat/172-inpaint-small-images'

This commit is contained in:
Sean Sube 2023-02-18 05:53:53 -06:00
commit 3cfda639f3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 46 additions and 14 deletions

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import List, Optional from typing import List, Optional
from PIL import Image from PIL import Image
from onnx_web.image import valid_image
from onnx_web.output import save_image from onnx_web.output import save_image
@ -18,7 +19,7 @@ def blend_mask(
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
*, *,
sources: Optional[List[Image.Image]] = None, resized: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None, mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None, _callback: ProgressCallback = None,
**kwargs, **kwargs,
@ -33,7 +34,6 @@ def blend_mask(
save_image(server, "last-mask.png", mask) save_image(server, "last-mask.png", mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
for source in sources: resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]
source.thumbnail(mult_mask.size)
return Image.composite(sources[0], sources[1], mult_mask) return Image.composite(resized[0], resized[1], mult_mask)

View File

@ -20,6 +20,9 @@ def reduce_thumbnail(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
image = source_image.copy() image = source_image.copy()
# TODO: should use a call to valid_image
image = image.thumbnail((size.width, size.height)) image = image.thumbnail((size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image return image

View File

@ -255,7 +255,7 @@ def run_blend_pipeline(
server, server,
stage, stage,
params, params,
sources=sources, resized=sources,
mask=mask, mask=mask,
callback=progress, callback=progress,
) )

View File

@ -1,8 +1,9 @@
import numpy as np import numpy as np
from numpy import random from numpy import random
from PIL import Image, ImageChops, ImageFilter from PIL import Image, ImageChops, ImageFilter, ImageOps
from typing import Tuple, Union
from .params import Border, Point from .params import Border, Point, Size
def get_pixel_index(x: int, y: int, width: int) -> int: def get_pixel_index(x: int, y: int, width: int) -> int:
@ -189,3 +190,24 @@ def expand_image(
full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
return (full_source, full_mask, full_noise, (full_width, full_height)) return (full_source, full_mask, full_noise, (full_width, full_height))
def valid_image(
image: Image.Image,
min_dims: Union[Size, Tuple[int, int]] = [512, 512],
max_dims: Union[Size, Tuple[int, int]] = [512, 512],
) -> Image.Image:
min_x, min_y = min_dims
max_x, max_y = max_dims
if image.width > max_x or image.height > max_y:
image = ImageOps.contain(image, (max_x, max_y))
if image.width < min_x or image.height < min_y:
blank = Image.new(image.mode, (min_x, min_y), "black")
blank.paste(image)
image = blank
# check for square
return image

View File

@ -54,6 +54,9 @@ class Size:
self.width = width self.width = width
self.height = height self.height = height
def __iter__(self):
return iter([self.width, self.height])
def __str__(self) -> str: def __str__(self) -> str:
return "%sx%s" % (self.width, self.height) return "%sx%s" % (self.width, self.height)

View File

@ -49,6 +49,7 @@ from .image import ( # mask filters; noise sources
noise_source_histogram, noise_source_histogram,
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
valid_image,
) )
from .output import json_params, make_output_name from .output import json_params, make_output_name
from .params import ( from .params import (
@ -508,7 +509,7 @@ def img2img():
output = make_output_name(context, "img2img", params, size, extras=(strength,)) output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output) logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_img2img_pipeline, run_img2img_pipeline,
@ -597,8 +598,8 @@ def inpaint():
) )
logger.info("inpaint job queued for: %s", output) logger.info("inpaint job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, min_dims=size, max_dims=size)
mask_image.thumbnail((size.width, size.height)) mask_image = valid_image(mask_image, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_inpaint_pipeline, run_inpaint_pipeline,
@ -635,7 +636,7 @@ def upscale():
output = make_output_name(context, "upscale", params, size) output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output) logger.info("upscale job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_upscale_pipeline, run_upscale_pipeline,
@ -702,7 +703,7 @@ def chain():
) )
source_file = request.files.get(stage_source_name) source_file = request.files.get(stage_source_name)
source_image = Image.open(BytesIO(source_file.read())).convert("RGB") source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, max_dims=(size.width, size.height))
kwargs["source_image"] = source_image kwargs["source_image"] = source_image
if stage_mask_name in request.files: if stage_mask_name in request.files:
@ -713,7 +714,7 @@ def chain():
) )
mask_file = request.files.get(stage_mask_name) mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB") mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image.thumbnail((size.width, size.height)) mask_image = valid_image(mask_image, max_dims=(size.width, size.height))
kwargs["mask_image"] = mask_image kwargs["mask_image"] = mask_image
pipeline.append((callback, stage, kwargs)) pipeline.append((callback, stage, kwargs))
@ -743,13 +744,16 @@ def blend():
mask_file = request.files.get("mask") mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)
max_sources = 2 max_sources = 2
sources = [] sources = []
for i in range(max_sources): for i in range(max_sources):
source_file = request.files.get("source:%s" % (i)) source_file = request.files.get("source:%s" % (i))
sources.append(Image.open(BytesIO(source_file.read())).convert("RGBA")) source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()