fix(api): resize latents to complete panorama blocks
This commit is contained in:
parent
0b31ad0ab6
commit
103d1a449a
|
@ -1,5 +1,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
from math import ceil
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -13,9 +14,9 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
||||||
)
|
)
|
||||||
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
|
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
|
||||||
|
|
||||||
from onnx_web.chain.tile import make_tile_mask
|
from ...chain.tile import make_tile_mask
|
||||||
|
from ...params import Size
|
||||||
from ..utils import LATENT_FACTOR, parse_regions, repair_nan
|
from ..utils import LATENT_FACTOR, expand_latents, parse_regions, repair_nan
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -41,13 +42,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
self.window = window
|
self.window = window
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
def get_views(self, panorama_height, panorama_width, window_size, stride):
|
def get_views(
|
||||||
|
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
|
||||||
|
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
|
||||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||||
panorama_height /= 8
|
panorama_height /= 8
|
||||||
panorama_width /= 8
|
panorama_width /= 8
|
||||||
|
|
||||||
num_blocks_height = abs((panorama_height - window_size) // stride) + 1
|
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
|
||||||
num_blocks_width = abs((panorama_width - window_size) // stride) + 1
|
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
|
||||||
|
|
||||||
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"panorama generated %s views, %s by %s blocks",
|
"panorama generated %s views, %s by %s blocks",
|
||||||
|
@ -64,7 +68,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
w_end = w_start + window_size
|
w_end = w_start + window_size
|
||||||
views.append((h_start, h_end, w_start, w_end))
|
views.append((h_start, h_end, w_start, w_end))
|
||||||
|
|
||||||
return views
|
return (views, (h_end, w_end))
|
||||||
|
|
||||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
def prepare_latents_img2img(
|
def prepare_latents_img2img(
|
||||||
|
@ -382,9 +386,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
||||||
|
|
||||||
# 8. Panorama additions
|
# 8. Panorama additions
|
||||||
views = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
count = np.zeros_like(latents)
|
count = np.zeros_like((latents[0], latents[1], *resize))
|
||||||
value = np.zeros_like(latents)
|
value = np.zeros_like((latents[0], latents[1], *resize))
|
||||||
|
|
||||||
|
# adjust latents
|
||||||
|
latents = expand_latents(latents, generator.randint(), Size(width, height))
|
||||||
|
|
||||||
# 8. Denoising loop
|
# 8. Denoising loop
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
@ -560,6 +567,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# remove extra margins
|
||||||
|
latents = latents[:, :, 0:height, 0:width]
|
||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
image = latents
|
image = latents
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -276,6 +276,17 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
return image_latents
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
|
def expand_latents(
|
||||||
|
latents: np.ndarray,
|
||||||
|
seed: int,
|
||||||
|
size: Size,
|
||||||
|
) -> np.ndarray:
|
||||||
|
batch, _channels, height, width = latents.shape
|
||||||
|
extra_latents = get_latents_from_seed(seed, size, batch=batch)
|
||||||
|
extra_latents[:, :, 0:height, 0:width] = latents
|
||||||
|
return extra_latents
|
||||||
|
|
||||||
|
|
||||||
def get_tile_latents(
|
def get_tile_latents(
|
||||||
full_latents: np.ndarray,
|
full_latents: np.ndarray,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
@ -301,11 +312,7 @@ def get_tile_latents(
|
||||||
tile_latents = full_latents[:, :, y:yt, x:xt]
|
tile_latents = full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
||||||
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
|
tile_latents = expand_latents(tile_latents, seed, size)
|
||||||
extra_latents[
|
|
||||||
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
|
|
||||||
] = tile_latents
|
|
||||||
tile_latents = extra_latents
|
|
||||||
|
|
||||||
return tile_latents
|
return tile_latents
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue