From 4c17edb2673956833d0c769fdd35879c65a75dde Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 14 Mar 2023 18:00:26 -0500 Subject: [PATCH] feat(api): add conversion script for LoRAs from sd-scripts (#213) --- api/onnx_web/convert/diffusion/lora.py | 240 +++++++++++++++++-------- api/scripts/onnx-diff.py | 44 +++++ 2 files changed, 205 insertions(+), 79 deletions(-) create mode 100644 api/scripts/onnx-diff.py diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 1608fd12..69af4db3 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,16 +1,19 @@ -from itertools import groupby +from argparse import ArgumentParser from logging import getLogger from os import path -from sys import argv -from typing import List, Literal, Tuple +from typing import Dict, Literal +import numpy as np import torch from onnx import TensorProto, load, numpy_helper from onnx.checker import check_model -from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors +from onnx.external_data_helper import ( + convert_model_to_external_data, + write_external_data_tensors, +) from safetensors.torch import load_file -# from ..utils import ConversionContext +from onnx_web.convert.utils import ConversionContext logger = getLogger(__name__) @@ -20,106 +23,176 @@ logger = getLogger(__name__) ### -def fix_name(key: str): +def fix_initializer_name(key: str): # lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight # lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0 return key.replace(".", "_") -def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"]): +def fix_node_name(key: str): + fixed_name = fix_initializer_name(key.replace("/", "_")) + if fixed_name[0] == "_": + return fixed_name[1:] + else: + return fixed_name + + +def merge_lora( + base_name: str, + lora_names: str, + dest_path: str, + dest_type: Literal["text_encoder", "unet"], + lora_weights: "np.NDArray[np.float64]" = None, +): base_model = load(base_name) - lora_models = [load_file(name) for name in lora_names.split(",")] - - lora_nodes: List[Tuple[int, TensorProto]] = [] - - fixed_initialized_names = [fix_name(node.name) for node in base_model.graph.initializer] - logger.info("fixed initializer names: %s", fixed_initialized_names) + lora_models = [load_file(name) for name in lora_names] + lora_count = len(lora_models) + lora_weights = lora_weights or (np.ones((lora_count)) / lora_count) if dest_type == "text_encoder": lora_prefix = "lora_te_" - elif dest_type == "unet": - lora_prefix = "lora_unet_" else: - lora_prefix = "lora_" + lora_prefix = f"lora_{dest_type}_" - for i in range(len(fixed_initialized_names)): - base_key = fixed_initialized_names[i] - base_node = base_model.graph.initializer[i] + blended: Dict[str, np.ndarray] = {} + for lora_name, lora_model, lora_weight in zip(lora_names, lora_models, lora_weights): + logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) + for key in lora_model.keys(): + if ".lora_down" in key and lora_prefix in key: + base_key = key[: key.index(".lora_down")].replace( + lora_prefix, "" + ) - updates = [] - for lora_model in lora_models: - for key in lora_model.keys(): - if ".lora_down" in key: - original_key = key[: key.index(".lora_down")].replace(lora_prefix, "") - bias_key = original_key + "_bias" - weight_key = original_key + "_weight" + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + logger.info("blending weights for keys: %s, %s, %s", key, up_key, alpha_key) - if bias_key.startswith(base_key): - print("found bias key:", base_key, bias_key) + down_weight = lora_model[key].to(dtype=torch.float32) + up_weight = lora_model[up_key].to(dtype=torch.float32) - if weight_key == base_key: - print("down for key:", base_key, weight_key) + dim = down_weight.size()[0] + alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy() - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" + try: + if len(up_weight.size()) == 2: + # blend for nn.Linear + logger.info("blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha) + weights = up_weight @ down_weight + np_weights = (weights.numpy() * (alpha / dim)) + elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1): + # blend for nn.Conv2d 1x1 + logger.info("blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha) + weights = ( + ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + ) + np_weights = (weights.numpy() * (alpha / dim)) + else: + # TODO: add support for Conv2d 3x3 + logger.warning("unknown LoRA node type at %s: %s", base_key, up_weight.shape[-2:]) + continue - down_weight = lora_model[key].to(dtype=torch.float32) - up_weight = lora_model[up_key].to(dtype=torch.float32) + np_weights *= lora_weight + if base_key in blended: + blended[base_key] += np_weights + else: + blended[base_key] = np_weights - dim = down_weight.size()[0] - alpha = lora_model.get(alpha_key).numpy() or dim + except Exception: + logger.exception( + "error blending weights for key %s", base_key + ) - np_vals = numpy_helper.to_array(base_node) - print("before shape", np_vals.shape, up_weight.shape, down_weight.shape) + logger.info( + "updating %s of %s initializers: %s", + len(blended.keys()), + len(base_model.graph.initializer), + list(blended.keys()) + ) - try: - if len(up_weight.size()) == 2: - squoze = up_weight @ down_weight - print(squoze.shape) - np_vals = np_vals + (squoze.numpy() * (alpha / dim)) - else: - squoze = ( - ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) - ) - print(squoze.shape) - np_vals = np_vals + (alpha * squoze.numpy()) - print("after shape", np_vals.shape) + fixed_initializer_names = [ + fix_initializer_name(node.name) for node in base_model.graph.initializer + ] + # logger.info("fixed initializer names: %s", fixed_initializer_names) - updates.append(np_vals) + fixed_node_names = [ + fix_node_name(node.name) for node in base_model.graph.node + ] + # logger.info("fixed node names: %s", fixed_node_names) - break - except Exception as e: - logger.exception("error blending weights with key %s", weight_key) - if len(updates) == 0: - logger.debug("no lora found for key %s", base_key) + for base_key, weights in blended.items(): + conv_key = base_key + "_Conv" + matmul_key = base_key + "_MatMul" + + logger.info("key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, matmul_key in fixed_node_names) + + if conv_key in fixed_node_names: + conv_idx = fixed_node_names.index(conv_key) + conv_node = base_model.graph.node[conv_idx] + logger.info("found conv node: %s", conv_node.name) + + # find weight initializer + logger.info("conv inputs: %s", conv_node.input) + weight_name = [n for n in conv_node.input if ".weight" in n][0] + weight_name = fix_initializer_name(weight_name) + + weight_idx = fixed_initializer_names.index(weight_name) + weight_node = base_model.graph.initializer[weight_idx] + logger.info("found weight initializer: %s", weight_node.name) + + # blending + base_weights = numpy_helper.to_array(weight_node) + logger.info("found blended weights for conv: %s, %s", weights.shape, base_weights.shape) + + blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) + blended = np.expand_dims(blended, (2, 3)) + logger.info("blended weight shape: %s", blended.shape) + + # replace the original initializer + updated_node = numpy_helper.from_array(blended, weight_node.name) + del base_model.graph.initializer[weight_idx] + base_model.graph.initializer.insert(weight_idx, updated_node) + elif matmul_key in fixed_node_names: + weight_idx = fixed_node_names.index(matmul_key) + weight_node = base_model.graph.node[weight_idx] + logger.info("found matmul node: %s", weight_node.name) + + # find the MatMul initializer + logger.info("matmul inputs: %s", weight_node.input) + matmul_name = [n for n in weight_node.input if "MatMul" in n][0] + + matmul_idx = fixed_initializer_names.index(matmul_name) + matmul_node = base_model.graph.initializer[matmul_idx] + logger.info("found matmul initializer: %s", matmul_node.name) + + # blending + base_weights = numpy_helper.to_array(matmul_node) + logger.info("found blended weights for matmul: %s, %s", weights.shape, base_weights.shape) + + blended = base_weights + weights.transpose() + logger.info("blended weight shape: %s", blended.shape) + + # replace the original initializer + updated_node = numpy_helper.from_array(blended, matmul_node.name) + del base_model.graph.initializer[matmul_idx] + base_model.graph.initializer.insert(matmul_idx, updated_node) else: - # blend updates together and append to lora_nodes - logger.info("blending %s updated weights for key %s", len(updates), base_key) + logger.info("could not find any nodes for %s", base_key) - # TODO: allow individual alphas - np_vals = sum(updates) / len(updates) - - retensor = numpy_helper.from_array(np_vals, base_node.name) - logger.info("created new tensor with %s bytes", len(retensor.raw_data)) - - # TypeError: does not support assignment - lora_nodes.append((i, retensor)) - - - logger.info("updating %s of %s nodes", len(lora_nodes), len(base_model.graph.initializer)) - for idx, node in lora_nodes: - del base_model.graph.initializer[idx] - base_model.graph.initializer.insert(idx, node) + logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node)) # save it back to disk # TODO: save to memory instead - convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb") + convert_model_to_external_data( + base_model, + all_tensors_to_one_file=True, + location=f"lora-{dest_type}-external.pb", + ) bare_model = write_external_data_tensors(base_model, dest_path) dest_file = path.join(dest_path, f"lora-{dest_type}.onnx") @@ -133,4 +206,13 @@ def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Liter if __name__ == "__main__": - merge_lora(*argv[1:]) + parser = ArgumentParser() + parser.add_argument("--base", type=str) + parser.add_argument("--dest", type=str) + parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) + parser.add_argument("--lora_models", nargs='+', type=str) + parser.add_argument("--lora_weights", nargs='+', type=float) + + args = parser.parse_args() + logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights) + merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights) diff --git a/api/scripts/onnx-diff.py b/api/scripts/onnx-diff.py new file mode 100644 index 00000000..ba4b4cf4 --- /dev/null +++ b/api/scripts/onnx-diff.py @@ -0,0 +1,44 @@ +from logging import getLogger, basicConfig, DEBUG +from onnx import load_model, ModelProto +from onnx.numpy_helper import to_array +from sys import argv, stdout + + +basicConfig(stream=stdout, level=DEBUG) + +logger = getLogger(__name__) + +def diff_models(ref_model: ModelProto, cmp_model: ModelProto): + if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer): + logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer)) + else: + for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer): + if ref_init.name != cmp_init.name: + logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name) + elif ref_init.data_location != cmp_init.data_location: + logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location) + elif ref_init.data_type != cmp_init.data_type: + logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type) + elif len(ref_init.raw_data) != len(cmp_init.raw_data): + logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data)) + elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0: + ref_data = to_array(ref_init) + cmp_data = to_array(cmp_init) + data_diff = ref_data - cmp_data + if data_diff.max() > 0: + logger.info("raw data differs: %s", data_diff) + else: + logger.info("initializers are identical in all checked fields: %s", ref_init.name) + + +if __name__ == "__main__": + ref_path = argv[1] + cmp_paths = argv[2:] + + logger.info("loading reference model from %s", ref_path) + ref_model = load_model(ref_path) + + for cmp_path in cmp_paths: + logger.info("loading comparison model from %s", cmp_path) + cmp_model = load_model(cmp_path) + diff_models(ref_model, cmp_model)