1
0
Fork 0

fix(api): resize images to min dimensions by padding if necessary (#172)

This commit is contained in:
Sean Sube 2023-02-18 05:35:53 -06:00
parent 3dde3b9237
commit 0e108daa0f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 39 additions and 13 deletions

View File

@ -2,6 +2,7 @@ from logging import getLogger
from typing import List, Optional
from PIL import Image
from onnx_web.image import valid_image
from onnx_web.output import save_image
@ -18,7 +19,7 @@ def blend_mask(
_stage: StageParams,
_params: ImageParams,
*,
sources: Optional[List[Image.Image]] = None,
resized: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None,
**kwargs,
@ -33,7 +34,6 @@ def blend_mask(
save_image(server, "last-mask.png", mask)
save_image(server, "last-mult-mask.png", mult_mask)
for source in sources:
source.thumbnail(mult_mask.size)
resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]
return Image.composite(sources[0], sources[1], mult_mask)
return Image.composite(resized[0], resized[1], mult_mask)

View File

@ -255,7 +255,7 @@ def run_blend_pipeline(
server,
stage,
params,
sources=sources,
resized=sources,
mask=mask,
callback=progress,
)

View File

@ -1,6 +1,7 @@
import numpy as np
from numpy import random
from PIL import Image, ImageChops, ImageFilter
from PIL import Image, ImageChops, ImageFilter, ImageOps
from typing import Tuple
from .params import Border, Point
@ -189,3 +190,24 @@ def expand_image(
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
return (full_source, full_mask, full_noise, (full_width, full_height))
def valid_image(
image: Image.Image,
min_dims: Tuple[int, int] = [512, 512],
max_dims: Tuple[int, int] = [512, 512],
) -> Image.Image:
min_x, min_y = min_dims
max_x, max_y = max_dims
if image.width > max_x or image.height > max_y:
image = ImageOps.contain(image, max_dims)
if image.width < min_x or image.height < min_y:
blank = Image.new(image.mode, min_dims, "black")
blank.paste(image)
image = blank
# check for square
return image

View File

@ -49,6 +49,7 @@ from .image import ( # mask filters; noise sources
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
valid_image,
)
from .output import json_params, make_output_name
from .params import (
@ -508,7 +509,7 @@ def img2img():
output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit(
output,
run_img2img_pipeline,
@ -597,8 +598,8 @@ def inpaint():
)
logger.info("inpaint job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
mask_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, min_dims=size, max_dims=size)
mask_image = valid_image(mask_image, min_dims=size, max_dims=size)
executor.submit(
output,
run_inpaint_pipeline,
@ -635,7 +636,7 @@ def upscale():
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit(
output,
run_upscale_pipeline,
@ -702,7 +703,7 @@ def chain():
)
source_file = request.files.get(stage_source_name)
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, max_dims=(size.width, size.height))
kwargs["source_image"] = source_image
if stage_mask_name in request.files:
@ -713,7 +714,7 @@ def chain():
)
mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image.thumbnail((size.width, size.height))
mask_image = valid_image(mask_image, max_dims=(size.width, size.height))
kwargs["mask_image"] = mask_image
pipeline.append((callback, stage, kwargs))
@ -743,13 +744,16 @@ def blend():
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)
max_sources = 2
sources = []
for i in range(max_sources):
source_file = request.files.get("source:%s" % (i))
sources.append(Image.open(BytesIO(source_file.read())).convert("RGBA"))
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request()
upscale = upscale_from_request()