1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/diffusers.py

569 lines
20 KiB
Python
Raw Normal View History

###
# Parts of this file are copied or derived from:
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
#
# Originally by https://github.com/huggingface
# Those portions *are not* covered by the MIT licensed used for the rest of the onnx-web project.
# ...diffusers.pipelines.pipeline_onnx_stable_diffusion_upscale
# HuggingFace code used under the Apache License, Version 2.0
# https://github.com/huggingface/diffusers/blob/main/LICENSE
###
2023-04-30 04:05:51 +00:00
from functools import partial
from logging import getLogger
from os import mkdir, path
from pathlib import Path
from shutil import rmtree
from typing import Any, Dict, Optional, Tuple, Union
import torch
2023-01-28 23:09:19 +00:00
from diffusers import (
AutoencoderKL,
2023-01-28 23:09:19 +00:00
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInstructPix2PixPipeline,
2023-01-28 23:09:19 +00:00
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
2023-01-28 23:09:19 +00:00
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
2023-03-05 04:25:49 +00:00
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
available_pipelines = {
"controlnet": StableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
"txt2img": StableDiffusionPipeline,
"upscale": StableDiffusionUpscalePipeline,
}
def get_model_version(
source,
map_location,
2023-04-30 04:05:51 +00:00
size=None,
version=None,
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
v2 = version is not None and "v2" in version
opts = {
"extract_ema": True,
}
try:
checkpoint = load_tensor(source, map_location=map_location)
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None
if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768
opts["image_size"] = size
2023-04-30 04:05:51 +00:00
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
logger.debug("setting upcast_attention")
opts["upcast_attention"] = True
if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"
except:
logger.debug("unable to load tensor for version check")
pass
return (v2, opts)
def convert_diffusion_diffusers_cnet(
conversion: ConversionContext,
source: str,
device: str,
output_path: Path,
dtype,
unet_in_channels,
unet_sample_size,
num_tokens,
text_hidden_size,
2023-04-29 18:27:39 +00:00
unet: Optional[Any] = None,
):
# CNet
2023-04-29 18:27:39 +00:00
if unet is not None:
pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config)
else:
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
if is_torch_2_0:
pipe_cnet.set_attn_processor(AttnProcessor())
optimize_pipeline(conversion, pipe_cnet)
cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
pipe_cnet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
False,
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
"encoder_hidden_states",
"down_block_0",
"down_block_1",
"down_block_2",
"down_block_3",
"down_block_4",
"down_block_5",
"down_block_6",
"down_block_7",
"down_block_8",
"down_block_9",
"down_block_10",
"down_block_11",
"mid_block_additional_residual",
"return_dict",
],
output_names=[
"out_sample"
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"down_block_0": {0: "batch", 2: "height", 3: "width"},
"down_block_1": {0: "batch", 2: "height", 3: "width"},
"down_block_2": {0: "batch", 2: "height", 3: "width"},
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True, # UNet is > 2GB, so the weights need to be split
)
cnet_model_path = str(cnet_path.absolute().as_posix())
cnet_dir = path.dirname(cnet_model_path)
cnet = load_model(cnet_model_path)
# clean up existing tensor files
rmtree(cnet_dir)
mkdir(cnet_dir)
# collate external tensor files into one
save_model(
cnet,
cnet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
del pipe_cnet
@torch.no_grad()
def convert_diffusion_diffusers(
conversion: ConversionContext,
model: Dict,
source: str,
2023-04-30 04:05:51 +00:00
format: str,
) -> Tuple[bool, str]:
2023-02-05 13:53:26 +00:00
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
2023-02-05 13:53:26 +00:00
"""
config = model.get("config", None)
image_size = model.get("image_size", None)
name = model.get("name")
pipe_type = model.get("pipeline", "txt2img")
single_vae = model.get("single_vae")
source = source or model.get("source")
replace_vae = model.get("vae")
version = model.get("version", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
config_path = (
None if config is None else path.join(conversion.model_path, "config", config)
)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
# diffusers go into a directory rather than .onnx file
logger.info(
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
)
if single_vae:
2023-02-05 13:53:26 +00:00
logger.info("converting model with single VAE")
2023-04-15 14:08:14 +00:00
cnet_only = False
if path.exists(dest_path) and path.exists(model_index):
if not path.exists(model_cnet):
2023-04-15 14:08:14 +00:00
logger.info(
"ONNX model was converted without a ControlNet UNet, converting one"
)
cnet_only = True
else:
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type)
2023-04-30 04:05:51 +00:00
_v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
)
if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9
2023-04-30 04:05:51 +00:00
if format == "safetensors":
pipe_args["from_safetensors"] = True
if path.exists(source) and path.isdir(source):
logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
elif path.exists(source) and path.isfile(source):
logger.debug("loading pipeline from SD checkpoint: %s", source)
2023-04-30 04:05:51 +00:00
pipe_ctor = partial(pipe_class, torch_dtype=dtype)
pipeline = download_from_original_stable_diffusion_ckpt(
source,
original_config_file=config_path,
2023-04-30 04:05:51 +00:00
pipeline_class=pipe_ctor,
**pipe_args,
2023-04-30 04:05:51 +00:00
).to(device)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
optimize_pipeline(conversion, pipeline)
output_path = Path(dest_path)
# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
2023-04-15 14:08:14 +00:00
if not cnet_only:
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(device=device, dtype=torch.int32),
None, # attention mask
None, # position ids
None, # output attentions
torch.tensor(True).to(device=device, dtype=torch.bool),
),
output_path=output_path / "text_encoder" / ONNX_MODEL,
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=conversion.opset,
half=conversion.half,
)
del pipeline.text_encoder
2023-01-30 00:42:05 +00:00
# UNET
2023-04-15 14:08:14 +00:00
logger.debug("UNET config: %s", pipeline.unet.config)
2023-01-30 00:42:05 +00:00
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(device=device, dtype=torch.long)
2023-01-30 00:42:05 +00:00
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
2023-01-30 00:42:05 +00:00
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / ONNX_MODEL
2023-04-15 14:08:14 +00:00
if not cnet_only:
onnx_export(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=device, dtype=dtype
),
unet_scale,
2023-02-05 13:53:26 +00:00
),
2023-04-15 14:08:14 +00:00
output_path=unet_path,
ordered_input_names=unet_inputs,
# has to be different from "sample" for correct tracing
output_names=["out_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True,
)
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path)
unet = load_model(unet_model_path)
2023-04-15 14:08:14 +00:00
# clean up existing tensor files
rmtree(unet_dir)
mkdir(unet_dir)
# collate external tensor files into one
save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
if not single_vae:
2023-04-29 18:27:39 +00:00
# if converting only the CNet, the rest of the model has already been converted
convert_diffusion_diffusers_cnet(
conversion,
source,
device,
output_path,
dtype,
unet_in_channels,
unet_sample_size,
num_tokens,
text_hidden_size,
2023-04-29 18:27:39 +00:00
unet=pipeline.unet,
)
else:
logger.debug("skipping CNet for single-VAE model")
2023-04-15 14:08:14 +00:00
2023-04-29 18:27:39 +00:00
del pipeline.unet
2023-04-15 14:08:14 +00:00
if cnet_only:
logger.info("done converting CNet")
return (True, dest_path)
# VAE
if replace_vae is not None:
logger.debug("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.vae = vae
if single_vae:
2023-02-05 13:53:26 +00:00
logger.debug("VAE config: %s", pipeline.vae.config)
2023-01-30 00:42:05 +00:00
# SINGLE VAE
vae_only = pipeline.vae
2023-01-30 00:42:05 +00:00
vae_latent_channels = vae_only.config.latent_channels
# forward only through the decoder part
vae_only.forward = vae_only.decode
onnx_export(
vae_only,
model_args=(
2023-02-05 13:53:26 +00:00
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae" / ONNX_MODEL,
2023-01-30 00:42:05 +00:00
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
2023-01-30 00:42:05 +00:00
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=conversion.half,
)
else:
# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
2023-02-05 13:53:26 +00:00
sample, return_dict
)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=device, dtype=dtype
2023-02-05 13:53:26 +00:00
),
False,
),
output_path=output_path / "vae_encoder" / ONNX_MODEL,
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=False, # https://github.com/ssube/onnx-web/issues/290
)
# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
2023-02-05 13:53:26 +00:00
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / ONNX_MODEL,
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=conversion.half,
)
del pipeline.vae
if single_vae:
onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
2023-02-05 13:53:26 +00:00
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
)
else:
onnx_pipeline = OnnxStableDiffusionPipeline(
2023-02-05 13:53:26 +00:00
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
2023-02-05 13:53:26 +00:00
logger.info("exporting ONNX model")
onnx_pipeline.save_pretrained(output_path)
2023-01-30 00:42:05 +00:00
logger.info("ONNX pipeline saved to %s", output_path)
del pipeline
del onnx_pipeline
if single_vae:
_ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
else:
_ = OnnxStableDiffusionPipeline.from_pretrained(
2023-02-05 13:53:26 +00:00
output_path, provider="CPUExecutionProvider"
)
2023-01-28 23:09:19 +00:00
logger.info("ONNX pipeline is loadable")
return (True, dest_path)