From d9dd1e4b57dd4c174fad000555f5f11b02abdc19 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 10 Jul 2023 17:41:08 -0500 Subject: [PATCH] fix(api): improve handling of non-square images around tile size --- api/onnx_web/chain/source_txt2img.py | 29 +++++++------------------- api/onnx_web/chain/upscale_outpaint.py | 23 +++++++------------- api/onnx_web/params.py | 3 +++ 3 files changed, 17 insertions(+), 38 deletions(-) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 77dd899f..6d399ac6 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -6,12 +6,7 @@ import torch from PIL import Image from ..diffusers.load import load_pipeline -from ..diffusers.utils import ( - encode_prompt, - get_latents_from_seed, - get_tile_latents, - parse_prompt, -) +from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext @@ -54,23 +49,13 @@ class SourceTxt2ImgStage(BaseStage): ) tile_size = params.tiles - - if max(size) > tile_size: - latent_size = Size(tile_size, tile_size) - pipe_width = pipe_height = tile_size - else: - latent_size = Size(size.width, size.height) - pipe_width = size.width - pipe_height = size.height + latent_size = size.min(tile_size, tile_size) # generate new latents or slice existing if latents is None: - # generate new latents latents = get_latents_from_seed(params.seed, latent_size, params.batch) else: - # slice existing latents - latents = get_tile_latents(latents, dims, Size(tile_size, tile_size)) - pipe_width = pipe_height = tile_size + latents = get_tile_latents(latents, dims, latent_size) pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( @@ -87,8 +72,8 @@ class SourceTxt2ImgStage(BaseStage): rng = torch.manual_seed(params.seed) result = pipe.text2img( prompt, - height=pipe_height, - width=pipe_width, + height=latent_size.height, + width=latent_size.width, generator=rng, guidance_scale=params.cfg, latents=latents, @@ -108,8 +93,8 @@ class SourceTxt2ImgStage(BaseStage): rng = np.random.RandomState(params.seed) result = pipe( prompt, - height=pipe_height, - width=pipe_width, + height=latent_size.height, + width=latent_size.width, generator=rng, guidance_scale=params.cfg, latents=latents, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 2949e990..9753b2b3 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -71,24 +71,15 @@ class UpscaleOutpaintStage(BaseStage): outputs.append(source) continue - size = Size(*source.size) tile_size = params.tiles - if max(size) > tile_size: - latent_size = Size(tile_size, tile_size) - pipe_width = pipe_height = tile_size - else: - latent_size = Size(size.width, size.height) - pipe_width = size.width - pipe_height = size.height + size = Size(*source.size) + latent_size = size.min(tile_size, tile_size) # generate new latents or slice existing if latents is None: - # generate new latents latents = get_latents_from_seed(params.seed, latent_size, params.batch) else: - # slice existing latents and make sure there is a complete tile - latents = get_tile_latents(latents, dims, Size(tile_size, tile_size)) - pipe_width = pipe_height = tile_size + latents = get_tile_latents(latents, dims, latent_size) if params.lpw(): logger.debug("using LPW pipeline for inpaint") @@ -98,8 +89,8 @@ class UpscaleOutpaintStage(BaseStage): tile_mask, prompt, negative_prompt=negative_prompt, - height=pipe_height, - width=pipe_width, + height=latent_size.height, + width=latent_size.width, num_inference_steps=params.steps, guidance_scale=params.cfg, generator=rng, @@ -119,8 +110,8 @@ class UpscaleOutpaintStage(BaseStage): source, tile_mask, negative_prompt=negative_prompt, - height=pipe_height, - width=pipe_width, + height=latent_size.height, + width=latent_size.width, num_inference_steps=params.steps, guidance_scale=params.cfg, generator=rng, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 4a077fb8..91a7ffe6 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -81,6 +81,9 @@ class Size: border.top + self.height + border.bottom, ) + def min(self, width: int, height: int): + return Size(min(self.width, width), min(self.height, height)) + def round_to_tile(self, tile=512): return Size( ceil(self.width / tile) * tile,