fix(api): complete panorama tiles for SD pipeline
This commit is contained in:
parent
103d1a449a
commit
b54a57b379
|
@ -26,9 +26,15 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
|
|||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
|
||||
from onnx_web.chain.tile import make_tile_mask
|
||||
|
||||
from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions, repair_nan
|
||||
from ...chain.tile import make_tile_mask
|
||||
from ...params import Size
|
||||
from ..utils import (
|
||||
LATENT_CHANNELS,
|
||||
LATENT_FACTOR,
|
||||
expand_latents,
|
||||
parse_regions,
|
||||
repair_nan,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -373,7 +379,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
w_end = w_start + window_size
|
||||
views.append((h_start, h_end, w_start, w_end))
|
||||
|
||||
return views
|
||||
return (views, (h_end, w_end))
|
||||
|
||||
@torch.no_grad()
|
||||
def text2img(
|
||||
|
@ -552,10 +558,17 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width, self.window, self.stride)
|
||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
|
||||
latents = expand_latents(
|
||||
latents,
|
||||
generator.randint(),
|
||||
Size(width, height),
|
||||
sigma=self.scheduler.init_noise_sigma,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
last = i == (len(self.scheduler.timesteps) - 1)
|
||||
count.fill(0)
|
||||
|
@ -707,6 +720,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# remove extra margins
|
||||
latents = latents[:, :, 0:height, 0:width]
|
||||
|
||||
latents = np.clip(latents, -4, +4)
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
|
|
|
@ -391,7 +391,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
value = np.zeros_like((latents[0], latents[1], *resize))
|
||||
|
||||
# adjust latents
|
||||
latents = expand_latents(latents, generator.randint(), Size(width, height))
|
||||
latents = expand_latents(
|
||||
latents,
|
||||
generator.randint(),
|
||||
Size(width, height),
|
||||
sigma=self.scheduler.init_noise_sigma,
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
|
|
@ -280,11 +280,12 @@ def expand_latents(
|
|||
latents: np.ndarray,
|
||||
seed: int,
|
||||
size: Size,
|
||||
sigma: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
batch, _channels, height, width = latents.shape
|
||||
extra_latents = get_latents_from_seed(seed, size, batch=batch)
|
||||
extra_latents[:, :, 0:height, 0:width] = latents
|
||||
return extra_latents
|
||||
return extra_latents * np.float64(sigma)
|
||||
|
||||
|
||||
def get_tile_latents(
|
||||
|
|
Loading…
Reference in New Issue