1
0
Fork 0

fix(api): allow all supported tensors extensions for VAE files

This commit is contained in:
Sean Sube 2023-11-08 22:00:56 -06:00
parent 59515193a1
commit f4f3bda6f8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 2 additions and 2 deletions

View File

@ -36,7 +36,7 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ...utils import run_gc
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext, is_torch_2_0, load_tensor, onnx_export
from .checkpoint import convert_extract_checkpoint
logger = getLogger(__name__)
@ -375,7 +375,7 @@ def convert_diffusion_diffusers(
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if replace_vae.endswith(".safetensors"):
if check_ext(replace_vae, RESOLVE_FORMATS):
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)