From b54a57b379fbbae37b954cc6a0bbba3fa807ab7b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 2 Dec 2023 20:21:31 -0600 Subject: [PATCH] fix(api): complete panorama tiles for SD pipeline --- api/onnx_web/diffusers/pipelines/panorama.py | 26 +++++++++++++++---- .../diffusers/pipelines/panorama_xl.py | 7 ++++- api/onnx_web/diffusers/utils.py | 3 ++- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 11511fa7..54383ca0 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -26,9 +26,15 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from transformers import CLIPImageProcessor, CLIPTokenizer -from onnx_web.chain.tile import make_tile_mask - -from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions, repair_nan +from ...chain.tile import make_tile_mask +from ...params import Size +from ..utils import ( + LATENT_CHANNELS, + LATENT_FACTOR, + expand_latents, + parse_regions, + repair_nan, +) logger = logging.get_logger(__name__) @@ -373,7 +379,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) - return views + return (views, (h_end, w_end)) @torch.no_grad() def text2img( @@ -552,10 +558,17 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] # panorama additions - views = self.get_views(height, width, self.window, self.stride) + views, resize = self.get_views(height, width, self.window, self.stride) count = np.zeros_like(latents) value = np.zeros_like(latents) + latents = expand_latents( + latents, + generator.randint(), + Size(width, height), + sigma=self.scheduler.init_noise_sigma, + ) + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): last = i == (len(self.scheduler.timesteps) - 1) count.fill(0) @@ -707,6 +720,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # remove extra margins + latents = latents[:, :, 0:height, 0:width] + latents = np.clip(latents, -4, +4) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index dc178e30..d42936db 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -391,7 +391,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix value = np.zeros_like((latents[0], latents[1], *resize)) # adjust latents - latents = expand_latents(latents, generator.randint(), Size(width, height)) + latents = expand_latents( + latents, + generator.randint(), + Size(width, height), + sigma=self.scheduler.init_noise_sigma, + ) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index f01c6703..38687204 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -280,11 +280,12 @@ def expand_latents( latents: np.ndarray, seed: int, size: Size, + sigma: float = 1.0, ) -> np.ndarray: batch, _channels, height, width = latents.shape extra_latents = get_latents_from_seed(seed, size, batch=batch) extra_latents[:, :, 0:height, 0:width] = latents - return extra_latents + return extra_latents * np.float64(sigma) def get_tile_latents(