From 34f19737070a75df76e7cd234b135b3980f6927c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 9 Apr 2023 16:07:06 -0500 Subject: [PATCH] feat(api): add support for PyTorch 2.0 (#292) --- api/onnx_web/convert/diffusion/diffusers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 1f4459c9..1d88fc3b 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -11,6 +11,7 @@ from logging import getLogger from os import mkdir, path +from packaging import version from pathlib import Path from shutil import rmtree from typing import Dict, Tuple @@ -22,6 +23,7 @@ from diffusers import ( OnnxStableDiffusionPipeline, StableDiffusionPipeline, ) +from diffusers.models.cross_attention import CrossAttnProcessor from onnx import load_model, save_model from onnx.shape_inference import infer_shapes_path from onnxruntime.transformers.float16 import convert_float_to_float16 @@ -36,6 +38,8 @@ from ..utils import ConversionContext logger = getLogger(__name__) +is_torch_2_0 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0") + def onnx_export( model, @@ -168,6 +172,9 @@ def convert_diffusion_diffusers( device=ctx.training_device, dtype=torch.bool ) + if is_torch_2_0: + pipeline.unet.set_attn_processor(CrossAttnProcessor()) + unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size unet_path = output_path / "unet" / ONNX_MODEL