1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-09-10 11:26:55 -05:00
parent 0fa03e77ad
commit 78f834a678
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 168 additions and 58 deletions

View File

@ -39,7 +39,11 @@ class VAEWrapper(object):
self.tile_overlap_factor = overlap
def __call__(self, latent_sample=None, sample=None, **kwargs):
model = self.wrapped.model if hasattr(self.wrapped, "model") else self.wrapped.session
model = (
self.wrapped.model
if hasattr(self.wrapped, "model")
else self.wrapped.session
)
# set timestep dtype to input type
sample_dtype = next(

View File

@ -1,13 +1,16 @@
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
import inspect
import logging
from typing import Any, Optional, List, Union, Tuple, Callable, Dict
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
import inspect
import torch
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
StableDiffusionXLImg2ImgPipelineMixin,
)
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
logger = logging.getLogger(__name__)
@ -18,23 +21,21 @@ DEFAULT_STRIDE = 16
class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin):
def __init__(
self,
*args,
window: int = DEFAULT_WINDOW,
stride: int = DEFAULT_STRIDE,
**kwargs,
self,
*args,
window: int = DEFAULT_WINDOW,
stride: int = DEFAULT_STRIDE,
**kwargs,
):
super().__init__(self, *args, **kwargs)
self.window = window
self.stride = stride
def set_window_size(self, window: int, stride: int):
self.window = window
self.stride = stride
def get_views(self, panorama_height, panorama_width, window_size, stride):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
@ -60,21 +61,32 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
return views
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_img2img(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None):
def prepare_latents_img2img(
self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None
):
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215)
init_latents = self.vae_encoder(sample=image)[
0
] * self.vae_decoder.config.get("scaling_factor", 0.18215)
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
if (
batch_size > init_latents.shape[0]
and batch_size % init_latents.shape[0] == 0
):
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
init_latents = np.concatenate(
[init_latents] * additional_image_per_prompt, axis=0
)
elif (
batch_size > init_latents.shape[0]
and batch_size % init_latents.shape[0] != 0
):
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
@ -84,14 +96,29 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# add noise to latents using the timesteps
noise = generator.randn(*init_latents.shape).astype(dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timestep)
torch.from_numpy(init_latents),
torch.from_numpy(noise),
torch.from_numpy(timestep),
)
return init_latents.numpy()
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_text2img(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
def prepare_latents_text2img(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@ -101,14 +128,15 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
if latents is None:
latents = generator.randn(*shape).astype(dtype)
elif latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * np.float64(self.scheduler.init_noise_sigma)
return latents
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@ -118,13 +146,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
extra_step_kwargs = {}
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_eta:
extra_step_kwargs["eta"] = eta
return extra_step_kwargs
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
def text2img(
self,
@ -294,10 +323,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype)
if do_classifier_free_guidance:
prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0)
add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0)
prompt_embeds = np.concatenate(
(negative_prompt_embeds, prompt_embeds), axis=0
)
add_text_embeds = np.concatenate(
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
)
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0)
add_time_ids = np.repeat(
add_time_ids, batch_size * num_images_per_prompt, axis=0
)
# Adapted from diffusers to extend it for other runtimes than ORT
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
@ -318,8 +353,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = (
np.concatenate([latents_for_view] * 2)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
@ -336,14 +377,23 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents_for_view),
**extra_step_kwargs,
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
@ -354,7 +404,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents = np.where(count > 0, value / count, value)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
@ -364,7 +416,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = self.watermark.apply_watermark(image)
@ -379,7 +434,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
return StableDiffusionXLPipelineOutput(images=image)
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
def img2img(
self,
@ -481,7 +535,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
(nsfw) content, according to the `safety_checker`.
"""
# 0. Check inputs. Raise error if not correct
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
self.check_inputs(
prompt,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 1. Define call parameters
if isinstance(prompt, str):
@ -522,8 +583,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
latent_timestep = np.repeat(timesteps[:1], batch_size * num_images_per_prompt, axis=0)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength
)
latent_timestep = np.repeat(
timesteps[:1], batch_size * num_images_per_prompt, axis=0
)
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
latents_dtype = prompt_embeds.dtype
@ -531,12 +596,19 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 5. Prepare latent variables
latents = self.prepare_latents_img2img(
image, latent_timestep, batch_size, num_images_per_prompt, latents_dtype, generator
image,
latent_timestep,
batch_size,
num_images_per_prompt,
latents_dtype,
generator,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = {}
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_eta:
extra_step_kwargs["eta"] = eta
@ -558,10 +630,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
)
if do_classifier_free_guidance:
prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0)
add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0)
prompt_embeds = np.concatenate(
(negative_prompt_embeds, prompt_embeds), axis=0
)
add_text_embeds = np.concatenate(
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
)
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0)
add_time_ids = np.repeat(
add_time_ids, batch_size * num_images_per_prompt, axis=0
)
# 8. Panorama additions
views = self.get_views(height, width, self.window, self.stride)
@ -579,8 +657,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = (
np.concatenate([latents_for_view] * 2)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
@ -597,14 +681,23 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents_for_view),
**extra_step_kwargs,
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
@ -615,7 +708,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents = np.where(count > 0, value / count, value)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
@ -625,7 +720,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = self.watermark.apply_watermark(image)
@ -640,7 +738,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
return StableDiffusionXLPipelineOutput(images=image)
def __call__(
self,
*args,
@ -659,6 +756,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
return self.text2img(*args, **kwargs)
class ORTStableDiffusionXLPanoramaPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin):
class ORTStableDiffusionXLPanoramaPipeline(
ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin
):
def __call__(self, *args, **kwargs):
return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs)

View File

@ -259,7 +259,14 @@ class ImageParams:
# otherwise, check for additional allowed pipelines
if group == "img2img":
if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "panorama-sdxl", "pix2pix"]:
if pipeline in [
"controlnet",
"img2img-sdxl",
"lpw",
"panorama",
"panorama-sdxl",
"pix2pix",
]:
return pipeline
elif pipeline == "txt2img-sdxl":
return "img2img-sdxl"