1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/control.py

76 lines
2.4 KiB
Python
Raw Normal View History

from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict, Optional
import torch
from ...constants import ONNX_MODEL
2023-04-15 19:42:14 +00:00
from ...diffusers.version_safe_diffusers import AttnProcessor, ControlNetModel
from ..utils import ConversionContext, is_torch_2_0, onnx_export
logger = getLogger(__name__)
@torch.no_grad()
def convert_diffusion_control(
conversion: ConversionContext,
model: Dict,
source: str,
output_path: str,
attention_slicing: Optional[str] = None,
):
name = model.get("name")
source = source or model.get("source")
device = conversion.training_device
dtype = conversion.torch_dtype()
opset = conversion.opset
logger.debug("using Torch dtype %s for ControlNet", dtype)
output_path = Path(output_path)
logger.info("converting ControlNet model %s: %s -> %s", name, source, output_path)
if path.exists(output_path):
logger.info("ONNX model already exists, skipping")
return
controlnet = ControlNetModel.from_pretrained(source, torch_dtype=dtype)
if attention_slicing is not None:
logger.info("enabling attention slicing for ControlNet")
controlnet.set_attention_slice(attention_slicing)
# UNET
if is_torch_2_0:
controlnet.set_attn_processor(AttnProcessor())
cnet_path = output_path / ONNX_MODEL
onnx_export(
controlnet,
model_args=(
torch.randn(2, 4, 64, 64).to(device=device, dtype=dtype),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
torch.randn(2, 3, 512, 512).to(device=device, dtype=dtype),
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
"encoder_hidden_states",
"controlnet_cond",
],
output_names=[
"down_block_res_samples",
"mid_block_res_sample",
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"controlnet_cond": {0: "batch", 2: "height", 3: "width"},
},
opset=opset,
)
logger.info("ONNX ControlNet saved to %s", output_path)