From 5e1b70091c2d6a7f35da0eda6ca37887fe17e258 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 30 Jun 2023 23:22:38 -0500 Subject: [PATCH] feat(api): remove size restrictions on most pipelines --- api/onnx_web/__init__.py | 1 - api/onnx_web/chain/blend_linear.py | 5 +---- api/onnx_web/chain/blend_mask.py | 5 ++--- api/onnx_web/chain/reduce_thumbnail.py | 1 - api/onnx_web/image/__init__.py | 1 - api/onnx_web/image/utils.py | 21 --------------------- api/onnx_web/server/api.py | 13 ------------- 7 files changed, 3 insertions(+), 44 deletions(-) diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 77d49e14..a4f1cdaf 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -20,7 +20,6 @@ from .diffusers.stub_scheduler import StubScheduler from .diffusers.upscale import stage_upscale_correction from .image.utils import ( expand_image, - valid_image, ) from .image.mask_filter import ( mask_filter_gaussian_multiply, diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index d1e6e733..69f79429 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -3,7 +3,6 @@ from typing import List, Optional from PIL import Image -from ..image import valid_image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext @@ -24,6 +23,4 @@ def blend_linear( ) -> Image.Image: logger.info("blending image using linear interpolation") - resized = [valid_image(s) for s in sources] - - return Image.blend(resized[1], resized[0], alpha) + return Image.blend(sources[1], sources[0], alpha) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index c5840b30..99aba6b9 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -1,9 +1,8 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image -from ..image import valid_image from ..output import save_image from ..params import ImageParams, StageParams from ..server import ServerContext @@ -35,4 +34,4 @@ def blend_mask( save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) - return Image.composite(source, stage_source, mult_mask) + return Image.composite(stage_source, source, mult_mask) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 6df2ed6e..c51cda7a 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -23,7 +23,6 @@ def reduce_thumbnail( source = stage_source or source image = source.copy() - # TODO: should use a call to valid_image image = image.thumbnail((size.width, size.height)) logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) diff --git a/api/onnx_web/image/__init__.py b/api/onnx_web/image/__init__.py index 84be6e13..b0133d41 100644 --- a/api/onnx_web/image/__init__.py +++ b/api/onnx_web/image/__init__.py @@ -1,6 +1,5 @@ from .utils import ( expand_image, - valid_image, ) from .mask_filter import ( mask_filter_gaussian_multiply, diff --git a/api/onnx_web/image/utils.py b/api/onnx_web/image/utils.py index 643f5595..ee8dc59e 100644 --- a/api/onnx_web/image/utils.py +++ b/api/onnx_web/image/utils.py @@ -31,24 +31,3 @@ def expand_image( full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) return (full_source, full_mask, full_noise, size) - - -def valid_image( - image: Image.Image, - min_dims: Union[Size, Tuple[int, int]] = [512, 512], - max_dims: Union[Size, Tuple[int, int]] = [512, 512], -) -> Image.Image: - min_x, min_y = min_dims - max_x, max_y = max_dims - - if image.width > max_x or image.height > max_y: - image = ImageOps.contain(image, (max_x, max_y)) - - if image.width < min_x or image.height < min_y: - blank = Image.new(image.mode, (min_x, min_y), "black") - blank.paste(image) - image = blank - - # check for square - - return image diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 972c23d9..49d2be93 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -15,7 +15,6 @@ from ..diffusers.run import ( run_txt2img_pipeline, run_upscale_pipeline, ) -from ..image import valid_image # mask filters; noise sources from ..output import json_params, make_output_name from ..params import Border, StageParams, TileOrder, UpscaleParams from ..transformers.run import run_txt2txt_pipeline @@ -189,12 +188,6 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): server, "img2img", params, size, extras=[strength], count=output_count ) - if params.get_valid_pipeline("img2img") != "panorama": - logger.info( - "resizing input image for limited pipeline, use panorama pipeline for full-size" - ) - source = valid_image(source, min_dims=size, max_dims=size) - job_name = output[0] pool.submit( job_name, @@ -323,9 +316,6 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): output = make_output_name(server, "upscale", params, size) - logger.info("resizing source image for limited pipeline") - source = valid_image(source, min_dims=size, max_dims=size) - job_name = output[0] pool.submit( job_name, @@ -396,7 +386,6 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): source_file = request.files.get(stage_source_name) if source_file is not None: source = Image.open(BytesIO(source_file.read())).convert("RGB") - source = valid_image(source, max_dims=(size.width, size.height)) kwargs["stage_source"] = source if stage_mask_name in request.files: @@ -408,7 +397,6 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): mask_file = request.files.get(stage_mask_name) if mask_file is not None: mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - mask = valid_image(mask, max_dims=(size.width, size.height)) kwargs["stage_mask"] = mask pipeline.append((callback, stage, kwargs)) @@ -447,7 +435,6 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): logger.warning("missing source %s", i) else: source = Image.open(BytesIO(source_file.read())).convert("RGBA") - source = valid_image(source, mask.size, mask.size) sources.append(source) device, params, size = pipeline_from_request(server)