From 0e108daa0fa99c45432a68d037511d972d5b1fbb Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 05:35:53 -0600 Subject: [PATCH 1/3] fix(api): resize images to min dimensions by padding if necessary (#172) --- api/onnx_web/chain/blend_mask.py | 8 ++++---- api/onnx_web/diffusion/run.py | 2 +- api/onnx_web/image.py | 24 +++++++++++++++++++++++- api/onnx_web/serve.py | 18 +++++++++++------- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 537c7abd..21333498 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -2,6 +2,7 @@ from logging import getLogger from typing import List, Optional from PIL import Image +from onnx_web.image import valid_image from onnx_web.output import save_image @@ -18,7 +19,7 @@ def blend_mask( _stage: StageParams, _params: ImageParams, *, - sources: Optional[List[Image.Image]] = None, + resized: Optional[List[Image.Image]] = None, mask: Optional[Image.Image] = None, _callback: ProgressCallback = None, **kwargs, @@ -33,7 +34,6 @@ def blend_mask( save_image(server, "last-mask.png", mask) save_image(server, "last-mult-mask.png", mult_mask) - for source in sources: - source.thumbnail(mult_mask.size) + resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized] - return Image.composite(sources[0], sources[1], mult_mask) + return Image.composite(resized[0], resized[1], mult_mask) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index da58a42b..d0d3bc29 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -255,7 +255,7 @@ def run_blend_pipeline( server, stage, params, - sources=sources, + resized=sources, mask=mask, callback=progress, ) diff --git a/api/onnx_web/image.py b/api/onnx_web/image.py index 45a8fa89..fcccfd9a 100644 --- a/api/onnx_web/image.py +++ b/api/onnx_web/image.py @@ -1,6 +1,7 @@ import numpy as np from numpy import random -from PIL import Image, ImageChops, ImageFilter +from PIL import Image, ImageChops, ImageFilter, ImageOps +from typing import Tuple from .params import Border, Point @@ -189,3 +190,24 @@ def expand_image( full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) return (full_source, full_mask, full_noise, (full_width, full_height)) + + +def valid_image( + image: Image.Image, + min_dims: Tuple[int, int] = [512, 512], + max_dims: Tuple[int, int] = [512, 512], +) -> Image.Image: + min_x, min_y = min_dims + max_x, max_y = max_dims + + if image.width > max_x or image.height > max_y: + image = ImageOps.contain(image, max_dims) + + if image.width < min_x or image.height < min_y: + blank = Image.new(image.mode, min_dims, "black") + blank.paste(image) + image = blank + + # check for square + + return image \ No newline at end of file diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index c0c68217..1ae14860 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -49,6 +49,7 @@ from .image import ( # mask filters; noise sources noise_source_histogram, noise_source_normal, noise_source_uniform, + valid_image, ) from .output import json_params, make_output_name from .params import ( @@ -508,7 +509,7 @@ def img2img(): output = make_output_name(context, "img2img", params, size, extras=(strength,)) logger.info("img2img job queued for: %s", output) - source_image.thumbnail((size.width, size.height)) + source_image = valid_image(source_image, min_dims=size, max_dims=size) executor.submit( output, run_img2img_pipeline, @@ -597,8 +598,8 @@ def inpaint(): ) logger.info("inpaint job queued for: %s", output) - source_image.thumbnail((size.width, size.height)) - mask_image.thumbnail((size.width, size.height)) + source_image = valid_image(source_image, min_dims=size, max_dims=size) + mask_image = valid_image(mask_image, min_dims=size, max_dims=size) executor.submit( output, run_inpaint_pipeline, @@ -635,7 +636,7 @@ def upscale(): output = make_output_name(context, "upscale", params, size) logger.info("upscale job queued for: %s", output) - source_image.thumbnail((size.width, size.height)) + source_image = valid_image(source_image, min_dims=size, max_dims=size) executor.submit( output, run_upscale_pipeline, @@ -702,7 +703,7 @@ def chain(): ) source_file = request.files.get(stage_source_name) source_image = Image.open(BytesIO(source_file.read())).convert("RGB") - source_image.thumbnail((size.width, size.height)) + source_image = valid_image(source_image, max_dims=(size.width, size.height)) kwargs["source_image"] = source_image if stage_mask_name in request.files: @@ -713,7 +714,7 @@ def chain(): ) mask_file = request.files.get(stage_mask_name) mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB") - mask_image.thumbnail((size.width, size.height)) + mask_image = valid_image(mask_image, max_dims=(size.width, size.height)) kwargs["mask_image"] = mask_image pipeline.append((callback, stage, kwargs)) @@ -743,13 +744,16 @@ def blend(): mask_file = request.files.get("mask") mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") + mask = valid_image(mask) max_sources = 2 sources = [] for i in range(max_sources): source_file = request.files.get("source:%s" % (i)) - sources.append(Image.open(BytesIO(source_file.read())).convert("RGBA")) + source = Image.open(BytesIO(source_file.read())).convert("RGBA") + source = valid_image(source, mask.size, mask.size) + sources.append(source) device, params, size = pipeline_from_request() upscale = upscale_from_request() From 3ca02d48755a7153f9aa0645b987b9d35189ea47 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 05:44:43 -0600 Subject: [PATCH 2/3] fix(api): make size params iterable --- api/onnx_web/chain/reduce_thumbnail.py | 3 +++ api/onnx_web/params.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 50cccc1d..9114c289 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -20,6 +20,9 @@ def reduce_thumbnail( **kwargs, ) -> Image.Image: image = source_image.copy() + + # TODO: should use a call to valid_image image = image.thumbnail((size.width, size.height)) + logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) return image diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 1a645056..4a744138 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -54,6 +54,9 @@ class Size: self.width = width self.height = height + def __iter__(self): + return iter([self.width, self.height]) + def __str__(self) -> str: return "%sx%s" % (self.width, self.height) From 431db6e3f865e272bf0b347a81b8b66c9b29edc5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 05:47:34 -0600 Subject: [PATCH 3/3] repack size into tuples --- api/onnx_web/image.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/image.py b/api/onnx_web/image.py index fcccfd9a..be7153ea 100644 --- a/api/onnx_web/image.py +++ b/api/onnx_web/image.py @@ -1,9 +1,9 @@ import numpy as np from numpy import random from PIL import Image, ImageChops, ImageFilter, ImageOps -from typing import Tuple +from typing import Tuple, Union -from .params import Border, Point +from .params import Border, Point, Size def get_pixel_index(x: int, y: int, width: int) -> int: @@ -194,17 +194,17 @@ def expand_image( def valid_image( image: Image.Image, - min_dims: Tuple[int, int] = [512, 512], - max_dims: Tuple[int, int] = [512, 512], + min_dims: Union[Size, Tuple[int, int]] = [512, 512], + max_dims: Union[Size, Tuple[int, int]] = [512, 512], ) -> Image.Image: min_x, min_y = min_dims max_x, max_y = max_dims if image.width > max_x or image.height > max_y: - image = ImageOps.contain(image, max_dims) + image = ImageOps.contain(image, (max_x, max_y)) if image.width < min_x or image.height < min_y: - blank = Image.new(image.mode, min_dims, "black") + blank = Image.new(image.mode, (min_x, min_y), "black") blank.paste(image) image = blank