feat(api): add support for PyTorch 2.0 (#292)
This commit is contained in:
parent
a291dc8980
commit
34f1973707
|
@ -11,6 +11,7 @@
|
|||
|
||||
from logging import getLogger
|
||||
from os import mkdir, path
|
||||
from packaging import version
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Tuple
|
||||
|
@ -22,6 +23,7 @@ from diffusers import (
|
|||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.models.cross_attention import CrossAttnProcessor
|
||||
from onnx import load_model, save_model
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||
|
@ -36,6 +38,8 @@ from ..utils import ConversionContext
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
is_torch_2_0 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0")
|
||||
|
||||
|
||||
def onnx_export(
|
||||
model,
|
||||
|
@ -168,6 +172,9 @@ def convert_diffusion_diffusers(
|
|||
device=ctx.training_device, dtype=torch.bool
|
||||
)
|
||||
|
||||
if is_torch_2_0:
|
||||
pipeline.unet.set_attn_processor(CrossAttnProcessor())
|
||||
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
unet_path = output_path / "unet" / ONNX_MODEL
|
||||
|
|
Loading…
Reference in New Issue