diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 8493ca1a..4661b846 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -222,4 +222,4 @@ def load_tensor(name: str, map_location=None): checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint ) - return checkpoint \ No newline at end of file + return checkpoint diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index 7d89e9c2..4efa96ab 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -23,11 +23,10 @@ logger = getLogger(__name__) NUM_LATENT_CHANNELS = 4 NUM_UNET_INPUT_CHANNELS = 7 -# TODO: should this be a lookup? it needs to match the conversion script -class_labels_dtype = np.long - -# TODO: should this be a lookup or converted? can it vary on ONNX? -text_embeddings_dtype = torch.float32 +TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, +} def preprocess(image): @@ -98,6 +97,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) + latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)] + # 4. Preprocess image image = preprocess(image) image = image.cpu() @@ -108,7 +109,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) - noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype) + noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 @@ -122,7 +123,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): NUM_LATENT_CHANNELS, height, width, - text_embeddings_dtype, + latents_dtype, device, generator, latents, @@ -142,6 +143,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -154,14 +160,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): latent_model_input = np.concatenate([latent_model_input, image], axis=1) # timestep to tensor - timestep = np.array([t], dtype=np.float32) + timestep = np.array([t], dtype=timestep_dtype) # predict the noise residual noise_pred = self.unet( sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings, - class_labels=noise_level.astype(class_labels_dtype), + class_labels=noise_level.astype(np.int64), )[0] # perform guidance