1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/lora.py

137 lines
5.2 KiB
Python

from itertools import groupby
from logging import getLogger
from os import path
from sys import argv
from typing import List, Literal, Tuple
import torch
from onnx import TensorProto, load, numpy_helper
from onnx.checker import check_model
from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors
from safetensors.torch import load_file
# from ..utils import ConversionContext
logger = getLogger(__name__)
###
# everything in this file is still super experimental and may not produce valid ONNX models
###
def fix_name(key: str):
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
return key.replace(".", "_")
def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"]):
base_model = load(base_name)
lora_models = [load_file(name) for name in lora_names.split(",")]
lora_nodes: List[Tuple[int, TensorProto]] = []
fixed_initialized_names = [fix_name(node.name) for node in base_model.graph.initializer]
logger.info("fixed initializer names: %s", fixed_initialized_names)
if dest_type == "text_encoder":
lora_prefix = "lora_te_"
elif dest_type == "unet":
lora_prefix = "lora_unet_"
else:
lora_prefix = "lora_"
for i in range(len(fixed_initialized_names)):
base_key = fixed_initialized_names[i]
base_node = base_model.graph.initializer[i]
updates = []
for lora_model in lora_models:
for key in lora_model.keys():
if ".lora_down" in key:
original_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
bias_key = original_key + "_bias"
weight_key = original_key + "_weight"
if bias_key.startswith(base_key):
print("found bias key:", base_key, bias_key)
if weight_key == base_key:
print("down for key:", base_key, weight_key)
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
down_weight = lora_model[key].to(dtype=torch.float32)
up_weight = lora_model[up_key].to(dtype=torch.float32)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key).numpy() or dim
np_vals = numpy_helper.to_array(base_node)
print("before shape", np_vals.shape, up_weight.shape, down_weight.shape)
try:
if len(up_weight.size()) == 2:
squoze = up_weight @ down_weight
print(squoze.shape)
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
else:
squoze = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())
print("after shape", np_vals.shape)
updates.append(np_vals)
break
except Exception as e:
logger.exception("error blending weights with key %s", weight_key)
if len(updates) == 0:
logger.debug("no lora found for key %s", base_key)
else:
# blend updates together and append to lora_nodes
logger.info("blending %s updated weights for key %s", len(updates), base_key)
# TODO: allow individual alphas
np_vals = sum(updates) / len(updates)
retensor = numpy_helper.from_array(np_vals, base_node.name)
logger.info("created new tensor with %s bytes", len(retensor.raw_data))
# TypeError: does not support assignment
lora_nodes.append((i, retensor))
logger.info("updating %s of %s nodes", len(lora_nodes), len(base_model.graph.initializer))
for idx, node in lora_nodes:
del base_model.graph.initializer[idx]
base_model.graph.initializer.insert(idx, node)
# save it back to disk
# TODO: save to memory instead
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb")
bare_model = write_external_data_tensors(base_model, dest_path)
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
with open(dest_file, "wb") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("model saved, checking...")
check_model(dest_file)
logger.info("model successfully exported")
if __name__ == "__main__":
merge_lora(*argv[1:])