from argparse import ArgumentParser from logging import getLogger from os import path from typing import Dict, Literal import numpy as np import torch from onnx import TensorProto, load, numpy_helper from onnx.checker import check_model from onnx.external_data_helper import ( convert_model_to_external_data, write_external_data_tensors, ) from safetensors.torch import load_file from onnx_web.convert.utils import ConversionContext logger = getLogger(__name__) ### # everything in this file is still super experimental and may not produce valid ONNX models ### def fix_initializer_name(key: str): # lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight # lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0 return key.replace(".", "_") def fix_node_name(key: str): fixed_name = fix_initializer_name(key.replace("/", "_")) if fixed_name[0] == "_": return fixed_name[1:] else: return fixed_name def merge_lora( base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"], lora_weights: "np.NDArray[np.float64]" = None, ): base_model = load(base_name) lora_models = [load_file(name) for name in lora_names] lora_count = len(lora_models) lora_weights = lora_weights or (np.ones((lora_count)) / lora_count) if dest_type == "text_encoder": lora_prefix = "lora_te_" else: lora_prefix = f"lora_{dest_type}_" blended: Dict[str, np.ndarray] = {} for lora_name, lora_model, lora_weight in zip(lora_names, lora_models, lora_weights): logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) for key in lora_model.keys(): if ".lora_down" in key and lora_prefix in key: base_key = key[: key.index(".lora_down")].replace( lora_prefix, "" ) up_key = key.replace("lora_down", "lora_up") alpha_key = key[: key.index("lora_down")] + "alpha" logger.info("blending weights for keys: %s, %s, %s", key, up_key, alpha_key) down_weight = lora_model[key].to(dtype=torch.float32) up_weight = lora_model[up_key].to(dtype=torch.float32) dim = down_weight.size()[0] alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy() try: if len(up_weight.size()) == 2: # blend for nn.Linear logger.info("blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha) weights = up_weight @ down_weight np_weights = (weights.numpy() * (alpha / dim)) elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1): # blend for nn.Conv2d 1x1 logger.info("blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha) weights = ( ( up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) ) .unsqueeze(2) .unsqueeze(3) ) np_weights = (weights.numpy() * (alpha / dim)) else: # TODO: add support for Conv2d 3x3 logger.warning("unknown LoRA node type at %s: %s", base_key, up_weight.shape[-2:]) continue np_weights *= lora_weight if base_key in blended: blended[base_key] += np_weights else: blended[base_key] = np_weights except Exception: logger.exception( "error blending weights for key %s", base_key ) logger.info( "updating %s of %s initializers: %s", len(blended.keys()), len(base_model.graph.initializer), list(blended.keys()) ) fixed_initializer_names = [ fix_initializer_name(node.name) for node in base_model.graph.initializer ] # logger.info("fixed initializer names: %s", fixed_initializer_names) fixed_node_names = [ fix_node_name(node.name) for node in base_model.graph.node ] # logger.info("fixed node names: %s", fixed_node_names) for base_key, weights in blended.items(): conv_key = base_key + "_Conv" matmul_key = base_key + "_MatMul" logger.info("key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, matmul_key in fixed_node_names) if conv_key in fixed_node_names: conv_idx = fixed_node_names.index(conv_key) conv_node = base_model.graph.node[conv_idx] logger.info("found conv node: %s", conv_node.name) # find weight initializer logger.info("conv inputs: %s", conv_node.input) weight_name = [n for n in conv_node.input if ".weight" in n][0] weight_name = fix_initializer_name(weight_name) weight_idx = fixed_initializer_names.index(weight_name) weight_node = base_model.graph.initializer[weight_idx] logger.info("found weight initializer: %s", weight_node.name) # blending base_weights = numpy_helper.to_array(weight_node) logger.info("found blended weights for conv: %s, %s", weights.shape, base_weights.shape) blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) blended = np.expand_dims(blended, (2, 3)) logger.info("blended weight shape: %s", blended.shape) # replace the original initializer updated_node = numpy_helper.from_array(blended, weight_node.name) del base_model.graph.initializer[weight_idx] base_model.graph.initializer.insert(weight_idx, updated_node) elif matmul_key in fixed_node_names: weight_idx = fixed_node_names.index(matmul_key) weight_node = base_model.graph.node[weight_idx] logger.info("found matmul node: %s", weight_node.name) # find the MatMul initializer logger.info("matmul inputs: %s", weight_node.input) matmul_name = [n for n in weight_node.input if "MatMul" in n][0] matmul_idx = fixed_initializer_names.index(matmul_name) matmul_node = base_model.graph.initializer[matmul_idx] logger.info("found matmul initializer: %s", matmul_node.name) # blending base_weights = numpy_helper.to_array(matmul_node) logger.info("found blended weights for matmul: %s, %s", weights.shape, base_weights.shape) blended = base_weights + weights.transpose() logger.info("blended weight shape: %s", blended.shape) # replace the original initializer updated_node = numpy_helper.from_array(blended, matmul_node.name) del base_model.graph.initializer[matmul_idx] base_model.graph.initializer.insert(matmul_idx, updated_node) else: logger.info("could not find any nodes for %s", base_key) logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node)) # save it back to disk # TODO: save to memory instead convert_model_to_external_data( base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb", ) bare_model = write_external_data_tensors(base_model, dest_path) dest_file = path.join(dest_path, f"lora-{dest_type}.onnx") with open(dest_file, "wb") as model_file: model_file.write(bare_model.SerializeToString()) logger.info("model saved, checking...") check_model(dest_file) logger.info("model successfully exported") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--base", type=str) parser.add_argument("--dest", type=str) parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) parser.add_argument("--lora_models", nargs='+', type=str) parser.add_argument("--lora_weights", nargs='+', type=float) args = parser.parse_args() logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights) merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)