1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/diffusers.py

670 lines
23 KiB
Python
Raw Normal View History

###
# Parts of this file are copied or derived from:
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
#
# Originally by https://github.com/huggingface
# Those portions *are not* covered by the MIT licensed used for the rest of the onnx-web project.
# ...diffusers.pipelines.pipeline_onnx_stable_diffusion_upscale
# HuggingFace code used under the Apache License, Version 2.0
# https://github.com/huggingface/diffusers/blob/main/LICENSE
###
from logging import getLogger
from os import mkdir, path
from pathlib import Path
from shutil import rmtree
from typing import Any, Dict, Optional, Tuple, Union
import torch
2023-01-28 23:09:19 +00:00
from diffusers import (
AutoencoderKL,
2023-01-28 23:09:19 +00:00
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInstructPix2PixPipeline,
2023-01-28 23:09:19 +00:00
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
2023-01-28 23:09:19 +00:00
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
2023-03-05 04:25:49 +00:00
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
2023-05-15 00:56:40 +00:00
from ...utils import run_gc
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
2023-05-20 20:12:39 +00:00
from .checkpoint import convert_extract_checkpoint
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
available_pipelines = {
"controlnet": StableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
"txt2img": StableDiffusionPipeline,
"upscale": StableDiffusionUpscalePipeline,
}
def get_model_version(
source,
map_location,
2023-04-30 04:05:51 +00:00
size=None,
version=None,
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
v2 = version is not None and "v2" in version
opts = {
"extract_ema": True,
}
try:
checkpoint = load_tensor(source, map_location=map_location)
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
2023-04-30 04:11:46 +00:00
logger.trace("global_step key not found in model")
global_step = None
if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768
opts["image_size"] = size
2023-04-30 04:05:51 +00:00
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
2023-04-30 04:11:46 +00:00
logger.trace("setting upcast_attention")
opts["upcast_attention"] = True
if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"
2023-04-30 04:12:15 +00:00
except Exception:
logger.debug("unable to load tensor for version check")
pass
return (v2, opts)
@torch.no_grad()
def convert_diffusion_diffusers_cnet(
conversion: ConversionContext,
source: str,
device: str,
output_path: Path,
dtype,
unet_in_channels,
unet_sample_size,
num_tokens,
text_hidden_size,
2023-04-29 18:27:39 +00:00
unet: Optional[Any] = None,
v2: Optional[bool] = False,
):
# CNet
2023-04-29 18:27:39 +00:00
if unet is not None:
logger.debug("creating CNet from existing UNet config")
2023-04-29 18:27:39 +00:00
pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config)
else:
logger.debug("loading CNet from pretrained UNet config")
2023-04-29 18:27:39 +00:00
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
run_gc()
if is_torch_2_0:
pipe_cnet.set_attn_processor(AttnProcessor())
optimize_pipeline(conversion, pipe_cnet)
cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
pipe_cnet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
False,
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
"encoder_hidden_states",
"down_block_0",
"down_block_1",
"down_block_2",
"down_block_3",
"down_block_4",
"down_block_5",
"down_block_6",
"down_block_7",
"down_block_8",
"down_block_9",
"down_block_10",
"down_block_11",
"mid_block_additional_residual",
"return_dict",
],
output_names=[
"out_sample"
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"down_block_0": {0: "batch", 2: "height", 3: "width"},
"down_block_1": {0: "batch", 2: "height", 3: "width"},
"down_block_2": {0: "batch", 2: "height", 3: "width"},
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True, # UNet is > 2GB, so the weights need to be split
v2=v2,
)
del pipe_cnet
run_gc()
return cnet_path
@torch.no_grad()
def collate_cnet(cnet_path):
logger.debug("collating CNet external tensors")
cnet_model_path = str(cnet_path.absolute().as_posix())
cnet_dir = path.dirname(cnet_model_path)
cnet = load_model(cnet_model_path)
# clean up existing tensor files
rmtree(cnet_dir)
mkdir(cnet_dir)
# collate external tensor files into one
save_model(
cnet,
cnet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
del cnet
run_gc()
@torch.no_grad()
def convert_diffusion_diffusers(
conversion: ConversionContext,
model: Dict,
source: str,
2023-04-30 04:05:51 +00:00
format: str,
hf: bool = False,
) -> Tuple[bool, str]:
2023-02-05 13:53:26 +00:00
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
2023-02-05 13:53:26 +00:00
"""
2023-05-20 20:12:39 +00:00
name = model.get("name")
# optional
config = model.get("config", None)
image_size = model.get("image_size", None)
pipe_type = model.get("pipeline", "txt2img")
2023-05-20 20:12:39 +00:00
single_vae = model.get("single_vae", False)
replace_vae = model.get("vae", None)
version = model.get("version", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
config_path = (
None if config is None else path.join(conversion.model_path, "config", config)
)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
# diffusers go into a directory rather than .onnx file
logger.info(
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
)
if single_vae:
2023-02-05 13:53:26 +00:00
logger.info("converting model with single VAE")
2023-04-15 14:08:14 +00:00
cnet_only = False
if path.exists(dest_path) and path.exists(model_index):
if not single_vae and not path.exists(model_cnet):
2023-04-15 14:08:14 +00:00
logger.info(
"ONNX model was converted without a ControlNet UNet, converting one"
)
cnet_only = True
else:
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type)
v2, pipe_args = get_model_version(
2023-04-30 04:05:51 +00:00
source, conversion.map_location, size=image_size, version=version
)
2023-06-10 16:30:47 +00:00
2023-06-10 16:01:02 +00:00
is_inpainting = False
if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9
2023-06-10 16:01:02 +00:00
is_inpainting = True
2023-04-30 04:05:51 +00:00
if format == "safetensors":
pipe_args["from_safetensors"] = True
2023-05-21 01:06:54 +00:00
torch_source = None
if path.exists(source) and path.isdir(source):
logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
elif path.exists(source) and path.isfile(source):
if conversion.extract:
2023-05-20 20:12:39 +00:00
logger.debug("extracting SD checkpoint to Torch models: %s", source)
torch_source = convert_extract_checkpoint(
conversion,
source,
f"{name}-torch",
2023-06-10 16:30:47 +00:00
is_inpainting=is_inpainting,
config_file=config,
2023-05-20 20:12:39 +00:00
vae_file=replace_vae,
)
2023-05-29 00:44:16 +00:00
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
pipeline = pipe_class.from_pretrained(
2023-05-20 20:12:39 +00:00
torch_source,
torch_dtype=dtype,
).to(device)
# VAE replacement already happened during extraction, skip
replace_vae = None
else:
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = download_from_original_stable_diffusion_ckpt(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
**pipe_args,
).to(device, torch_dtype=dtype)
elif hf:
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
optimize_pipeline(conversion, pipeline)
output_path = Path(dest_path)
# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
2023-04-15 14:08:14 +00:00
if not cnet_only:
encoder_path = output_path / "text_encoder" / ONNX_MODEL
logger.info("exporting text encoder to %s", encoder_path)
2023-04-15 14:08:14 +00:00
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(device=device, dtype=torch.int32),
None, # attention mask
None, # position ids
None, # output attentions
torch.tensor(True).to(device=device, dtype=torch.bool),
),
output_path=encoder_path,
2023-04-15 14:08:14 +00:00
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=conversion.opset,
half=conversion.half,
)
del pipeline.text_encoder
run_gc()
2023-01-30 00:42:05 +00:00
# UNET
2023-04-15 14:08:14 +00:00
logger.debug("UNET config: %s", pipeline.unet.config)
2023-01-30 00:42:05 +00:00
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(device=device, dtype=torch.long)
2023-01-30 00:42:05 +00:00
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
2023-01-30 00:42:05 +00:00
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / ONNX_MODEL
2023-04-15 14:08:14 +00:00
if not cnet_only:
logger.info("exporting UNet to %s", unet_path)
2023-04-15 14:08:14 +00:00
onnx_export(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=device, dtype=dtype
),
unet_scale,
2023-02-05 13:53:26 +00:00
),
2023-04-15 14:08:14 +00:00
output_path=unet_path,
ordered_input_names=unet_inputs,
# has to be different from "sample" for correct tracing
output_names=["out_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True,
v2=v2,
2023-04-15 14:08:14 +00:00
)
cnet_path = None
if conversion.control and not single_vae and conversion.share_unet:
logger.debug("converting CNet from loaded UNet")
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
source,
device,
output_path,
dtype,
unet_in_channels,
unet_sample_size,
num_tokens,
text_hidden_size,
2023-04-29 18:27:39 +00:00
unet=pipeline.unet,
v2=v2,
)
2023-04-15 14:08:14 +00:00
2023-04-29 18:27:39 +00:00
del pipeline.unet
run_gc()
2023-04-29 18:27:39 +00:00
if conversion.control and not single_vae and not conversion.share_unet:
2023-05-21 01:06:54 +00:00
cnet_source = torch_source or source
logger.info("loading and converting CNet from %s", cnet_source)
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
2023-05-21 01:06:54 +00:00
cnet_source,
device,
output_path,
dtype,
unet_in_channels,
unet_sample_size,
num_tokens,
text_hidden_size,
unet=None,
v2=v2,
)
if cnet_path is not None:
collate_cnet(cnet_path)
2023-04-15 14:08:14 +00:00
if cnet_only:
logger.info("done converting CNet")
return (True, dest_path)
logger.debug("collating UNet external tensors")
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path)
unet = load_model(unet_model_path)
# clean up existing tensor files
rmtree(unet_dir)
mkdir(unet_dir)
# collate external tensor files into one
save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
del unet
run_gc()
# VAE
if replace_vae is not None:
if replace_vae.startswith("."):
2023-05-29 00:44:16 +00:00
logger.debug(
"custom VAE appears to be a local path, making it relative to the model path"
)
replace_vae = path.join(conversion.model_path, replace_vae)
logger.info("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.vae = vae
run_gc()
if single_vae:
2023-02-05 13:53:26 +00:00
logger.debug("VAE config: %s", pipeline.vae.config)
2023-01-30 00:42:05 +00:00
# SINGLE VAE
vae_only = pipeline.vae
2023-01-30 00:42:05 +00:00
vae_latent_channels = vae_only.config.latent_channels
# forward only through the decoder part
vae_only.forward = vae_only.decode
vae_path = output_path / "vae" / ONNX_MODEL
logger.info("exporting VAE to %s", vae_path)
onnx_export(
vae_only,
model_args=(
2023-02-05 13:53:26 +00:00
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=device, dtype=dtype),
False,
),
output_path=vae_path,
2023-01-30 00:42:05 +00:00
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
2023-01-30 00:42:05 +00:00
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=conversion.half,
)
else:
# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
2023-02-05 13:53:26 +00:00
sample, return_dict
)[0].sample()
vae_path = output_path / "vae_encoder" / ONNX_MODEL
logger.info("exporting VAE encoder to %s", vae_path)
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=device, dtype=dtype
2023-02-05 13:53:26 +00:00
),
False,
),
output_path=vae_path,
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=False, # https://github.com/ssube/onnx-web/issues/290
)
# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
vae_path = output_path / "vae_decoder" / ONNX_MODEL
logger.info("exporting VAE decoder to %s", vae_path)
onnx_export(
vae_decoder,
model_args=(
2023-02-05 13:53:26 +00:00
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=device, dtype=dtype),
False,
),
output_path=vae_path,
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=conversion.opset,
half=conversion.half,
)
del pipeline.vae
run_gc()
if single_vae:
logger.debug("reloading diffusion model with upscaling pipeline")
onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
2023-02-05 13:53:26 +00:00
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
)
else:
logger.debug("reloading diffusion model with default pipeline")
onnx_pipeline = OnnxStableDiffusionPipeline(
2023-02-05 13:53:26 +00:00
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
logger.info("exporting pretrained ONNX model to %s", output_path)
onnx_pipeline.save_pretrained(output_path)
2023-01-30 00:42:05 +00:00
logger.info("ONNX pipeline saved to %s", output_path)
del pipeline
del onnx_pipeline
run_gc()
if conversion.reload:
if single_vae:
_ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
else:
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
logger.info("ONNX pipeline is loadable")
else:
logger.debug("skipping ONNX reload test")
return (True, dest_path)