apply lint
This commit is contained in:
parent
0fa03e77ad
commit
78f834a678
|
@ -39,7 +39,11 @@ class VAEWrapper(object):
|
||||||
self.tile_overlap_factor = overlap
|
self.tile_overlap_factor = overlap
|
||||||
|
|
||||||
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
||||||
model = self.wrapped.model if hasattr(self.wrapped, "model") else self.wrapped.session
|
model = (
|
||||||
|
self.wrapped.model
|
||||||
|
if hasattr(self.wrapped, "model")
|
||||||
|
else self.wrapped.session
|
||||||
|
)
|
||||||
|
|
||||||
# set timestep dtype to input type
|
# set timestep dtype to input type
|
||||||
sample_dtype = next(
|
sample_dtype = next(
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
|
import inspect
|
||||||
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin
|
|
||||||
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
|
|
||||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional, List, Union, Tuple, Callable, Dict
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import inspect
|
import torch
|
||||||
|
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||||
|
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
|
||||||
|
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
||||||
|
StableDiffusionXLImg2ImgPipelineMixin,
|
||||||
|
)
|
||||||
|
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,23 +21,21 @@ DEFAULT_STRIDE = 16
|
||||||
|
|
||||||
class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin):
|
class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
window: int = DEFAULT_WINDOW,
|
window: int = DEFAULT_WINDOW,
|
||||||
stride: int = DEFAULT_STRIDE,
|
stride: int = DEFAULT_STRIDE,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(self, *args, **kwargs)
|
super().__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
self.window = window
|
self.window = window
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
|
|
||||||
def set_window_size(self, window: int, stride: int):
|
def set_window_size(self, window: int, stride: int):
|
||||||
self.window = window
|
self.window = window
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
|
|
||||||
def get_views(self, panorama_height, panorama_width, window_size, stride):
|
def get_views(self, panorama_height, panorama_width, window_size, stride):
|
||||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||||
panorama_height /= 8
|
panorama_height /= 8
|
||||||
|
@ -60,21 +61,32 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
return views
|
return views
|
||||||
|
|
||||||
|
|
||||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
def prepare_latents_img2img(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None):
|
def prepare_latents_img2img(
|
||||||
|
self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None
|
||||||
|
):
|
||||||
batch_size = batch_size * num_images_per_prompt
|
batch_size = batch_size * num_images_per_prompt
|
||||||
|
|
||||||
if image.shape[1] == 4:
|
if image.shape[1] == 4:
|
||||||
init_latents = image
|
init_latents = image
|
||||||
else:
|
else:
|
||||||
init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215)
|
init_latents = self.vae_encoder(sample=image)[
|
||||||
|
0
|
||||||
|
] * self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||||
|
|
||||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
if (
|
||||||
|
batch_size > init_latents.shape[0]
|
||||||
|
and batch_size % init_latents.shape[0] == 0
|
||||||
|
):
|
||||||
# expand init_latents for batch_size
|
# expand init_latents for batch_size
|
||||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||||
init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0)
|
init_latents = np.concatenate(
|
||||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
[init_latents] * additional_image_per_prompt, axis=0
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
batch_size > init_latents.shape[0]
|
||||||
|
and batch_size % init_latents.shape[0] != 0
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||||
)
|
)
|
||||||
|
@ -84,14 +96,29 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
# add noise to latents using the timesteps
|
# add noise to latents using the timesteps
|
||||||
noise = generator.randn(*init_latents.shape).astype(dtype)
|
noise = generator.randn(*init_latents.shape).astype(dtype)
|
||||||
init_latents = self.scheduler.add_noise(
|
init_latents = self.scheduler.add_noise(
|
||||||
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timestep)
|
torch.from_numpy(init_latents),
|
||||||
|
torch.from_numpy(noise),
|
||||||
|
torch.from_numpy(timestep),
|
||||||
)
|
)
|
||||||
return init_latents.numpy()
|
return init_latents.numpy()
|
||||||
|
|
||||||
|
|
||||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
def prepare_latents_text2img(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
def prepare_latents_text2img(
|
||||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
self,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height // self.vae_scale_factor,
|
||||||
|
width // self.vae_scale_factor,
|
||||||
|
)
|
||||||
if isinstance(generator, list) and len(generator) != batch_size:
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
@ -101,14 +128,15 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = generator.randn(*shape).astype(dtype)
|
latents = generator.randn(*shape).astype(dtype)
|
||||||
elif latents.shape != shape:
|
elif latents.shape != shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
raise ValueError(
|
||||||
|
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
||||||
|
)
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
def prepare_extra_step_kwargs(self, generator, eta):
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
@ -118,13 +146,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
extra_step_kwargs = {}
|
extra_step_kwargs = {}
|
||||||
|
|
||||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
accepts_eta = "eta" in set(
|
||||||
|
inspect.signature(self.scheduler.step).parameters.keys()
|
||||||
|
)
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
return extra_step_kwargs
|
return extra_step_kwargs
|
||||||
|
|
||||||
|
|
||||||
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
||||||
def text2img(
|
def text2img(
|
||||||
self,
|
self,
|
||||||
|
@ -294,10 +323,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype)
|
add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype)
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0)
|
prompt_embeds = np.concatenate(
|
||||||
add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0)
|
(negative_prompt_embeds, prompt_embeds), axis=0
|
||||||
|
)
|
||||||
|
add_text_embeds = np.concatenate(
|
||||||
|
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
|
||||||
|
)
|
||||||
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
||||||
add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0)
|
add_time_ids = np.repeat(
|
||||||
|
add_time_ids, batch_size * num_images_per_prompt, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
# Adapted from diffusers to extend it for other runtimes than ORT
|
# Adapted from diffusers to extend it for other runtimes than ORT
|
||||||
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
||||||
|
@ -318,8 +353,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
|
latent_model_input = (
|
||||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
np.concatenate([latents_for_view] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents_for_view
|
||||||
|
)
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
torch.from_numpy(latent_model_input), t
|
||||||
|
)
|
||||||
latent_model_input = latent_model_input.cpu().numpy()
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
|
@ -336,14 +377,23 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
if guidance_rescale > 0.0:
|
if guidance_rescale > 0.0:
|
||||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred,
|
||||||
|
noise_pred_text,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
scheduler_output = self.scheduler.step(
|
scheduler_output = self.scheduler.step(
|
||||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
|
torch.from_numpy(noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents_for_view),
|
||||||
|
**extra_step_kwargs,
|
||||||
)
|
)
|
||||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
@ -354,7 +404,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
@ -364,7 +416,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
image = np.concatenate(
|
image = np.concatenate(
|
||||||
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
[
|
||||||
|
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||||
|
for i in range(latents.shape[0])
|
||||||
|
]
|
||||||
)
|
)
|
||||||
image = self.watermark.apply_watermark(image)
|
image = self.watermark.apply_watermark(image)
|
||||||
|
|
||||||
|
@ -379,7 +434,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
return StableDiffusionXLPipelineOutput(images=image)
|
return StableDiffusionXLPipelineOutput(images=image)
|
||||||
|
|
||||||
|
|
||||||
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
||||||
def img2img(
|
def img2img(
|
||||||
self,
|
self,
|
||||||
|
@ -481,7 +535,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
(nsfw) content, according to the `safety_checker`.
|
(nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
# 0. Check inputs. Raise error if not correct
|
# 0. Check inputs. Raise error if not correct
|
||||||
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
strength,
|
||||||
|
callback_steps,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
# 1. Define call parameters
|
# 1. Define call parameters
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
|
@ -522,8 +583,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
timesteps, num_inference_steps = self.get_timesteps(
|
||||||
latent_timestep = np.repeat(timesteps[:1], batch_size * num_images_per_prompt, axis=0)
|
num_inference_steps, strength
|
||||||
|
)
|
||||||
|
latent_timestep = np.repeat(
|
||||||
|
timesteps[:1], batch_size * num_images_per_prompt, axis=0
|
||||||
|
)
|
||||||
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
||||||
|
|
||||||
latents_dtype = prompt_embeds.dtype
|
latents_dtype = prompt_embeds.dtype
|
||||||
|
@ -531,12 +596,19 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
# 5. Prepare latent variables
|
# 5. Prepare latent variables
|
||||||
latents = self.prepare_latents_img2img(
|
latents = self.prepare_latents_img2img(
|
||||||
image, latent_timestep, batch_size, num_images_per_prompt, latents_dtype, generator
|
image,
|
||||||
|
latent_timestep,
|
||||||
|
batch_size,
|
||||||
|
num_images_per_prompt,
|
||||||
|
latents_dtype,
|
||||||
|
generator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. Prepare extra step kwargs
|
# 6. Prepare extra step kwargs
|
||||||
extra_step_kwargs = {}
|
extra_step_kwargs = {}
|
||||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
accepts_eta = "eta" in set(
|
||||||
|
inspect.signature(self.scheduler.step).parameters.keys()
|
||||||
|
)
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
@ -558,10 +630,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
)
|
)
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0)
|
prompt_embeds = np.concatenate(
|
||||||
add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0)
|
(negative_prompt_embeds, prompt_embeds), axis=0
|
||||||
|
)
|
||||||
|
add_text_embeds = np.concatenate(
|
||||||
|
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
|
||||||
|
)
|
||||||
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
||||||
add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0)
|
add_time_ids = np.repeat(
|
||||||
|
add_time_ids, batch_size * num_images_per_prompt, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
# 8. Panorama additions
|
# 8. Panorama additions
|
||||||
views = self.get_views(height, width, self.window, self.stride)
|
views = self.get_views(height, width, self.window, self.stride)
|
||||||
|
@ -579,8 +657,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
|
latent_model_input = (
|
||||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
np.concatenate([latents_for_view] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents_for_view
|
||||||
|
)
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
torch.from_numpy(latent_model_input), t
|
||||||
|
)
|
||||||
latent_model_input = latent_model_input.cpu().numpy()
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
|
@ -597,14 +681,23 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
if guidance_rescale > 0.0:
|
if guidance_rescale > 0.0:
|
||||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred,
|
||||||
|
noise_pred_text,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
scheduler_output = self.scheduler.step(
|
scheduler_output = self.scheduler.step(
|
||||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
|
torch.from_numpy(noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents_for_view),
|
||||||
|
**extra_step_kwargs,
|
||||||
)
|
)
|
||||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
@ -615,7 +708,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
@ -625,7 +720,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
image = np.concatenate(
|
image = np.concatenate(
|
||||||
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
[
|
||||||
|
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||||
|
for i in range(latents.shape[0])
|
||||||
|
]
|
||||||
)
|
)
|
||||||
image = self.watermark.apply_watermark(image)
|
image = self.watermark.apply_watermark(image)
|
||||||
|
|
||||||
|
@ -640,7 +738,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
return StableDiffusionXLPipelineOutput(images=image)
|
return StableDiffusionXLPipelineOutput(images=image)
|
||||||
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
|
@ -659,6 +756,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
return self.text2img(*args, **kwargs)
|
return self.text2img(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ORTStableDiffusionXLPanoramaPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin):
|
class ORTStableDiffusionXLPanoramaPipeline(
|
||||||
|
ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin
|
||||||
|
):
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs)
|
return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs)
|
||||||
|
|
|
@ -259,7 +259,14 @@ class ImageParams:
|
||||||
|
|
||||||
# otherwise, check for additional allowed pipelines
|
# otherwise, check for additional allowed pipelines
|
||||||
if group == "img2img":
|
if group == "img2img":
|
||||||
if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "panorama-sdxl", "pix2pix"]:
|
if pipeline in [
|
||||||
|
"controlnet",
|
||||||
|
"img2img-sdxl",
|
||||||
|
"lpw",
|
||||||
|
"panorama",
|
||||||
|
"panorama-sdxl",
|
||||||
|
"pix2pix",
|
||||||
|
]:
|
||||||
return pipeline
|
return pipeline
|
||||||
elif pipeline == "txt2img-sdxl":
|
elif pipeline == "txt2img-sdxl":
|
||||||
return "img2img-sdxl"
|
return "img2img-sdxl"
|
||||||
|
|
Loading…
Reference in New Issue