From 0b1aa26be5c12f6f4578564fd55efdbb49181985 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Mar 2023 13:38:51 -0500 Subject: [PATCH] blend LoRAs into a valid ONNX UNet (#213) --- api/onnx_web/convert/diffusion/lora.py | 268 +++++++++---------------- 1 file changed, 96 insertions(+), 172 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 57e85d79..1608fd12 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,15 +1,16 @@ +from itertools import groupby from logging import getLogger from os import path from sys import argv -from typing import List, Tuple +from typing import List, Literal, Tuple -import onnx.checker import torch -from numpy import ndarray -from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model -from safetensors import safe_open +from onnx import TensorProto, load, numpy_helper +from onnx.checker import check_model +from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors +from safetensors.torch import load_file -from ..utils import ConversionContext +# from ..utils import ConversionContext logger = getLogger(__name__) @@ -19,194 +20,117 @@ logger = getLogger(__name__) ### -def load_lora(filename: str): - model = load(filename) - - for weight in model.graph.initializer: - # print(weight.name, numpy_helper.to_array(weight).shape) - pass - - return model - - -def blend_loras( - base: ModelProto, weights: List[ModelProto], alphas: List[float] -) -> List[Tuple[TensorProto, ndarray]]: - total = 1 + sum(alphas) - - results = [] - - for base_node in base.graph.initializer: - logger.info("blending initializer node %s", base_node.name) - base_weights = numpy_helper.to_array(base_node).copy() - - for weight, alpha in zip(weights, alphas): - weight_node = next( - iter([f for f in weight.graph.initializer if f.name == base_node.name]), - None, - ) - - if weight_node is not None: - base_weights += numpy_helper.to_array(weight_node) * alpha - else: - logger.warning( - "missing weights: %s in %s", base_node.name, weight.doc_string - ) - - results.append((base_node, base_weights / total)) - - return results - - -def convert_diffusion_lora(context: ConversionContext, component: str): - lora_weights = [ - f"diffusion-lora-jack/{component}/model.onnx", - f"diffusion-lora-taters/{component}/model.onnx", - ] - - base = load_lora(f"stable-diffusion-onnx-v1-5/{component}/model.onnx") - weights = [load_lora(f) for f in lora_weights] - alphas = [1 / len(weights)] * len(weights) - logger.info("blending LoRAs with alphas: %s, %s", weights, alphas) - - result = blend_loras(base, weights, alphas) - logger.info("blended result keys: %s", len(result)) - - del weights - del alphas - - tensors = [] - for node, tensor in result: - logger.info("remaking tensor for %s", node.name) - tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor)) - - del result - - graph = helper.make_graph( - base.graph.node, - base.graph.name, - base.graph.input, - base.graph.output, - tensors, - base.graph.doc_string, - base.graph.value_info, - base.graph.sparse_initializer, - ) - model = helper.make_model(graph) - - del model.opset_import[:] - opset = model.opset_import.add() - opset.version = 14 - - onnx_path = path.join(context.cache_path, f"lora-{component}.onnx") - tensor_path = path.join(context.cache_path, f"lora-{component}.tensors") - save_model( - model, - onnx_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - location=tensor_path, - ) - logger.info( - "saved model to %s and tensors to %s", - onnx_path, - tensor_path, - ) - - -def fix_key(key: str): +def fix_name(key: str): # lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight # lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0 return key.replace(".", "_") -def merge_lora(): - base_name = argv[1] - lora_name = argv[2] - +def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"]): base_model = load(base_name) - lora_model = safe_open(lora_name, framework="pt") + lora_models = [load_file(name) for name in lora_names.split(",")] - lora_nodes = [] - for base_node in base_model.graph.initializer: - base_key = fix_key(base_node.name) + lora_nodes: List[Tuple[int, TensorProto]] = [] - for key in lora_model.keys(): - if "lora_down" in key: - lora_key = key[: key.index("lora_down")].replace("lora_unet_", "") - if lora_key.startswith(base_key): - print("down for key:", base_key, lora_key) + fixed_initialized_names = [fix_name(node.name) for node in base_model.graph.initializer] + logger.info("fixed initializer names: %s", fixed_initialized_names) - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" + if dest_type == "text_encoder": + lora_prefix = "lora_te_" + elif dest_type == "unet": + lora_prefix = "lora_unet_" + else: + lora_prefix = "lora_" - down_weight = lora_model.get_tensor(key).to(dtype=torch.float32) - up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32) + for i in range(len(fixed_initialized_names)): + base_key = fixed_initialized_names[i] + base_node = base_model.graph.initializer[i] - dim = down_weight.size()[0] - alpha = lora_model.get(alpha_key).numpy() or dim + updates = [] + for lora_model in lora_models: + for key in lora_model.keys(): + if ".lora_down" in key: + original_key = key[: key.index(".lora_down")].replace(lora_prefix, "") + bias_key = original_key + "_bias" + weight_key = original_key + "_weight" - np_vals = numpy_helper.to_array(base_node) - print(np_vals.shape, up_weight.shape, down_weight.shape) + if bias_key.startswith(base_key): + print("found bias key:", base_key, bias_key) - squoze = ( - ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) - ) - print(squoze.shape) + if weight_key == base_key: + print("down for key:", base_key, weight_key) - np_vals = np_vals + (alpha * squoze.numpy()) + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" - try: - if len(up_weight.size()) == 2: - squoze = up_weight @ down_weight - print(squoze.shape) - np_vals = np_vals + (squoze.numpy() * (alpha / dim)) - else: - squoze = ( - ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) + down_weight = lora_model[key].to(dtype=torch.float32) + up_weight = lora_model[up_key].to(dtype=torch.float32) + + dim = down_weight.size()[0] + alpha = lora_model.get(alpha_key).numpy() or dim + + np_vals = numpy_helper.to_array(base_node) + print("before shape", np_vals.shape, up_weight.shape, down_weight.shape) + + try: + if len(up_weight.size()) == 2: + squoze = up_weight @ down_weight + print(squoze.shape) + np_vals = np_vals + (squoze.numpy() * (alpha / dim)) + else: + squoze = ( + ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) ) - .unsqueeze(2) - .unsqueeze(3) - ) - print(squoze.shape) - np_vals = np_vals + (alpha * squoze.numpy()) + print(squoze.shape) + np_vals = np_vals + (alpha * squoze.numpy()) + print("after shape", np_vals.shape) - # retensor = numpy_helper.from_array(np_vals, base_node.name) - retensor = helper.make_tensor( - base_node.name, - base_node.data_type, - base_node.dim, - np_vals, - raw=True, - ) - print(retensor) + updates.append(np_vals) - # TypeError: does not support assignment - lora_nodes.append(retensor) + break + except Exception as e: + logger.exception("error blending weights with key %s", weight_key) - break - except Exception as e: - print(e) + if len(updates) == 0: + logger.debug("no lora found for key %s", base_key) + else: + # blend updates together and append to lora_nodes + logger.info("blending %s updated weights for key %s", len(updates), base_key) - if retensor is None: - print("no lora found for key", base_key) - lora_nodes.append(base_node) + # TODO: allow individual alphas + np_vals = sum(updates) / len(updates) - print(len(lora_nodes), len(base_model.graph.initializer)) - del base_model.graph.initializer[:] - base_model.graph.initializer.extend(lora_nodes) + retensor = numpy_helper.from_array(np_vals, base_node.name) + logger.info("created new tensor with %s bytes", len(retensor.raw_data)) - onnx.checker.check_model(base_model) + # TypeError: does not support assignment + lora_nodes.append((i, retensor)) + + + logger.info("updating %s of %s nodes", len(lora_nodes), len(base_model.graph.initializer)) + for idx, node in lora_nodes: + del base_model.graph.initializer[idx] + base_model.graph.initializer.insert(idx, node) + + # save it back to disk + # TODO: save to memory instead + convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb") + bare_model = write_external_data_tensors(base_model, dest_path) + + dest_file = path.join(dest_path, f"lora-{dest_type}.onnx") + with open(dest_file, "wb") as model_file: + model_file.write(bare_model.SerializeToString()) + + logger.info("model saved, checking...") + check_model(dest_file) + + logger.info("model successfully exported") if __name__ == "__main__": - context = ConversionContext.from_environ() - convert_diffusion_lora(context, "unet") - convert_diffusion_lora(context, "text_encoder") + merge_lora(*argv[1:])