diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 47c09f59..59140f9d 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -3,6 +3,11 @@ from typing import List, Tuple from numpy import ndarray from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model +from sys import argv +from safetensors import safe_open + +import torch +import onnx.checker logger = getLogger(__name__) @@ -100,6 +105,81 @@ def convert_diffusion_lora(part: str): ) +def fix_key(key: str): + # lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight + # lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0 + return key.replace(".", "_") + + +def merge_lora(): + base_name = argv[1] + lora_name = argv[2] + + base_model = load(base_name) + lora_model = safe_open(lora_name, framework="pt") + + lora_nodes = [] + for base_node in base_model.graph.initializer: + base_key = fix_key(base_node.name) + + for key in lora_model.keys(): + if "lora_down" in key: + lora_key = key[:key.index("lora_down")].replace("lora_unet_", "") + if lora_key.startswith(base_key): + print("down for key:", base_key, lora_key) + + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[:key.index("lora_down")] + 'alpha' + + down_weight = lora_model.get_tensor(key).to(dtype=torch.float32) + up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32) + + dim = down_weight.size()[0] + alpha = lora_model.get(alpha_key).numpy() or dim + scale = alpha / dim + + np_vals = numpy_helper.to_array(base_node) + print(np_vals.shape, up_weight.shape, down_weight.shape) + + squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + print(squoze.shape) + + np_vals = np_vals + (alpha * squoze.numpy()) + + try: + if len(up_weight.size()) == 2: + squoze = (up_weight @ down_weight) + print(squoze.shape) + np_vals = np_vals + (squoze.numpy() * (alpha / dim)) + else: + squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + print(squoze.shape) + np_vals = np_vals + (alpha * squoze.numpy()) + + # retensor = numpy_helper.from_array(np_vals, base_node.name) + retensor = helper.make_tensor(base_node.name, base_node.data_type, base_node.dim, np_vals, raw=True) + print(retensor) + + # TypeError: does not support assignment + lora_nodes.append(retensor) + + break + except Exception as e: + print(e) + + + if retensor is None: + print("no lora found for key", base_key) + lora_nodes.append(base_node) + + print(len(lora_nodes), len(base_model.graph.initializer)) + del base_model.graph.initializer[:] + base_model.graph.initializer.extend(lora_nodes) + + onnx.checker.check_model(base_model) + + + if __name__ == "__main__": convert_diffusion_lora("unet") convert_diffusion_lora("text_encoder")