1
0
Fork 0

fix(api): complete panorama tiles for SD pipeline

This commit is contained in:
Sean Sube 2023-12-02 20:21:31 -06:00
parent 103d1a449a
commit b54a57b379
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 29 additions and 7 deletions

View File

@ -26,9 +26,15 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from onnx_web.chain.tile import make_tile_mask from ...chain.tile import make_tile_mask
from ...params import Size
from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions, repair_nan from ..utils import (
LATENT_CHANNELS,
LATENT_FACTOR,
expand_latents,
parse_regions,
repair_nan,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -373,7 +379,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
w_end = w_start + window_size w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end)) views.append((h_start, h_end, w_start, w_end))
return views return (views, (h_end, w_end))
@torch.no_grad() @torch.no_grad()
def text2img( def text2img(
@ -552,10 +558,17 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions # panorama additions
views = self.get_views(height, width, self.window, self.stride) views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents) count = np.zeros_like(latents)
value = np.zeros_like(latents) value = np.zeros_like(latents)
latents = expand_latents(
latents,
generator.randint(),
Size(width, height),
sigma=self.scheduler.init_noise_sigma,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
last = i == (len(self.scheduler.timesteps) - 1) last = i == (len(self.scheduler.timesteps) - 1)
count.fill(0) count.fill(0)
@ -707,6 +720,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# remove extra margins
latents = latents[:, :, 0:height, 0:width]
latents = np.clip(latents, -4, +4) latents = np.clip(latents, -4, +4)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]

View File

@ -391,7 +391,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
value = np.zeros_like((latents[0], latents[1], *resize)) value = np.zeros_like((latents[0], latents[1], *resize))
# adjust latents # adjust latents
latents = expand_latents(latents, generator.randint(), Size(width, height)) latents = expand_latents(
latents,
generator.randint(),
Size(width, height),
sigma=self.scheduler.init_noise_sigma,
)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

View File

@ -280,11 +280,12 @@ def expand_latents(
latents: np.ndarray, latents: np.ndarray,
seed: int, seed: int,
size: Size, size: Size,
sigma: float = 1.0,
) -> np.ndarray: ) -> np.ndarray:
batch, _channels, height, width = latents.shape batch, _channels, height, width = latents.shape
extra_latents = get_latents_from_seed(seed, size, batch=batch) extra_latents = get_latents_from_seed(seed, size, batch=batch)
extra_latents[:, :, 0:height, 0:width] = latents extra_latents[:, :, 0:height, 0:width] = latents
return extra_latents return extra_latents * np.float64(sigma)
def get_tile_latents( def get_tile_latents(