1
0
Fork 0

fix(api): resize latents to complete panorama blocks

This commit is contained in:
Sean Sube 2023-12-02 20:06:27 -06:00
parent 0b31ad0ab6
commit 103d1a449a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 31 additions and 15 deletions

View File

@ -1,5 +1,6 @@
import inspect
import logging
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
@ -13,9 +14,9 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
)
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
from onnx_web.chain.tile import make_tile_mask
from ..utils import LATENT_FACTOR, parse_regions, repair_nan
from ...chain.tile import make_tile_mask
from ...params import Size
from ..utils import LATENT_FACTOR, expand_latents, parse_regions, repair_nan
logger = logging.getLogger(__name__)
@ -41,13 +42,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
self.window = window
self.stride = stride
def get_views(self, panorama_height, panorama_width, window_size, stride):
def get_views(
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
num_blocks_height = abs((panorama_height - window_size) // stride) + 1
num_blocks_width = abs((panorama_width - window_size) // stride) + 1
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
logger.debug(
"panorama generated %s views, %s by %s blocks",
@ -64,7 +68,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))
return views
return (views, (h_end, w_end))
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_img2img(
@ -382,9 +386,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
# 8. Panorama additions
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like((latents[0], latents[1], *resize))
value = np.zeros_like((latents[0], latents[1], *resize))
# adjust latents
latents = expand_latents(latents, generator.randint(), Size(width, height))
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@ -560,6 +567,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:height, 0:width]
if output_type == "latent":
image = latents
else:

View File

@ -276,6 +276,17 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
return image_latents
def expand_latents(
latents: np.ndarray,
seed: int,
size: Size,
) -> np.ndarray:
batch, _channels, height, width = latents.shape
extra_latents = get_latents_from_seed(seed, size, batch=batch)
extra_latents[:, :, 0:height, 0:width] = latents
return extra_latents
def get_tile_latents(
full_latents: np.ndarray,
seed: int,
@ -301,11 +312,7 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
extra_latents[
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
] = tile_latents
tile_latents = extra_latents
tile_latents = expand_latents(tile_latents, seed, size)
return tile_latents