1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/diffusers.py

347 lines
12 KiB
Python

###
# Parts of this file are copied or derived from:
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
#
# Originally by https://github.com/huggingface
# Those portions *are not* covered by the MIT licensed used for the rest of the onnx-web project.
#
# HuggingFace code used under the Apache License, Version 2.0
# https://github.com/huggingface/diffusers/blob/main/LICENSE
###
from logging import getLogger
from os import mkdir, path
from packaging import version
from pathlib import Path
from shutil import rmtree
from typing import Dict, Tuple
import torch
from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionPipeline,
)
from diffusers.models.cross_attention import CrossAttnProcessor
from onnx import load_model, save_model
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from torch.onnx import export
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..utils import ConversionContext
logger = getLogger(__name__)
is_torch_2_0 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0")
def onnx_export(
model,
model_args: tuple,
output_path: Path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
half=False,
external_data=False,
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
output_path.parent.mkdir(parents=True, exist_ok=True)
output_file = output_path.absolute().as_posix()
export(
model,
model_args,
f=output_file,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
if half:
logger.info("converting model to fp16 internally: %s", output_file)
infer_shapes_path(output_file)
base_model = load_model(output_file)
opt_model = convert_float_to_float16(
base_model,
disable_shape_infer=True,
keep_io_types=True,
force_fp16_initializers=True,
)
save_model(
opt_model,
f"{output_file}",
save_as_external_data=external_data,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
@torch.no_grad()
def convert_diffusion_diffusers(
conversion: ConversionContext,
model: Dict,
source: str,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
source = source or model.get("source")
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
# diffusers go into a directory rather than .onnx file
logger.info(
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
)
if single_vae:
logger.info("converting model with single VAE")
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
pipeline = StableDiffusionPipeline.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(conversion.training_device)
output_path = Path(dest_path)
optimize_pipeline(conversion, pipeline)
# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(device=conversion.training_device, dtype=torch.int32),
None, # attention mask
None, # position ids
None, # output attentions
torch.tensor(True).to(device=conversion.training_device, dtype=torch.bool),
),
output_path=output_path / "text_encoder" / ONNX_MODEL,
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=conversion.opset,
half=conversion.half,
)
del pipeline.text_encoder
logger.debug("UNET config: %s", pipeline.unet.config)
# UNET
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(device=conversion.training_device, dtype=torch.long)
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(
device=conversion.training_device, dtype=torch.bool
)
if is_torch_2_0:
pipeline.unet.set_attn_processor(CrossAttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / ONNX_MODEL
onnx_export(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=conversion.training_device, dtype=dtype
),
torch.randn(2).to(device=conversion.training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=conversion.training_device, dtype=dtype
),
unet_scale,
),
output_path=unet_path,
ordered_input_names=unet_inputs,
# has to be different from "sample" for correct tracing
output_names=["out_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True,
)
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path)
unet = load_model(unet_model_path)
# clean up existing tensor files
rmtree(unet_dir)
mkdir(unet_dir)
# collate external tensor files into one
save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
del pipeline.unet
if replace_vae is not None:
logger.debug("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.vae = vae
if single_vae:
logger.debug("VAE config: %s", pipeline.vae.config)
# SINGLE VAE
vae_only = pipeline.vae
vae_latent_channels = vae_only.config.latent_channels
# forward only through the decoder part
vae_only.forward = vae_only.decode
onnx_export(
vae_only,
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=conversion.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae" / ONNX_MODEL,
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=conversion.half,
)
else:
# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict
)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=conversion.training_device, dtype=dtype
),
False,
),
output_path=output_path / "vae_encoder" / ONNX_MODEL,
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=False, # https://github.com/ssube/onnx-web/issues/290
)
# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=conversion.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / ONNX_MODEL,
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=conversion.half,
)
del pipeline.vae
if single_vae:
onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
)
else:
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
logger.info("exporting ONNX model")
onnx_pipeline.save_pretrained(output_path)
logger.info("ONNX pipeline saved to %s", output_path)
del pipeline
del onnx_pipeline
if single_vae:
_ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
else:
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
logger.info("ONNX pipeline is loadable")
return (True, dest_path)