fix(api): make version-safe imports compatible with tests
This commit is contained in:
parent
05ab396b2a
commit
6a004816af
|
@ -4,11 +4,10 @@ from typing import Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import OnnxRuntimeModel
|
from diffusers import OnnxRuntimeModel
|
||||||
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
|
||||||
from diffusers.models.vae import DecoderOutput
|
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
|
||||||
|
|
||||||
from ...server import ServerContext
|
from ...server import ServerContext
|
||||||
|
from ..version_safe_diffusers import AutoencoderKLOutput, DecoderOutput
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,4 @@
|
||||||
import diffusers
|
|
||||||
from diffusers import * # NOQA
|
from diffusers import * # NOQA
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
is_diffusers_0_15 = version.parse(
|
|
||||||
version.parse(diffusers.__version__).base_version
|
|
||||||
) >= version.parse("0.15")
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import DEISMultistepScheduler
|
from diffusers import DEISMultistepScheduler
|
||||||
|
@ -27,8 +20,17 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
||||||
|
|
||||||
|
try:
|
||||||
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
||||||
|
except ImportError:
|
||||||
|
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
||||||
|
|
||||||
if is_diffusers_0_15:
|
try:
|
||||||
|
from diffusers.models.autoencoders.vae import DecoderOutput
|
||||||
|
except ImportError:
|
||||||
|
from diffusers.models.vae import DecoderOutput
|
||||||
|
|
||||||
|
try:
|
||||||
from diffusers.models.attention_processor import AttnProcessor
|
from diffusers.models.attention_processor import AttnProcessor
|
||||||
else:
|
except ImportError:
|
||||||
from diffusers.models.cross_attention import CrossAttnProcessor as AttnProcessor
|
from diffusers.models.cross_attention import CrossAttnProcessor as AttnProcessor
|
||||||
|
|
Loading…
Reference in New Issue