diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index fe8a6532..dc178e30 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -1,5 +1,6 @@ import inspect import logging +from math import ceil from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -13,9 +14,9 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import ( ) from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg -from onnx_web.chain.tile import make_tile_mask - -from ..utils import LATENT_FACTOR, parse_regions, repair_nan +from ...chain.tile import make_tile_mask +from ...params import Size +from ..utils import LATENT_FACTOR, expand_latents, parse_regions, repair_nan logger = logging.getLogger(__name__) @@ -41,13 +42,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix self.window = window self.stride = stride - def get_views(self, panorama_height, panorama_width, window_size, stride): + def get_views( + self, panorama_height: int, panorama_width: int, window_size: int, stride: int + ) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]: # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) panorama_height /= 8 panorama_width /= 8 - num_blocks_height = abs((panorama_height - window_size) // stride) + 1 - num_blocks_width = abs((panorama_width - window_size) // stride) + 1 + num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1 + num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) logger.debug( "panorama generated %s views, %s by %s blocks", @@ -64,7 +68,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) - return views + return (views, (h_end, w_end)) # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents_img2img( @@ -382,9 +386,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) # 8. Panorama additions - views = self.get_views(height, width, self.window, self.stride) - count = np.zeros_like(latents) - value = np.zeros_like(latents) + views, resize = self.get_views(height, width, self.window, self.stride) + count = np.zeros_like((latents[0], latents[1], *resize)) + value = np.zeros_like((latents[0], latents[1], *resize)) + + # adjust latents + latents = expand_latents(latents, generator.randint(), Size(width, height)) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -560,6 +567,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # remove extra margins + latents = latents[:, :, 0:height, 0:width] if output_type == "latent": image = latents else: diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 140ece45..f01c6703 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -276,6 +276,17 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: return image_latents +def expand_latents( + latents: np.ndarray, + seed: int, + size: Size, +) -> np.ndarray: + batch, _channels, height, width = latents.shape + extra_latents = get_latents_from_seed(seed, size, batch=batch) + extra_latents[:, :, 0:height, 0:width] = latents + return extra_latents + + def get_tile_latents( full_latents: np.ndarray, seed: int, @@ -301,11 +312,7 @@ def get_tile_latents( tile_latents = full_latents[:, :, y:yt, x:xt] if tile_latents.shape[2] < t or tile_latents.shape[3] < t: - extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0]) - extra_latents[ - :, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3] - ] = tile_latents - tile_latents = extra_latents + tile_latents = expand_latents(tile_latents, seed, size) return tile_latents