fix(api): allow all supported tensors extensions for VAE files
This commit is contained in:
parent
59515193a1
commit
f4f3bda6f8
|
@ -36,7 +36,7 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||||
from ...models.cnet import UNet2DConditionModel_CNet
|
from ...models.cnet import UNet2DConditionModel_CNet
|
||||||
from ...utils import run_gc
|
from ...utils import run_gc
|
||||||
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
|
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext, is_torch_2_0, load_tensor, onnx_export
|
||||||
from .checkpoint import convert_extract_checkpoint
|
from .checkpoint import convert_extract_checkpoint
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -375,7 +375,7 @@ def convert_diffusion_diffusers(
|
||||||
|
|
||||||
if replace_vae is not None:
|
if replace_vae is not None:
|
||||||
vae_path = path.join(conversion.model_path, replace_vae)
|
vae_path = path.join(conversion.model_path, replace_vae)
|
||||||
if replace_vae.endswith(".safetensors"):
|
if check_ext(replace_vae, RESOLVE_FORMATS):
|
||||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||||
else:
|
else:
|
||||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||||
|
|
Loading…
Reference in New Issue