1
0
Fork 0

feat(api): remove size restrictions on most pipelines

This commit is contained in:
Sean Sube 2023-06-30 23:22:38 -05:00
parent 934dabb39e
commit 5e1b70091c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 3 additions and 44 deletions

View File

@ -20,7 +20,6 @@ from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import stage_upscale_correction
from .image.utils import (
expand_image,
valid_image,
)
from .image.mask_filter import (
mask_filter_gaussian_multiply,

View File

@ -3,7 +3,6 @@ from typing import List, Optional
from PIL import Image
from ..image import valid_image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
@ -24,6 +23,4 @@ def blend_linear(
) -> Image.Image:
logger.info("blending image using linear interpolation")
resized = [valid_image(s) for s in sources]
return Image.blend(resized[1], resized[0], alpha)
return Image.blend(sources[1], sources[0], alpha)

View File

@ -1,9 +1,8 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
from ..image import valid_image
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server import ServerContext
@ -35,4 +34,4 @@ def blend_mask(
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)
return Image.composite(source, stage_source, mult_mask)
return Image.composite(stage_source, source, mult_mask)

View File

@ -23,7 +23,6 @@ def reduce_thumbnail(
source = stage_source or source
image = source.copy()
# TODO: should use a call to valid_image
image = image.thumbnail((size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)

View File

@ -1,6 +1,5 @@
from .utils import (
expand_image,
valid_image,
)
from .mask_filter import (
mask_filter_gaussian_multiply,

View File

@ -31,24 +31,3 @@ def expand_image(
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
return (full_source, full_mask, full_noise, size)
def valid_image(
image: Image.Image,
min_dims: Union[Size, Tuple[int, int]] = [512, 512],
max_dims: Union[Size, Tuple[int, int]] = [512, 512],
) -> Image.Image:
min_x, min_y = min_dims
max_x, max_y = max_dims
if image.width > max_x or image.height > max_y:
image = ImageOps.contain(image, (max_x, max_y))
if image.width < min_x or image.height < min_y:
blank = Image.new(image.mode, (min_x, min_y), "black")
blank.paste(image)
image = blank
# check for square
return image

View File

@ -15,7 +15,6 @@ from ..diffusers.run import (
run_txt2img_pipeline,
run_upscale_pipeline,
)
from ..image import valid_image # mask filters; noise sources
from ..output import json_params, make_output_name
from ..params import Border, StageParams, TileOrder, UpscaleParams
from ..transformers.run import run_txt2txt_pipeline
@ -189,12 +188,6 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
server, "img2img", params, size, extras=[strength], count=output_count
)
if params.get_valid_pipeline("img2img") != "panorama":
logger.info(
"resizing input image for limited pipeline, use panorama pipeline for full-size"
)
source = valid_image(source, min_dims=size, max_dims=size)
job_name = output[0]
pool.submit(
job_name,
@ -323,9 +316,6 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
output = make_output_name(server, "upscale", params, size)
logger.info("resizing source image for limited pipeline")
source = valid_image(source, min_dims=size, max_dims=size)
job_name = output[0]
pool.submit(
job_name,
@ -396,7 +386,6 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get(stage_source_name)
if source_file is not None:
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source
if stage_mask_name in request.files:
@ -408,7 +397,6 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
mask_file = request.files.get(stage_mask_name)
if mask_file is not None:
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))
@ -447,7 +435,6 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
logger.warning("missing source %s", i)
else:
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request(server)