fix(api): resize images to min dimensions by padding if necessary (#172)
This commit is contained in:
parent
3dde3b9237
commit
0e108daa0f
|
@ -2,6 +2,7 @@ from logging import getLogger
|
|||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
from onnx_web.image import valid_image
|
||||
|
||||
from onnx_web.output import save_image
|
||||
|
||||
|
@ -18,7 +19,7 @@ def blend_mask(
|
|||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
*,
|
||||
sources: Optional[List[Image.Image]] = None,
|
||||
resized: Optional[List[Image.Image]] = None,
|
||||
mask: Optional[Image.Image] = None,
|
||||
_callback: ProgressCallback = None,
|
||||
**kwargs,
|
||||
|
@ -33,7 +34,6 @@ def blend_mask(
|
|||
save_image(server, "last-mask.png", mask)
|
||||
save_image(server, "last-mult-mask.png", mult_mask)
|
||||
|
||||
for source in sources:
|
||||
source.thumbnail(mult_mask.size)
|
||||
resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]
|
||||
|
||||
return Image.composite(sources[0], sources[1], mult_mask)
|
||||
return Image.composite(resized[0], resized[1], mult_mask)
|
||||
|
|
|
@ -255,7 +255,7 @@ def run_blend_pipeline(
|
|||
server,
|
||||
stage,
|
||||
params,
|
||||
sources=sources,
|
||||
resized=sources,
|
||||
mask=mask,
|
||||
callback=progress,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import numpy as np
|
||||
from numpy import random
|
||||
from PIL import Image, ImageChops, ImageFilter
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
from typing import Tuple
|
||||
|
||||
from .params import Border, Point
|
||||
|
||||
|
@ -189,3 +190,24 @@ def expand_image(
|
|||
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
|
||||
|
||||
return (full_source, full_mask, full_noise, (full_width, full_height))
|
||||
|
||||
|
||||
def valid_image(
|
||||
image: Image.Image,
|
||||
min_dims: Tuple[int, int] = [512, 512],
|
||||
max_dims: Tuple[int, int] = [512, 512],
|
||||
) -> Image.Image:
|
||||
min_x, min_y = min_dims
|
||||
max_x, max_y = max_dims
|
||||
|
||||
if image.width > max_x or image.height > max_y:
|
||||
image = ImageOps.contain(image, max_dims)
|
||||
|
||||
if image.width < min_x or image.height < min_y:
|
||||
blank = Image.new(image.mode, min_dims, "black")
|
||||
blank.paste(image)
|
||||
image = blank
|
||||
|
||||
# check for square
|
||||
|
||||
return image
|
|
@ -49,6 +49,7 @@ from .image import ( # mask filters; noise sources
|
|||
noise_source_histogram,
|
||||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
valid_image,
|
||||
)
|
||||
from .output import json_params, make_output_name
|
||||
from .params import (
|
||||
|
@ -508,7 +509,7 @@ def img2img():
|
|||
output = make_output_name(context, "img2img", params, size, extras=(strength,))
|
||||
logger.info("img2img job queued for: %s", output)
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
source_image = valid_image(source_image, min_dims=size, max_dims=size)
|
||||
executor.submit(
|
||||
output,
|
||||
run_img2img_pipeline,
|
||||
|
@ -597,8 +598,8 @@ def inpaint():
|
|||
)
|
||||
logger.info("inpaint job queued for: %s", output)
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
mask_image.thumbnail((size.width, size.height))
|
||||
source_image = valid_image(source_image, min_dims=size, max_dims=size)
|
||||
mask_image = valid_image(mask_image, min_dims=size, max_dims=size)
|
||||
executor.submit(
|
||||
output,
|
||||
run_inpaint_pipeline,
|
||||
|
@ -635,7 +636,7 @@ def upscale():
|
|||
output = make_output_name(context, "upscale", params, size)
|
||||
logger.info("upscale job queued for: %s", output)
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
source_image = valid_image(source_image, min_dims=size, max_dims=size)
|
||||
executor.submit(
|
||||
output,
|
||||
run_upscale_pipeline,
|
||||
|
@ -702,7 +703,7 @@ def chain():
|
|||
)
|
||||
source_file = request.files.get(stage_source_name)
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
source_image = valid_image(source_image, max_dims=(size.width, size.height))
|
||||
kwargs["source_image"] = source_image
|
||||
|
||||
if stage_mask_name in request.files:
|
||||
|
@ -713,7 +714,7 @@ def chain():
|
|||
)
|
||||
mask_file = request.files.get(stage_mask_name)
|
||||
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
mask_image.thumbnail((size.width, size.height))
|
||||
mask_image = valid_image(mask_image, max_dims=(size.width, size.height))
|
||||
kwargs["mask_image"] = mask_image
|
||||
|
||||
pipeline.append((callback, stage, kwargs))
|
||||
|
@ -743,13 +744,16 @@ def blend():
|
|||
|
||||
mask_file = request.files.get("mask")
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||
mask = valid_image(mask)
|
||||
|
||||
max_sources = 2
|
||||
sources = []
|
||||
|
||||
for i in range(max_sources):
|
||||
source_file = request.files.get("source:%s" % (i))
|
||||
sources.append(Image.open(BytesIO(source_file.read())).convert("RGBA"))
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
||||
source = valid_image(source, mask.size, mask.size)
|
||||
sources.append(source)
|
||||
|
||||
device, params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
|
Loading…
Reference in New Issue