1
0
Fork 0

fix(api): resize images to min dimensions by padding if necessary (#172)

This commit is contained in:
Sean Sube 2023-02-18 05:35:53 -06:00
parent 3dde3b9237
commit 0e108daa0f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 39 additions and 13 deletions

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import List, Optional from typing import List, Optional
from PIL import Image from PIL import Image
from onnx_web.image import valid_image
from onnx_web.output import save_image from onnx_web.output import save_image
@ -18,7 +19,7 @@ def blend_mask(
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
*, *,
sources: Optional[List[Image.Image]] = None, resized: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None, mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None, _callback: ProgressCallback = None,
**kwargs, **kwargs,
@ -33,7 +34,6 @@ def blend_mask(
save_image(server, "last-mask.png", mask) save_image(server, "last-mask.png", mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
for source in sources: resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]
source.thumbnail(mult_mask.size)
return Image.composite(sources[0], sources[1], mult_mask) return Image.composite(resized[0], resized[1], mult_mask)

View File

@ -255,7 +255,7 @@ def run_blend_pipeline(
server, server,
stage, stage,
params, params,
sources=sources, resized=sources,
mask=mask, mask=mask,
callback=progress, callback=progress,
) )

View File

@ -1,6 +1,7 @@
import numpy as np import numpy as np
from numpy import random from numpy import random
from PIL import Image, ImageChops, ImageFilter from PIL import Image, ImageChops, ImageFilter, ImageOps
from typing import Tuple
from .params import Border, Point from .params import Border, Point
@ -189,3 +190,24 @@ def expand_image(
full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
return (full_source, full_mask, full_noise, (full_width, full_height)) return (full_source, full_mask, full_noise, (full_width, full_height))
def valid_image(
image: Image.Image,
min_dims: Tuple[int, int] = [512, 512],
max_dims: Tuple[int, int] = [512, 512],
) -> Image.Image:
min_x, min_y = min_dims
max_x, max_y = max_dims
if image.width > max_x or image.height > max_y:
image = ImageOps.contain(image, max_dims)
if image.width < min_x or image.height < min_y:
blank = Image.new(image.mode, min_dims, "black")
blank.paste(image)
image = blank
# check for square
return image

View File

@ -49,6 +49,7 @@ from .image import ( # mask filters; noise sources
noise_source_histogram, noise_source_histogram,
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
valid_image,
) )
from .output import json_params, make_output_name from .output import json_params, make_output_name
from .params import ( from .params import (
@ -508,7 +509,7 @@ def img2img():
output = make_output_name(context, "img2img", params, size, extras=(strength,)) output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output) logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_img2img_pipeline, run_img2img_pipeline,
@ -597,8 +598,8 @@ def inpaint():
) )
logger.info("inpaint job queued for: %s", output) logger.info("inpaint job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, min_dims=size, max_dims=size)
mask_image.thumbnail((size.width, size.height)) mask_image = valid_image(mask_image, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_inpaint_pipeline, run_inpaint_pipeline,
@ -635,7 +636,7 @@ def upscale():
output = make_output_name(context, "upscale", params, size) output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output) logger.info("upscale job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_upscale_pipeline, run_upscale_pipeline,
@ -702,7 +703,7 @@ def chain():
) )
source_file = request.files.get(stage_source_name) source_file = request.files.get(stage_source_name)
source_image = Image.open(BytesIO(source_file.read())).convert("RGB") source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image.thumbnail((size.width, size.height)) source_image = valid_image(source_image, max_dims=(size.width, size.height))
kwargs["source_image"] = source_image kwargs["source_image"] = source_image
if stage_mask_name in request.files: if stage_mask_name in request.files:
@ -713,7 +714,7 @@ def chain():
) )
mask_file = request.files.get(stage_mask_name) mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB") mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image.thumbnail((size.width, size.height)) mask_image = valid_image(mask_image, max_dims=(size.width, size.height))
kwargs["mask_image"] = mask_image kwargs["mask_image"] = mask_image
pipeline.append((callback, stage, kwargs)) pipeline.append((callback, stage, kwargs))
@ -743,13 +744,16 @@ def blend():
mask_file = request.files.get("mask") mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)
max_sources = 2 max_sources = 2
sources = [] sources = []
for i in range(max_sources): for i in range(max_sources):
source_file = request.files.get("source:%s" % (i)) source_file = request.files.get("source:%s" % (i))
sources.append(Image.open(BytesIO(source_file.read())).convert("RGBA")) source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()