1
0
Fork 0

feat(api): add support for PyTorch 2.0 (#292)

This commit is contained in:
Sean Sube 2023-04-09 16:07:06 -05:00
parent a291dc8980
commit 34f1973707
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 7 additions and 0 deletions

View File

@ -11,6 +11,7 @@
from logging import getLogger from logging import getLogger
from os import mkdir, path from os import mkdir, path
from packaging import version
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Dict, Tuple from typing import Dict, Tuple
@ -22,6 +23,7 @@ from diffusers import (
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.models.cross_attention import CrossAttnProcessor
from onnx import load_model, save_model from onnx import load_model, save_model
from onnx.shape_inference import infer_shapes_path from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16 from onnxruntime.transformers.float16 import convert_float_to_float16
@ -36,6 +38,8 @@ from ..utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
is_torch_2_0 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0")
def onnx_export( def onnx_export(
model, model,
@ -168,6 +172,9 @@ def convert_diffusion_diffusers(
device=ctx.training_device, dtype=torch.bool device=ctx.training_device, dtype=torch.bool
) )
if is_torch_2_0:
pipeline.unet.set_attn_processor(CrossAttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / ONNX_MODEL unet_path = output_path / "unet" / ONNX_MODEL