From 2b29b099f0ec9bb0d23df87036f8430798a6e4fc Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 15 Feb 2023 17:16:20 -0600 Subject: [PATCH] fix(api): make SD upscaling compatible with more schedulers --- api/onnx_web/convert/diffusion_stable.py | 4 ++- .../pipeline_onnx_stable_diffusion_upscale.py | 27 +++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index eb34b533..18156c3c 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -126,7 +126,9 @@ def convert_diffusion_stable( # UNET if single_vae: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] - unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int) + unet_scale = torch.tensor(4).to( + device=ctx.training_device, dtype=torch.long + ) else: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_scale = torch.tensor(False).to( diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index 5aefdb49..7d89e9c2 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -19,11 +19,15 @@ from diffusers.pipeline_utils import ImagePipelineOutput logger = getLogger(__name__) -# TODO: make this dynamic, from self.vae.config.latent_channels -num_channels_latents = 4 -# TODO: make this dynamic, from self.unet.config.in_channels -unet_in_channels = 7 +NUM_LATENT_CHANNELS = 4 +NUM_UNET_INPUT_CHANNELS = 7 + +# TODO: should this be a lookup? it needs to match the conversion script +class_labels_dtype = np.long + +# TODO: should this be a lookup or converted? can it vary on ONNX? +text_embeddings_dtype = torch.float32 def preprocess(image): @@ -93,7 +97,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): text_embeddings = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype # 4. Preprocess image image = preprocess(image) @@ -116,7 +119,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): height, width = image.shape[2:] latents = self.prepare_latents( batch_size * num_images_per_prompt, - num_channels_latents, + NUM_LATENT_CHANNELS, height, width, text_embeddings_dtype, @@ -127,12 +130,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] - if num_channels_latents + num_channels_image != unet_in_channels: + if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" - f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +" f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" + f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) @@ -158,7 +161,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings, - class_labels=noise_level, + class_labels=noise_level.astype(class_labels_dtype), )[0] # perform guidance @@ -167,7 +170,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs + ).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):