76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
from logging import getLogger
|
|
from os import path
|
|
from pathlib import Path
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from ...constants import ONNX_MODEL
|
|
from ...diffusers.version_safe_diffusers import AttnProcessor, ControlNetModel
|
|
from ..utils import ConversionContext, is_torch_2_0, onnx_export
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
@torch.no_grad()
|
|
def convert_diffusion_control(
|
|
conversion: ConversionContext,
|
|
model: Dict,
|
|
source: str,
|
|
output_path: str,
|
|
attention_slicing: Optional[str] = None,
|
|
):
|
|
name = model.get("name")
|
|
source = source or model.get("source")
|
|
|
|
device = conversion.training_device
|
|
dtype = conversion.torch_dtype()
|
|
opset = conversion.opset
|
|
logger.debug("using Torch dtype %s for ControlNet", dtype)
|
|
|
|
output_path = Path(output_path)
|
|
logger.info("converting ControlNet model %s: %s -> %s", name, source, output_path)
|
|
if path.exists(output_path):
|
|
logger.info("ONNX model already exists, skipping")
|
|
return
|
|
|
|
controlnet = ControlNetModel.from_pretrained(source, torch_dtype=dtype)
|
|
if attention_slicing is not None:
|
|
logger.info("enabling attention slicing for ControlNet")
|
|
controlnet.set_attention_slice(attention_slicing)
|
|
|
|
# UNET
|
|
if is_torch_2_0:
|
|
controlnet.set_attn_processor(AttnProcessor())
|
|
|
|
cnet_path = output_path / ONNX_MODEL
|
|
onnx_export(
|
|
controlnet,
|
|
model_args=(
|
|
torch.randn(2, 4, 64, 64).to(device=device, dtype=dtype),
|
|
torch.randn(2).to(device=device, dtype=dtype),
|
|
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
|
|
torch.randn(2, 3, 512, 512).to(device=device, dtype=dtype),
|
|
),
|
|
output_path=cnet_path,
|
|
ordered_input_names=[
|
|
"sample",
|
|
"timestep",
|
|
"encoder_hidden_states",
|
|
"controlnet_cond",
|
|
],
|
|
output_names=[
|
|
"down_block_res_samples",
|
|
"mid_block_res_sample",
|
|
], # has to be different from "sample" for correct tracing
|
|
dynamic_axes={
|
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
|
"timestep": {0: "batch"},
|
|
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
|
"controlnet_cond": {0: "batch", 2: "height", 3: "width"},
|
|
},
|
|
opset=opset,
|
|
)
|
|
|
|
logger.info("ONNX ControlNet saved to %s", output_path)
|