diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 64ab7be7..52d69960 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -23,7 +23,6 @@ from logging import getLogger from typing import Dict, List import huggingface_hub.utils.tqdm -import safetensors.torch import torch from diffusers import ( AutoencoderKL, @@ -1147,8 +1146,8 @@ def extract_checkpoint( extract_ema=False, train_unfrozen=False, is_512=True, - config_file: str =None, - vae_file: str =None, + config_file: str = None, + vae_file: str = None, ): """ diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index d07fc064..b697dd18 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -38,7 +38,7 @@ ORT_TO_NP_TYPE = { "tensor(double)": np.float64, } -TORCH_DTYPES = { +ORT_TO_PT_TYPE = { "float16": torch.float16, "float32": torch.float32, } @@ -112,7 +112,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)] + latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] # 4. Preprocess image image = preprocess(image)