From f4f3bda6f815b68afbe2eded85ee172c3a99d60f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 8 Nov 2023 22:00:56 -0600 Subject: [PATCH] fix(api): allow all supported tensors extensions for VAE files --- api/onnx_web/convert/diffusion/diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index f14de94d..cee41d52 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -36,7 +36,7 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc -from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export +from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext, is_torch_2_0, load_tensor, onnx_export from .checkpoint import convert_extract_checkpoint logger = getLogger(__name__) @@ -375,7 +375,7 @@ def convert_diffusion_diffusers( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if replace_vae.endswith(".safetensors"): + if check_ext(replace_vae, RESOLVE_FORMATS): pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: pipeline.vae = AutoencoderKL.from_pretrained(vae_path)