### # Parts of this file are copied or derived from: # https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py # # Originally by https://github.com/huggingface # Those portions *are not* covered by the MIT licensed used for the rest of the onnx-web project. # ...diffusers.pipelines.pipeline_onnx_stable_diffusion_upscale # HuggingFace code used under the Apache License, Version 2.0 # https://github.com/huggingface/diffusers/blob/main/LICENSE ### from logging import getLogger from os import mkdir, path from pathlib import Path from shutil import rmtree from typing import Any, Dict, Optional, Tuple, Union import torch from diffusers import AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( download_from_original_stable_diffusion_ckpt, ) from onnx import load_model, save_model from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...diffusers.load import available_pipelines, optimize_pipeline from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export from .checkpoint import convert_extract_checkpoint logger = getLogger(__name__) def get_model_version( source, map_location, size=None, version=None, ) -> Tuple[bool, Dict[str, Union[bool, int, str]]]: v2 = version is not None and "v2" in version opts = { "extract_ema": True, } try: checkpoint = load_tensor(source, map_location=map_location) if "global_step" in checkpoint: global_step = checkpoint["global_step"] else: logger.trace("global_step key not found in model") global_step = None if size is None: # NOTE: For stable diffusion 2 base one has to pass `image_size==512` # as it relies on a brittle global step parameter here size = 512 if global_step == 875000 else 768 opts["image_size"] = size key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: v2 = True if size != 512: # v2.1 needs to upcast attention logger.trace("setting upcast_attention") opts["upcast_attention"] = True if v2 and size != 512: opts["model_type"] = "FrozenOpenCLIPEmbedder" opts["prediction_type"] = "v_prediction" else: opts["model_type"] = "FrozenCLIPEmbedder" opts["prediction_type"] = "epsilon" except Exception: logger.debug("unable to load tensor for version check") pass return (v2, opts) @torch.no_grad() def convert_diffusion_diffusers_cnet( conversion: ConversionContext, source: str, device: str, output_path: Path, dtype, unet_in_channels, unet_sample_size, num_tokens, text_hidden_size, unet: Optional[Any] = None, v2: Optional[bool] = False, ): # CNet if unet is not None: logger.debug("creating CNet from existing UNet config") pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config) else: logger.debug("loading CNet from pretrained UNet config") pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet") pipe_cnet = pipe_cnet.to(device=device, dtype=dtype) run_gc() if is_torch_2_0: pipe_cnet.set_attn_processor(AttnProcessor()) optimize_pipeline(conversion, pipe_cnet) cnet_path = output_path / "cnet" / ONNX_MODEL onnx_export( pipe_cnet, model_args=( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( device=device, dtype=dtype ), torch.randn(2).to(device=device, dtype=dtype), torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), torch.randn(2, 320, unet_sample_size, unet_sample_size).to( device=device, dtype=dtype ), torch.randn(2, 320, unet_sample_size, unet_sample_size).to( device=device, dtype=dtype ), torch.randn(2, 320, unet_sample_size, unet_sample_size).to( device=device, dtype=dtype ), torch.randn(2, 320, unet_sample_size // 2, unet_sample_size // 2).to( device=device, dtype=dtype ), torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to( device=device, dtype=dtype ), torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to( device=device, dtype=dtype ), torch.randn(2, 640, unet_sample_size // 4, unet_sample_size // 4).to( device=device, dtype=dtype ), torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to( device=device, dtype=dtype ), torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to( device=device, dtype=dtype ), torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to( device=device, dtype=dtype ), torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to( device=device, dtype=dtype ), torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to( device=device, dtype=dtype ), torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to( device=device, dtype=dtype ), False, ), output_path=cnet_path, ordered_input_names=[ "sample", "timestep", "encoder_hidden_states", "down_block_0", "down_block_1", "down_block_2", "down_block_3", "down_block_4", "down_block_5", "down_block_6", "down_block_7", "down_block_8", "down_block_9", "down_block_10", "down_block_11", "mid_block_additional_residual", "return_dict", ], output_names=[ "out_sample" ], # has to be different from "sample" for correct tracing dynamic_axes={ "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "timestep": {0: "batch"}, "encoder_hidden_states": {0: "batch", 1: "sequence"}, "down_block_0": {0: "batch", 2: "height", 3: "width"}, "down_block_1": {0: "batch", 2: "height", 3: "width"}, "down_block_2": {0: "batch", 2: "height", 3: "width"}, "down_block_3": {0: "batch", 2: "height2", 3: "width2"}, "down_block_4": {0: "batch", 2: "height2", 3: "width2"}, "down_block_5": {0: "batch", 2: "height2", 3: "width2"}, "down_block_6": {0: "batch", 2: "height4", 3: "width4"}, "down_block_7": {0: "batch", 2: "height4", 3: "width4"}, "down_block_8": {0: "batch", 2: "height4", 3: "width4"}, "down_block_9": {0: "batch", 2: "height8", 3: "width8"}, "down_block_10": {0: "batch", 2: "height8", 3: "width8"}, "down_block_11": {0: "batch", 2: "height8", 3: "width8"}, "mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"}, }, opset=conversion.opset, half=conversion.half, external_data=True, # UNet is > 2GB, so the weights need to be split v2=v2, ) del pipe_cnet run_gc() return cnet_path @torch.no_grad() def collate_cnet(cnet_path): logger.debug("collating CNet external tensors") cnet_model_path = str(cnet_path.absolute().as_posix()) cnet_dir = path.dirname(cnet_model_path) cnet = load_model(cnet_model_path) # clean up existing tensor files rmtree(cnet_dir) mkdir(cnet_dir) # collate external tensor files into one save_model( cnet, cnet_model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=ONNX_WEIGHTS, convert_attribute=False, ) del cnet run_gc() @torch.no_grad() def convert_diffusion_diffusers( conversion: ConversionContext, model: Dict, source: str, format: Optional[str], hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ name = model.get("name") # optional config = model.get("config", None) image_size = model.get("image_size", None) pipe_type = model.get("pipeline", "txt2img") single_vae = model.get("single_vae", False) replace_vae = model.get("vae", None) version = model.get("version", None) device = conversion.training_device dtype = conversion.torch_dtype() logger.debug("using Torch dtype %s for pipeline", dtype) config_path = ( None if config is None else path.join(conversion.model_path, "config", config) ) dest_path = path.join(conversion.model_path, name) model_index = path.join(dest_path, "model_index.json") model_cnet = path.join(dest_path, "cnet", ONNX_MODEL) model_hash = path.join(dest_path, "hash.txt") # diffusers go into a directory rather than .onnx file logger.info( "converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path ) if single_vae: logger.info("converting model with single VAE") cnet_only = False if path.exists(dest_path) and path.exists(model_index): if "hash" in model and not path.exists(model_hash): logger.info("ONNX model does not have hash file, adding one") with open(model_hash, "w") as f: f.write(model["hash"]) if not single_vae and not path.exists(model_cnet): logger.info( "ONNX model was converted without a ControlNet UNet, converting one" ) cnet_only = True else: logger.info("ONNX model already exists, skipping") return (False, dest_path) pipe_class = available_pipelines.get(pipe_type) v2, pipe_args = get_model_version( source, conversion.map_location, size=image_size, version=version ) is_inpainting = False if pipe_type == "inpaint": pipe_args["num_in_channels"] = 9 is_inpainting = True if format == "safetensors": pipe_args["from_safetensors"] = True torch_source = None if path.exists(source) and path.isdir(source): logger.debug("loading pipeline from diffusers directory: %s", source) pipeline = pipe_class.from_pretrained( source, torch_dtype=dtype, use_auth_token=conversion.token, ).to(device) elif path.exists(source) and path.isfile(source): if conversion.extract: logger.debug("extracting SD checkpoint to Torch models: %s", source) torch_source = convert_extract_checkpoint( conversion, source, f"{name}-torch", is_inpainting=is_inpainting, config_file=config, vae_file=replace_vae, ) logger.debug("loading pipeline from extracted checkpoint: %s", torch_source) pipeline = pipe_class.from_pretrained( torch_source, torch_dtype=dtype, ).to(device) # VAE replacement already happened during extraction, skip replace_vae = None else: logger.debug("loading pipeline from SD checkpoint: %s", source) pipeline = download_from_original_stable_diffusion_ckpt( source, original_config_file=config_path, pipeline_class=pipe_class, vae_path=replace_vae, **pipe_args, ).to(device, torch_dtype=dtype) elif hf: logger.debug("downloading pretrained model from Huggingface hub: %s", source) pipeline = pipe_class.from_pretrained( source, torch_dtype=dtype, use_auth_token=conversion.token, ).to(device) else: logger.warning("pipeline source not found or not recognized: %s", source) raise ValueError(f"pipeline source not found or not recognized: {source}") optimize_pipeline(conversion, pipeline) output_path = Path(dest_path) # TEXT ENCODER num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( "A sample prompt", padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) if not cnet_only: encoder_path = output_path / "text_encoder" / ONNX_MODEL logger.info("exporting text encoder to %s", encoder_path) onnx_export( pipeline.text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files model_args=( text_input.input_ids.to(device=device, dtype=torch.int32), None, # attention mask None, # position ids None, # output attentions torch.tensor(True).to(device=device, dtype=torch.bool), ), output_path=encoder_path, ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output", "hidden_states"], dynamic_axes={ "input_ids": {0: "batch", 1: "sequence"}, }, opset=conversion.opset, half=conversion.half, ) del pipeline.text_encoder run_gc() # UNET logger.debug("UNET config: %s", pipeline.unet.config) if single_vae: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] unet_scale = torch.tensor(4).to(device=device, dtype=torch.long) else: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool) if is_torch_2_0: pipeline.unet.set_attn_processor(AttnProcessor()) unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size unet_path = output_path / "unet" / ONNX_MODEL if not cnet_only: logger.info("exporting UNet to %s", unet_path) onnx_export( pipeline.unet, model_args=( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( device=device, dtype=dtype ), torch.randn(2).to(device=device, dtype=dtype), torch.randn(2, num_tokens, text_hidden_size).to( device=device, dtype=dtype ), unet_scale, ), output_path=unet_path, ordered_input_names=unet_inputs, # has to be different from "sample" for correct tracing output_names=["out_sample"], dynamic_axes={ "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "timestep": {0: "batch"}, "encoder_hidden_states": {0: "batch", 1: "sequence"}, }, opset=conversion.opset, half=conversion.half, external_data=True, v2=v2, ) cnet_path = None if conversion.control and not single_vae and conversion.share_unet: logger.debug("converting CNet from loaded UNet") cnet_path = convert_diffusion_diffusers_cnet( conversion, source, device, output_path, dtype, unet_in_channels, unet_sample_size, num_tokens, text_hidden_size, unet=pipeline.unet, v2=v2, ) del pipeline.unet run_gc() if conversion.control and not single_vae and not conversion.share_unet: cnet_source = torch_source or source logger.info("loading and converting CNet from %s", cnet_source) cnet_path = convert_diffusion_diffusers_cnet( conversion, cnet_source, device, output_path, dtype, unet_in_channels, unet_sample_size, num_tokens, text_hidden_size, unet=None, v2=v2, ) if cnet_path is not None: collate_cnet(cnet_path) if cnet_only: logger.info("done converting CNet") return (True, dest_path) logger.debug("collating UNet external tensors") unet_model_path = str(unet_path.absolute().as_posix()) unet_dir = path.dirname(unet_model_path) unet = load_model(unet_model_path) # clean up existing tensor files rmtree(unet_dir) mkdir(unet_dir) # collate external tensor files into one save_model( unet, unet_model_path, save_as_external_data=True, all_tensors_to_one_file=True, location=ONNX_WEIGHTS, convert_attribute=False, ) del unet run_gc() # VAE if replace_vae is not None: if replace_vae.startswith("."): logger.debug( "custom VAE appears to be a local path, making it relative to the model path" ) replace_vae = path.join(conversion.model_path, replace_vae) logger.info("loading custom VAE: %s", replace_vae) vae = AutoencoderKL.from_pretrained(replace_vae) pipeline.vae = vae run_gc() if single_vae: logger.debug("VAE config: %s", pipeline.vae.config) # SINGLE VAE vae_only = pipeline.vae vae_latent_channels = vae_only.config.latent_channels # forward only through the decoder part vae_only.forward = vae_only.decode vae_path = output_path / "vae" / ONNX_MODEL logger.info("exporting VAE to %s", vae_path) onnx_export( vae_only, model_args=( torch.randn( 1, vae_latent_channels, unet_sample_size, unet_sample_size ).to(device=device, dtype=dtype), False, ), output_path=vae_path, ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], dynamic_axes={ "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, opset=conversion.opset, half=conversion.half, ) else: # VAE ENCODER vae_encoder = pipeline.vae vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode( sample, return_dict )[0].sample() vae_path = output_path / "vae_encoder" / ONNX_MODEL logger.info("exporting VAE encoder to %s", vae_path) onnx_export( vae_encoder, model_args=( torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( device=device, dtype=dtype ), False, ), output_path=vae_path, ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], dynamic_axes={ "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, opset=conversion.opset, half=False, # https://github.com/ssube/onnx-web/issues/290 ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels # forward only through the decoder part vae_decoder.forward = vae_encoder.decode vae_path = output_path / "vae_decoder" / ONNX_MODEL logger.info("exporting VAE decoder to %s", vae_path) onnx_export( vae_decoder, model_args=( torch.randn( 1, vae_latent_channels, unet_sample_size, unet_sample_size ).to(device=device, dtype=dtype), False, ), output_path=vae_path, ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], dynamic_axes={ "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, opset=conversion.opset, half=conversion.half, ) del pipeline.vae run_gc() if single_vae: logger.debug("reloading diffusion model with upscaling pipeline") onnx_pipeline = OnnxStableDiffusionUpscalePipeline( vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"), text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), tokenizer=pipeline.tokenizer, low_res_scheduler=pipeline.scheduler, unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), scheduler=pipeline.scheduler, ) else: logger.debug("reloading diffusion model with default pipeline") onnx_pipeline = OnnxStableDiffusionPipeline( vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), tokenizer=pipeline.tokenizer, unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), scheduler=pipeline.scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) logger.info("exporting pretrained ONNX model to %s", output_path) onnx_pipeline.save_pretrained(output_path) logger.info("ONNX pipeline saved to %s", output_path) del pipeline del onnx_pipeline run_gc() if conversion.reload: if single_vae: _ = OnnxStableDiffusionUpscalePipeline.from_pretrained( output_path, provider="CPUExecutionProvider" ) else: _ = OnnxStableDiffusionPipeline.from_pretrained( output_path, provider="CPUExecutionProvider" ) logger.info("ONNX pipeline is loadable") else: logger.debug("skipping ONNX reload test") return (True, dest_path)