diff --git a/api/onnx_web/diffusers/patches/vae.py b/api/onnx_web/diffusers/patches/vae.py index 8019a243..d7e3e1f7 100644 --- a/api/onnx_web/diffusers/patches/vae.py +++ b/api/onnx_web/diffusers/patches/vae.py @@ -4,10 +4,11 @@ from typing import Union import numpy as np import torch from diffusers import OnnxRuntimeModel +from diffusers.models.autoencoder_kl import AutoencoderKLOutput +from diffusers.models.vae import DecoderOutput from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE from ...server import ServerContext -from ..version_safe_diffusers import AutoencoderKLOutput, DecoderOutput logger = getLogger(__name__) diff --git a/api/onnx_web/diffusers/version_safe_diffusers.py b/api/onnx_web/diffusers/version_safe_diffusers.py index a4be235b..e0aea62a 100644 --- a/api/onnx_web/diffusers/version_safe_diffusers.py +++ b/api/onnx_web/diffusers/version_safe_diffusers.py @@ -5,9 +5,7 @@ from packaging import version is_diffusers_0_15 = version.parse( version.parse(diffusers.__version__).base_version ) >= version.parse("0.15") -is_diffusers_0_24 = version.parse( - version.parse(diffusers.__version__).base_version -) >= version.parse("0.24") + try: from diffusers import DEISMultistepScheduler @@ -35,10 +33,3 @@ if is_diffusers_0_15: else: from diffusers.models.cross_attention import CrossAttnProcessor as AttnProcessor - -if is_diffusers_0_24: - from diffusers.models.autoencoders.vae import DecoderOutput - from diffusers.models.modeling_outputs import AutoencoderKLOutput -else: - from diffusers.models.autoencoder_kl import AutoencoderKLOutput - from diffusers.models.vae import DecoderOutput