From 49b3aa68bbe297b77ba275a2ddb401b40f704b56 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 5 Feb 2023 15:54:17 -0600 Subject: [PATCH] fix(api): update SD upscale pipeline --- .../pipeline_onnx_stable_diffusion_upscale.py | 130 +++++++----------- 1 file changed, 51 insertions(+), 79 deletions(-) diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index b696d349..9641827a 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -3,14 +3,20 @@ from typing import Any, Callable, List, Optional, Union import numpy as np import torch + +import PIL + from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline from diffusers.pipeline_utils import ImagePipelineOutput -from PIL import Image + logger = getLogger(__name__) -num_channels_latents = 4 # self.vae.config.latent_channels -unet_in_channels = 7 # self.unet.config.in_channels +# TODO: make this dynamic, from self.vae.config.latent_channels +num_channels_latents = 4 + +# TODO: make this dynamic, from self.unet.config.in_channels +unet_in_channels = 7 ### # This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline: @@ -22,10 +28,10 @@ unet_in_channels = 7 # self.unet.config.in_channels def preprocess(image): if isinstance(image, torch.Tensor): return image - elif isinstance(image, Image.Image): + elif isinstance(image, PIL.Image.Image): image = [image] - if isinstance(image[0], Image.Image): + if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 @@ -52,20 +58,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): scheduler: Any, max_noise_level: int = 350, ): - super().__init__( - vae, - text_encoder, - tokenizer, - unet, - low_res_scheduler, - scheduler, - max_noise_level, - ) + super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level) def __call__( self, prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, Image.Image, List[Image.Image]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, @@ -92,12 +90,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 3. Encode input prompt text_embeddings = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype # 4. Preprocess image image = preprocess(image) @@ -108,12 +103,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): timesteps = self.scheduler.timesteps # 5. Add noise to image - text_embeddings_dtype = torch.float32 - noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) - noise = torch.randn( - image.shape, generator=generator, device=device, dtype=text_embeddings_dtype - ) + noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 @@ -137,11 +128,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != unet_in_channels: raise ValueError( - "Incorrect configuration settings! The config of `pipeline.unet`" - f" expects {unet_in_channels} but received `num_channels_latents`:" - f" {num_channels_latents} + `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the" - " config of `pipeline.unet` or your `image` input." + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -152,16 +143,10 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - np.concatenate([latents] * 2) - if do_classifier_free_guidance - else latents - ) + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = np.concatenate([latent_model_input, image], axis=1) # timestep to tensor @@ -178,19 +163,13 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -214,14 +193,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): image = image.transpose((0, 2, 3, 1)) return image - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - ): + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( @@ -232,34 +204,31 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because CLIP can only" - f" handle sequences up to {self.tokenizer.model_max_length} tokens:" - f" {removed_text}" + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) + # if hasattr(text_inputs, "attention_mask"): + # attention_mask = text_inputs.attention_mask.to(device) + # else: + # attention_mask = None + # no positional arguments to text_encoder text_embeddings = self.text_encoder( input_ids=text_input_ids.int().to(device), - # TODO: is this needed? # attention_mask=attention_mask, ) text_embeddings = text_embeddings[0] + bs_embed, seq_len, _ = text_embeddings.shape # duplicate text embeddings for each generation per prompt, using mps friendly method text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) - # TODO: is this needed? - # text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -268,17 +237,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( - "`negative_prompt` should be the same type to `prompt`, but got" - f" {type(negative_prompt)} != {type(prompt)}." + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size" - f" {len(negative_prompt)}, but `prompt`: {prompt} has batch size" - f" {batch_size}. Please make sure that passed `negative_prompt`" - " matches the batch size of `prompt`." + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt @@ -292,17 +260,21 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): return_tensors="pt", ) + # if hasattr(uncond_input, "attention_mask"): + # attention_mask = uncond_input.attention_mask.to(device) + # else: + # attention_mask = None + uncond_embeddings = self.text_encoder( input_ids=uncond_input.input_ids.int().to(device), - # TODO: is this needed? # attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] + seq_len = uncond_embeddings.shape[1] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) - # TODO: is this needed? - # uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch