1
0
Fork 0

fix(api): check diffusers version before imports (#336)

This commit is contained in:
Sean Sube 2023-04-15 14:32:22 -05:00
parent ad5c69e7e1
commit 95841ffe2b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 54 additions and 32 deletions

View File

@ -4,8 +4,7 @@ from pathlib import Path
from typing import Dict from typing import Dict
import torch import torch
from diffusers.models.controlnet import ControlNetModel from ...diffusers.version_safe_diffusers import AttnProcessor, ControlNetModel
from diffusers.models.cross_attention import CrossAttnProcessor
from ...constants import ONNX_MODEL from ...constants import ONNX_MODEL
from ..utils import ConversionContext, is_torch_2_0, onnx_export from ..utils import ConversionContext, is_torch_2_0, onnx_export
@ -43,7 +42,7 @@ def convert_diffusion_control(
# UNET # UNET
if is_torch_2_0: if is_torch_2_0:
controlnet.set_attn_processor(CrossAttnProcessor()) controlnet.set_attn_processor(AttnProcessor())
cnet_path = output_path / "cnet" / ONNX_MODEL cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export( onnx_export(

View File

@ -22,12 +22,12 @@ from diffusers import (
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.models.cross_attention import CrossAttnProcessor
from onnx import load_model, save_model from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet from ...models.cnet import UNet2DConditionModel_CNet
from ..utils import ConversionContext, is_torch_2_0, onnx_export from ..utils import ConversionContext, is_torch_2_0, onnx_export
@ -51,7 +51,7 @@ def convert_diffusion_diffusers_cnet(
) )
if is_torch_2_0: if is_torch_2_0:
pipe_cnet.set_attn_processor(CrossAttnProcessor()) pipe_cnet.set_attn_processor(AttnProcessor())
cnet_path = output_path / "cnet" / ONNX_MODEL cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export( onnx_export(
@ -262,7 +262,7 @@ def convert_diffusion_diffusers(
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool) unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
if is_torch_2_0: if is_torch_2_0:
pipeline.unet.set_attn_processor(CrossAttnProcessor()) pipeline.unet.set_attn_processor(AttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size unet_sample_size = pipeline.unet.config.sample_size

View File

@ -4,9 +4,24 @@ from typing import Any, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from diffusers import ( from onnx import load_model
from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.utils import expand_prompt
from ..models.meta import NetworkModel
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from .version_safe_diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler, DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
@ -21,31 +36,8 @@ from diffusers import (
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
UniPCMultistepScheduler,
) )
from onnx import load_model
from transformers import CLIPTokenizer
try:
from diffusers import DEISMultistepScheduler
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
try:
from diffusers import UniPCMultistepScheduler
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from ..diffusers.utils import expand_prompt
from ..models.meta import NetworkModel
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -0,0 +1,30 @@
import diffusers
from diffusers import * # NOQA
from packaging import version
is_diffusers_0_15 = version.parse(
version.parse(diffusers.__version__).base_version
) >= version.parse("0.15")
try:
from diffusers import DEISMultistepScheduler # NOQA
except ImportError:
from ..diffusers.stub_scheduler import (
StubScheduler as DEISMultistepScheduler, # NOQA
)
try:
from diffusers import UniPCMultistepScheduler # NOQA
except ImportError:
from ..diffusers.stub_scheduler import (
StubScheduler as UniPCMultistepScheduler, # NOQA
)
if is_diffusers_0_15:
from diffusers.models.attention_processor import AttnProcessor # NOQA
else:
from diffusers.models.cross_attention import (
CrossAttnProcessor as AttnProcessor, # NOQA
)

View File

@ -1,4 +1,5 @@
# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py # from https://huggingface.co/CrucibleAI/ControlNetMediaPipeFace/blob/main/laion_face_common.py
# and https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py
from typing import Mapping from typing import Mapping