diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index f14de94d..cee41d52 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -36,7 +36,7 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc -from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export +from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext, is_torch_2_0, load_tensor, onnx_export from .checkpoint import convert_extract_checkpoint logger = getLogger(__name__) @@ -375,7 +375,7 @@ def convert_diffusion_diffusers( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if replace_vae.endswith(".safetensors"): + if check_ext(replace_vae, RESOLVE_FORMATS): pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: pipeline.vae = AutoencoderKL.from_pretrained(vae_path)