1
0
Fork 0

fix(api): resize latents to complete panorama blocks

This commit is contained in:
Sean Sube 2023-12-02 20:06:27 -06:00
parent 0b31ad0ab6
commit 103d1a449a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 31 additions and 15 deletions

View File

@ -1,5 +1,6 @@
import inspect import inspect
import logging import logging
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -13,9 +14,9 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
) )
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
from onnx_web.chain.tile import make_tile_mask from ...chain.tile import make_tile_mask
from ...params import Size
from ..utils import LATENT_FACTOR, parse_regions, repair_nan from ..utils import LATENT_FACTOR, expand_latents, parse_regions, repair_nan
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,13 +42,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
self.window = window self.window = window
self.stride = stride self.stride = stride
def get_views(self, panorama_height, panorama_width, window_size, stride): def get_views(
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8 panorama_height /= 8
panorama_width /= 8 panorama_width /= 8
num_blocks_height = abs((panorama_height - window_size) // stride) + 1 num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
num_blocks_width = abs((panorama_width - window_size) // stride) + 1 num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
total_num_blocks = int(num_blocks_height * num_blocks_width) total_num_blocks = int(num_blocks_height * num_blocks_width)
logger.debug( logger.debug(
"panorama generated %s views, %s by %s blocks", "panorama generated %s views, %s by %s blocks",
@ -64,7 +68,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
w_end = w_start + window_size w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end)) views.append((h_start, h_end, w_start, w_end))
return views return (views, (h_end, w_end))
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_img2img( def prepare_latents_img2img(
@ -382,9 +386,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
# 8. Panorama additions # 8. Panorama additions
views = self.get_views(height, width, self.window, self.stride) views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents) count = np.zeros_like((latents[0], latents[1], *resize))
value = np.zeros_like(latents) value = np.zeros_like((latents[0], latents[1], *resize))
# adjust latents
latents = expand_latents(latents, generator.randint(), Size(width, height))
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@ -560,6 +567,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:height, 0:width]
if output_type == "latent": if output_type == "latent":
image = latents image = latents
else: else:

View File

@ -276,6 +276,17 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
return image_latents return image_latents
def expand_latents(
latents: np.ndarray,
seed: int,
size: Size,
) -> np.ndarray:
batch, _channels, height, width = latents.shape
extra_latents = get_latents_from_seed(seed, size, batch=batch)
extra_latents[:, :, 0:height, 0:width] = latents
return extra_latents
def get_tile_latents( def get_tile_latents(
full_latents: np.ndarray, full_latents: np.ndarray,
seed: int, seed: int,
@ -301,11 +312,7 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt] tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape[2] < t or tile_latents.shape[3] < t: if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0]) tile_latents = expand_latents(tile_latents, seed, size)
extra_latents[
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
] = tile_latents
tile_latents = extra_latents
return tile_latents return tile_latents