2023-09-10 16:52:46 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from os import path
|
|
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
|
2023-09-11 23:18:38 +00:00
|
|
|
import onnx
|
2023-09-10 16:52:46 +00:00
|
|
|
import torch
|
2023-09-11 23:18:38 +00:00
|
|
|
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
|
|
|
|
from onnx.shape_inference import infer_shapes_path
|
|
|
|
from onnxruntime.transformers.float16 import convert_float_to_float16
|
2023-09-10 16:52:46 +00:00
|
|
|
from optimum.exporters.onnx import main_export
|
|
|
|
|
2023-09-11 23:18:38 +00:00
|
|
|
from ...constants import ONNX_MODEL
|
2023-12-10 19:52:52 +00:00
|
|
|
from ..client import fetch_model
|
2023-10-07 00:01:00 +00:00
|
|
|
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
2023-09-10 16:52:46 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def convert_diffusion_diffusers_xl(
|
|
|
|
conversion: ConversionContext,
|
|
|
|
model: Dict,
|
|
|
|
format: Optional[str],
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
"""
|
|
|
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
|
|
|
"""
|
2023-12-10 19:59:47 +00:00
|
|
|
name = str(model.get("name")).strip()
|
2023-12-10 19:52:52 +00:00
|
|
|
source = model.get("source")
|
2023-09-11 23:18:38 +00:00
|
|
|
replace_vae = model.get("vae", None)
|
2023-09-10 16:52:46 +00:00
|
|
|
|
|
|
|
device = conversion.training_device
|
|
|
|
dtype = conversion.torch_dtype()
|
|
|
|
logger.debug("using Torch dtype %s for pipeline", dtype)
|
|
|
|
|
|
|
|
dest_path = path.join(conversion.model_path, name)
|
|
|
|
model_index = path.join(dest_path, "model_index.json")
|
|
|
|
model_hash = path.join(dest_path, "hash.txt")
|
|
|
|
|
|
|
|
# diffusers go into a directory rather than .onnx file
|
|
|
|
logger.info(
|
|
|
|
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
|
|
|
|
)
|
|
|
|
|
|
|
|
if path.exists(dest_path) and path.exists(model_index):
|
|
|
|
logger.info("ONNX model already exists, skipping conversion")
|
2023-10-07 00:01:00 +00:00
|
|
|
|
|
|
|
if "hash" in model and not path.exists(model_hash):
|
|
|
|
logger.info("ONNX model does not have hash file, adding one")
|
|
|
|
with open(model_hash, "w") as f:
|
|
|
|
f.write(model["hash"])
|
|
|
|
|
2023-09-10 16:52:46 +00:00
|
|
|
return (False, dest_path)
|
|
|
|
|
2023-12-10 19:52:52 +00:00
|
|
|
cache_path = fetch_model(conversion, name, model["source"], format=format)
|
2023-09-10 16:52:46 +00:00
|
|
|
# safetensors -> diffusers directory with torch models
|
|
|
|
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
|
|
|
|
|
|
|
if format == "safetensors":
|
2023-09-10 16:53:36 +00:00
|
|
|
pipeline = StableDiffusionXLPipeline.from_single_file(
|
2023-12-10 19:52:52 +00:00
|
|
|
cache_path, use_safetensors=True
|
2023-09-10 16:53:36 +00:00
|
|
|
)
|
2023-09-10 16:52:46 +00:00
|
|
|
else:
|
2023-12-10 19:52:52 +00:00
|
|
|
pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)
|
2023-09-10 16:52:46 +00:00
|
|
|
|
2023-09-11 23:18:38 +00:00
|
|
|
if replace_vae is not None:
|
2023-09-12 01:21:43 +00:00
|
|
|
vae_path = path.join(conversion.model_path, replace_vae)
|
2023-12-09 04:55:51 +00:00
|
|
|
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
|
|
|
|
if vae_file[0]:
|
2023-11-24 05:21:45 +00:00
|
|
|
logger.debug("loading VAE from single tensor file: %s", vae_path)
|
2023-09-12 01:21:43 +00:00
|
|
|
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
2023-09-11 23:18:38 +00:00
|
|
|
else:
|
2023-12-09 05:19:52 +00:00
|
|
|
logger.debug("loading pretrained VAE from path: %s", replace_vae)
|
|
|
|
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
|
2023-09-11 23:18:38 +00:00
|
|
|
|
2023-11-06 14:48:35 +00:00
|
|
|
if path.exists(temp_path):
|
|
|
|
logger.debug("torch model already exists for %s: %s", source, temp_path)
|
|
|
|
else:
|
|
|
|
logger.debug("exporting torch model for %s: %s", source, temp_path)
|
|
|
|
pipeline.save_pretrained(temp_path)
|
2023-09-10 16:52:46 +00:00
|
|
|
|
|
|
|
# directory -> onnx using optimum exporters
|
|
|
|
main_export(
|
|
|
|
temp_path,
|
|
|
|
output=dest_path,
|
|
|
|
task="stable-diffusion-xl",
|
|
|
|
device=device,
|
2023-11-19 00:13:13 +00:00
|
|
|
fp16=conversion.has_optimization(
|
|
|
|
"torch-fp16"
|
|
|
|
), # optimum's fp16 mode only works on CUDA or ROCm
|
2023-09-10 16:52:46 +00:00
|
|
|
framework="pt",
|
|
|
|
)
|
|
|
|
|
2023-10-07 00:01:00 +00:00
|
|
|
if "hash" in model:
|
|
|
|
logger.debug("adding hash file to ONNX model")
|
|
|
|
with open(model_hash, "w") as f:
|
|
|
|
f.write(model["hash"])
|
|
|
|
|
2023-09-11 23:18:38 +00:00
|
|
|
if conversion.half:
|
|
|
|
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
|
|
|
|
infer_shapes_path(unet_path)
|
|
|
|
unet = onnx.load(unet_path)
|
|
|
|
opt_model = convert_float_to_float16(
|
|
|
|
unet,
|
|
|
|
disable_shape_infer=True,
|
|
|
|
force_fp16_initializers=True,
|
|
|
|
keep_io_types=True,
|
|
|
|
op_block_list=["Attention", "MultiHeadAttention"],
|
|
|
|
)
|
|
|
|
onnx.save_model(
|
|
|
|
opt_model,
|
|
|
|
unet_path,
|
|
|
|
save_as_external_data=True,
|
|
|
|
all_tensors_to_one_file=True,
|
|
|
|
location="weights.pb",
|
|
|
|
)
|
2023-09-10 16:52:46 +00:00
|
|
|
|
2023-09-10 16:53:36 +00:00
|
|
|
return False, dest_path
|