diff --git a/api/onnx_web/convert/diffusion/control.py b/api/onnx_web/convert/diffusion/control.py index 6a247c6b..f2bbbe2b 100644 --- a/api/onnx_web/convert/diffusion/control.py +++ b/api/onnx_web/convert/diffusion/control.py @@ -4,8 +4,7 @@ from pathlib import Path from typing import Dict import torch -from diffusers.models.controlnet import ControlNetModel -from diffusers.models.cross_attention import CrossAttnProcessor +from ...diffusers.version_safe_diffusers import AttnProcessor, ControlNetModel from ...constants import ONNX_MODEL from ..utils import ConversionContext, is_torch_2_0, onnx_export @@ -43,7 +42,7 @@ def convert_diffusion_control( # UNET if is_torch_2_0: - controlnet.set_attn_processor(CrossAttnProcessor()) + controlnet.set_attn_processor(AttnProcessor()) cnet_path = output_path / "cnet" / ONNX_MODEL onnx_export( diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index dc893165..62e36107 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -22,12 +22,12 @@ from diffusers import ( OnnxStableDiffusionPipeline, StableDiffusionPipeline, ) -from diffusers.models.cross_attention import CrossAttnProcessor from onnx import load_model, save_model from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...diffusers.load import optimize_pipeline from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline +from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ..utils import ConversionContext, is_torch_2_0, onnx_export @@ -51,7 +51,7 @@ def convert_diffusion_diffusers_cnet( ) if is_torch_2_0: - pipe_cnet.set_attn_processor(CrossAttnProcessor()) + pipe_cnet.set_attn_processor(AttnProcessor()) cnet_path = output_path / "cnet" / ONNX_MODEL onnx_export( @@ -262,7 +262,7 @@ def convert_diffusion_diffusers( unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool) if is_torch_2_0: - pipeline.unet.set_attn_processor(CrossAttnProcessor()) + pipeline.unet.set_attn_processor(AttnProcessor()) unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index c224dee0..8dadda3d 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -4,9 +4,24 @@ from typing import Any, List, Optional, Tuple import numpy as np import torch -from diffusers import ( +from onnx import load_model +from transformers import CLIPTokenizer + +from ..constants import ONNX_MODEL +from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors +from ..convert.diffusion.textual_inversion import blend_textual_inversions +from ..diffusers.utils import expand_prompt +from ..models.meta import NetworkModel +from ..params import DeviceParams, Size +from ..server import ServerContext +from ..utils import run_gc +from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline +from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline +from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline +from .version_safe_diffusers import ( DDIMScheduler, DDPMScheduler, + DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, @@ -21,31 +36,8 @@ from diffusers import ( OnnxStableDiffusionPipeline, PNDMScheduler, StableDiffusionPipeline, + UniPCMultistepScheduler, ) -from onnx import load_model -from transformers import CLIPTokenizer - -try: - from diffusers import DEISMultistepScheduler -except ImportError: - from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler - -try: - from diffusers import UniPCMultistepScheduler -except ImportError: - from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler - -from ..constants import ONNX_MODEL -from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors -from ..convert.diffusion.textual_inversion import blend_textual_inversions -from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline -from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline -from ..diffusers.utils import expand_prompt -from ..models.meta import NetworkModel -from ..params import DeviceParams, Size -from ..server import ServerContext -from ..utils import run_gc -from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline logger = getLogger(__name__) diff --git a/api/onnx_web/diffusers/version_safe_diffusers.py b/api/onnx_web/diffusers/version_safe_diffusers.py new file mode 100644 index 00000000..0add964b --- /dev/null +++ b/api/onnx_web/diffusers/version_safe_diffusers.py @@ -0,0 +1,30 @@ +import diffusers +from diffusers import * # NOQA +from packaging import version + +is_diffusers_0_15 = version.parse( + version.parse(diffusers.__version__).base_version +) >= version.parse("0.15") + + +try: + from diffusers import DEISMultistepScheduler # NOQA +except ImportError: + from ..diffusers.stub_scheduler import ( + StubScheduler as DEISMultistepScheduler, # NOQA + ) + +try: + from diffusers import UniPCMultistepScheduler # NOQA +except ImportError: + from ..diffusers.stub_scheduler import ( + StubScheduler as UniPCMultistepScheduler, # NOQA + ) + + +if is_diffusers_0_15: + from diffusers.models.attention_processor import AttnProcessor # NOQA +else: + from diffusers.models.cross_attention import ( + CrossAttnProcessor as AttnProcessor, # NOQA + ) diff --git a/api/onnx_web/image/laion_face.py b/api/onnx_web/image/laion_face.py index d5a70dc5..b0794037 100644 --- a/api/onnx_web/image/laion_face.py +++ b/api/onnx_web/image/laion_face.py @@ -1,4 +1,5 @@ -# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py +# from https://huggingface.co/CrucibleAI/ControlNetMediaPipeFace/blob/main/laion_face_common.py +# and https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py from typing import Mapping