blend LoRAs into existing model
This commit is contained in:
parent
f8658c88cd
commit
ce74183e97
|
@ -3,6 +3,11 @@ from typing import List, Tuple
|
||||||
|
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
|
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
|
||||||
|
from sys import argv
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import onnx.checker
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -100,6 +105,81 @@ def convert_diffusion_lora(part: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_key(key: str):
|
||||||
|
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
||||||
|
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
||||||
|
return key.replace(".", "_")
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora():
|
||||||
|
base_name = argv[1]
|
||||||
|
lora_name = argv[2]
|
||||||
|
|
||||||
|
base_model = load(base_name)
|
||||||
|
lora_model = safe_open(lora_name, framework="pt")
|
||||||
|
|
||||||
|
lora_nodes = []
|
||||||
|
for base_node in base_model.graph.initializer:
|
||||||
|
base_key = fix_key(base_node.name)
|
||||||
|
|
||||||
|
for key in lora_model.keys():
|
||||||
|
if "lora_down" in key:
|
||||||
|
lora_key = key[:key.index("lora_down")].replace("lora_unet_", "")
|
||||||
|
if lora_key.startswith(base_key):
|
||||||
|
print("down for key:", base_key, lora_key)
|
||||||
|
|
||||||
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
|
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
||||||
|
|
||||||
|
down_weight = lora_model.get_tensor(key).to(dtype=torch.float32)
|
||||||
|
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32)
|
||||||
|
|
||||||
|
dim = down_weight.size()[0]
|
||||||
|
alpha = lora_model.get(alpha_key).numpy() or dim
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
|
np_vals = numpy_helper.to_array(base_node)
|
||||||
|
print(np_vals.shape, up_weight.shape, down_weight.shape)
|
||||||
|
|
||||||
|
squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
print(squoze.shape)
|
||||||
|
|
||||||
|
np_vals = np_vals + (alpha * squoze.numpy())
|
||||||
|
|
||||||
|
try:
|
||||||
|
if len(up_weight.size()) == 2:
|
||||||
|
squoze = (up_weight @ down_weight)
|
||||||
|
print(squoze.shape)
|
||||||
|
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
|
||||||
|
else:
|
||||||
|
squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
print(squoze.shape)
|
||||||
|
np_vals = np_vals + (alpha * squoze.numpy())
|
||||||
|
|
||||||
|
# retensor = numpy_helper.from_array(np_vals, base_node.name)
|
||||||
|
retensor = helper.make_tensor(base_node.name, base_node.data_type, base_node.dim, np_vals, raw=True)
|
||||||
|
print(retensor)
|
||||||
|
|
||||||
|
# TypeError: does not support assignment
|
||||||
|
lora_nodes.append(retensor)
|
||||||
|
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
if retensor is None:
|
||||||
|
print("no lora found for key", base_key)
|
||||||
|
lora_nodes.append(base_node)
|
||||||
|
|
||||||
|
print(len(lora_nodes), len(base_model.graph.initializer))
|
||||||
|
del base_model.graph.initializer[:]
|
||||||
|
base_model.graph.initializer.extend(lora_nodes)
|
||||||
|
|
||||||
|
onnx.checker.check_model(base_model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert_diffusion_lora("unet")
|
convert_diffusion_lora("unet")
|
||||||
convert_diffusion_lora("text_encoder")
|
convert_diffusion_lora("text_encoder")
|
||||||
|
|
Loading…
Reference in New Issue