feat(api): add support for PyTorch 2.0 (#292)
This commit is contained in:
parent
a291dc8980
commit
34f1973707
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import mkdir, path
|
from os import mkdir, path
|
||||||
|
from packaging import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
@ -22,6 +23,7 @@ from diffusers import (
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
from diffusers.models.cross_attention import CrossAttnProcessor
|
||||||
from onnx import load_model, save_model
|
from onnx import load_model, save_model
|
||||||
from onnx.shape_inference import infer_shapes_path
|
from onnx.shape_inference import infer_shapes_path
|
||||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||||
|
@ -36,6 +38,8 @@ from ..utils import ConversionContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
is_torch_2_0 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0")
|
||||||
|
|
||||||
|
|
||||||
def onnx_export(
|
def onnx_export(
|
||||||
model,
|
model,
|
||||||
|
@ -168,6 +172,9 @@ def convert_diffusion_diffusers(
|
||||||
device=ctx.training_device, dtype=torch.bool
|
device=ctx.training_device, dtype=torch.bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_torch_2_0:
|
||||||
|
pipeline.unet.set_attn_processor(CrossAttnProcessor())
|
||||||
|
|
||||||
unet_in_channels = pipeline.unet.config.in_channels
|
unet_in_channels = pipeline.unet.config.in_channels
|
||||||
unet_sample_size = pipeline.unet.config.sample_size
|
unet_sample_size = pipeline.unet.config.sample_size
|
||||||
unet_path = output_path / "unet" / ONNX_MODEL
|
unet_path = output_path / "unet" / ONNX_MODEL
|
||||||
|
|
Loading…
Reference in New Issue