1
0
Fork 0

fix(api): update SD upscale pipeline

This commit is contained in:
Sean Sube 2023-02-05 15:54:17 -06:00
parent a2a0028bd4
commit 49b3aa68bb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 51 additions and 79 deletions

View File

@ -3,14 +3,20 @@ from typing import Any, Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import PIL
from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline
from diffusers.pipeline_utils import ImagePipelineOutput from diffusers.pipeline_utils import ImagePipelineOutput
from PIL import Image
logger = getLogger(__name__) logger = getLogger(__name__)
num_channels_latents = 4 # self.vae.config.latent_channels # TODO: make this dynamic, from self.vae.config.latent_channels
unet_in_channels = 7 # self.unet.config.in_channels num_channels_latents = 4
# TODO: make this dynamic, from self.unet.config.in_channels
unet_in_channels = 7
### ###
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline: # This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
@ -22,10 +28,10 @@ unet_in_channels = 7 # self.unet.config.in_channels
def preprocess(image): def preprocess(image):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, Image.Image): elif isinstance(image, PIL.Image.Image):
image = [image] image = [image]
if isinstance(image[0], Image.Image): if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
@ -52,20 +58,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
scheduler: Any, scheduler: Any,
max_noise_level: int = 350, max_noise_level: int = 350,
): ):
super().__init__( super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
vae,
text_encoder,
tokenizer,
unet,
low_res_scheduler,
scheduler,
max_noise_level,
)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, Image.Image, List[Image.Image]], image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
num_inference_steps: int = 75, num_inference_steps: int = 75,
guidance_scale: float = 9.0, guidance_scale: float = 9.0,
noise_level: int = 20, noise_level: int = 20,
@ -92,12 +90,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
) )
text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype
# 4. Preprocess image # 4. Preprocess image
image = preprocess(image) image = preprocess(image)
@ -108,12 +103,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Add noise to image # 5. Add noise to image
text_embeddings_dtype = torch.float32
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn( noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
image.shape, generator=generator, device=device, dtype=text_embeddings_dtype
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level) image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1 batch_multiplier = 2 if do_classifier_free_guidance else 1
@ -137,11 +128,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
num_channels_image = image.shape[1] num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != unet_in_channels: if num_channels_latents + num_channels_image != unet_in_channels:
raise ValueError( raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet`" "Incorrect configuration settings! The config of `pipeline.unet` expects"
f" expects {unet_in_channels} but received `num_channels_latents`:" f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" {num_channels_latents} + `num_channels_image`: {num_channels_image} " f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the" f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" config of `pipeline.unet` or your `image` input." " `pipeline.unet` or your `image` input."
) )
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@ -152,16 +143,10 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = ( latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
np.concatenate([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension # concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input, t
)
latent_model_input = np.concatenate([latent_model_input, image], axis=1) latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor # timestep to tensor
@ -178,19 +163,13 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
@ -214,14 +193,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
image = image.transpose((0, 2, 3, 1)) image = image.transpose((0, 2, 3, 1))
return image return image
def _encode_prompt( def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
@ -232,34 +204,31 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer( untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
text_input_ids, untruncated_ids removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" handle sequences up to {self.tokenizer.model_max_length} tokens:" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
f" {removed_text}"
) )
# if hasattr(text_inputs, "attention_mask"):
# attention_mask = text_inputs.attention_mask.to(device)
# else:
# attention_mask = None
# no positional arguments to text_encoder # no positional arguments to text_encoder
text_embeddings = self.text_encoder( text_embeddings = self.text_encoder(
input_ids=text_input_ids.int().to(device), input_ids=text_input_ids.int().to(device),
# TODO: is this needed?
# attention_mask=attention_mask, # attention_mask=attention_mask,
) )
text_embeddings = text_embeddings[0] text_embeddings = text_embeddings[0]
bs_embed, seq_len, _ = text_embeddings.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
# TODO: is this needed? text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
@ -268,17 +237,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt): elif type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(negative_prompt)} != {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {len(negative_prompt)}, but `prompt`: {prompt} has batch size" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
f" {batch_size}. Please make sure that passed `negative_prompt`" " the batch size of `prompt`."
" matches the batch size of `prompt`."
) )
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
@ -292,17 +260,21 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return_tensors="pt", return_tensors="pt",
) )
# if hasattr(uncond_input, "attention_mask"):
# attention_mask = uncond_input.attention_mask.to(device)
# else:
# attention_mask = None
uncond_embeddings = self.text_encoder( uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.int().to(device), input_ids=uncond_input.input_ids.int().to(device),
# TODO: is this needed?
# attention_mask=attention_mask, # attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] uncond_embeddings = uncond_embeddings[0]
seq_len = uncond_embeddings.shape[1]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
# TODO: is this needed? uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch