1
0
Fork 0

fix(api): make version-safe imports compatible with tests

This commit is contained in:
Sean Sube 2023-12-29 23:09:00 -06:00
parent 05ab396b2a
commit 6a004816af
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 12 additions and 11 deletions

View File

@ -4,11 +4,10 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
from diffusers import OnnxRuntimeModel from diffusers import OnnxRuntimeModel
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.vae import DecoderOutput
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
from ...server import ServerContext from ...server import ServerContext
from ..version_safe_diffusers import AutoencoderKLOutput, DecoderOutput
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -1,11 +1,4 @@
import diffusers
from diffusers import * # NOQA from diffusers import * # NOQA
from packaging import version
is_diffusers_0_15 = version.parse(
version.parse(diffusers.__version__).base_version
) >= version.parse("0.15")
try: try:
from diffusers import DEISMultistepScheduler from diffusers import DEISMultistepScheduler
@ -27,8 +20,17 @@ try:
except ImportError: except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
try:
from diffusers.models.modeling_outputs import AutoencoderKLOutput
except ImportError:
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
if is_diffusers_0_15: try:
from diffusers.models.autoencoders.vae import DecoderOutput
except ImportError:
from diffusers.models.vae import DecoderOutput
try:
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor
else: except ImportError:
from diffusers.models.cross_attention import CrossAttnProcessor as AttnProcessor from diffusers.models.cross_attention import CrossAttnProcessor as AttnProcessor