feat(api): add conversion script for LoRAs from sd-scripts (#213)
This commit is contained in:
parent
0b1aa26be5
commit
4c17edb267
|
@ -1,16 +1,19 @@
|
|||
from itertools import groupby
|
||||
from argparse import ArgumentParser
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from sys import argv
|
||||
from typing import List, Literal, Tuple
|
||||
from typing import Dict, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import TensorProto, load, numpy_helper
|
||||
from onnx.checker import check_model
|
||||
from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors
|
||||
from onnx.external_data_helper import (
|
||||
convert_model_to_external_data,
|
||||
write_external_data_tensors,
|
||||
)
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# from ..utils import ConversionContext
|
||||
from onnx_web.convert.utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,65 +23,66 @@ logger = getLogger(__name__)
|
|||
###
|
||||
|
||||
|
||||
def fix_name(key: str):
|
||||
def fix_initializer_name(key: str):
|
||||
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
||||
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
||||
return key.replace(".", "_")
|
||||
|
||||
|
||||
def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"]):
|
||||
def fix_node_name(key: str):
|
||||
fixed_name = fix_initializer_name(key.replace("/", "_"))
|
||||
if fixed_name[0] == "_":
|
||||
return fixed_name[1:]
|
||||
else:
|
||||
return fixed_name
|
||||
|
||||
|
||||
def merge_lora(
|
||||
base_name: str,
|
||||
lora_names: str,
|
||||
dest_path: str,
|
||||
dest_type: Literal["text_encoder", "unet"],
|
||||
lora_weights: "np.NDArray[np.float64]" = None,
|
||||
):
|
||||
base_model = load(base_name)
|
||||
lora_models = [load_file(name) for name in lora_names.split(",")]
|
||||
|
||||
lora_nodes: List[Tuple[int, TensorProto]] = []
|
||||
|
||||
fixed_initialized_names = [fix_name(node.name) for node in base_model.graph.initializer]
|
||||
logger.info("fixed initializer names: %s", fixed_initialized_names)
|
||||
lora_models = [load_file(name) for name in lora_names]
|
||||
lora_count = len(lora_models)
|
||||
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
|
||||
|
||||
if dest_type == "text_encoder":
|
||||
lora_prefix = "lora_te_"
|
||||
elif dest_type == "unet":
|
||||
lora_prefix = "lora_unet_"
|
||||
else:
|
||||
lora_prefix = "lora_"
|
||||
lora_prefix = f"lora_{dest_type}_"
|
||||
|
||||
for i in range(len(fixed_initialized_names)):
|
||||
base_key = fixed_initialized_names[i]
|
||||
base_node = base_model.graph.initializer[i]
|
||||
|
||||
updates = []
|
||||
for lora_model in lora_models:
|
||||
blended: Dict[str, np.ndarray] = {}
|
||||
for lora_name, lora_model, lora_weight in zip(lora_names, lora_models, lora_weights):
|
||||
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||
for key in lora_model.keys():
|
||||
if ".lora_down" in key:
|
||||
original_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||
bias_key = original_key + "_bias"
|
||||
weight_key = original_key + "_weight"
|
||||
|
||||
if bias_key.startswith(base_key):
|
||||
print("found bias key:", base_key, bias_key)
|
||||
|
||||
if weight_key == base_key:
|
||||
print("down for key:", base_key, weight_key)
|
||||
if ".lora_down" in key and lora_prefix in key:
|
||||
base_key = key[: key.index(".lora_down")].replace(
|
||||
lora_prefix, ""
|
||||
)
|
||||
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
logger.info("blending weights for keys: %s, %s, %s", key, up_key, alpha_key)
|
||||
|
||||
down_weight = lora_model[key].to(dtype=torch.float32)
|
||||
up_weight = lora_model[up_key].to(dtype=torch.float32)
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_model.get(alpha_key).numpy() or dim
|
||||
|
||||
np_vals = numpy_helper.to_array(base_node)
|
||||
print("before shape", np_vals.shape, up_weight.shape, down_weight.shape)
|
||||
alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy()
|
||||
|
||||
try:
|
||||
if len(up_weight.size()) == 2:
|
||||
squoze = up_weight @ down_weight
|
||||
print(squoze.shape)
|
||||
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
|
||||
else:
|
||||
squoze = (
|
||||
# blend for nn.Linear
|
||||
logger.info("blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
|
||||
weights = up_weight @ down_weight
|
||||
np_weights = (weights.numpy() * (alpha / dim))
|
||||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
||||
# blend for nn.Conv2d 1x1
|
||||
logger.info("blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
|
||||
weights = (
|
||||
(
|
||||
up_weight.squeeze(3).squeeze(2)
|
||||
@ down_weight.squeeze(3).squeeze(2)
|
||||
|
@ -86,40 +90,109 @@ def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Liter
|
|||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
)
|
||||
print(squoze.shape)
|
||||
np_vals = np_vals + (alpha * squoze.numpy())
|
||||
print("after shape", np_vals.shape)
|
||||
|
||||
updates.append(np_vals)
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("error blending weights with key %s", weight_key)
|
||||
|
||||
if len(updates) == 0:
|
||||
logger.debug("no lora found for key %s", base_key)
|
||||
np_weights = (weights.numpy() * (alpha / dim))
|
||||
else:
|
||||
# blend updates together and append to lora_nodes
|
||||
logger.info("blending %s updated weights for key %s", len(updates), base_key)
|
||||
# TODO: add support for Conv2d 3x3
|
||||
logger.warning("unknown LoRA node type at %s: %s", base_key, up_weight.shape[-2:])
|
||||
continue
|
||||
|
||||
# TODO: allow individual alphas
|
||||
np_vals = sum(updates) / len(updates)
|
||||
np_weights *= lora_weight
|
||||
if base_key in blended:
|
||||
blended[base_key] += np_weights
|
||||
else:
|
||||
blended[base_key] = np_weights
|
||||
|
||||
retensor = numpy_helper.from_array(np_vals, base_node.name)
|
||||
logger.info("created new tensor with %s bytes", len(retensor.raw_data))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error blending weights for key %s", base_key
|
||||
)
|
||||
|
||||
# TypeError: does not support assignment
|
||||
lora_nodes.append((i, retensor))
|
||||
logger.info(
|
||||
"updating %s of %s initializers: %s",
|
||||
len(blended.keys()),
|
||||
len(base_model.graph.initializer),
|
||||
list(blended.keys())
|
||||
)
|
||||
|
||||
fixed_initializer_names = [
|
||||
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
||||
]
|
||||
# logger.info("fixed initializer names: %s", fixed_initializer_names)
|
||||
|
||||
fixed_node_names = [
|
||||
fix_node_name(node.name) for node in base_model.graph.node
|
||||
]
|
||||
# logger.info("fixed node names: %s", fixed_node_names)
|
||||
|
||||
|
||||
logger.info("updating %s of %s nodes", len(lora_nodes), len(base_model.graph.initializer))
|
||||
for idx, node in lora_nodes:
|
||||
del base_model.graph.initializer[idx]
|
||||
base_model.graph.initializer.insert(idx, node)
|
||||
for base_key, weights in blended.items():
|
||||
conv_key = base_key + "_Conv"
|
||||
matmul_key = base_key + "_MatMul"
|
||||
|
||||
logger.info("key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, matmul_key in fixed_node_names)
|
||||
|
||||
if conv_key in fixed_node_names:
|
||||
conv_idx = fixed_node_names.index(conv_key)
|
||||
conv_node = base_model.graph.node[conv_idx]
|
||||
logger.info("found conv node: %s", conv_node.name)
|
||||
|
||||
# find weight initializer
|
||||
logger.info("conv inputs: %s", conv_node.input)
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
weight_name = fix_initializer_name(weight_name)
|
||||
|
||||
weight_idx = fixed_initializer_names.index(weight_name)
|
||||
weight_node = base_model.graph.initializer[weight_idx]
|
||||
logger.info("found weight initializer: %s", weight_node.name)
|
||||
|
||||
# blending
|
||||
base_weights = numpy_helper.to_array(weight_node)
|
||||
logger.info("found blended weights for conv: %s, %s", weights.shape, base_weights.shape)
|
||||
|
||||
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||
blended = np.expand_dims(blended, (2, 3))
|
||||
logger.info("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(blended, weight_node.name)
|
||||
del base_model.graph.initializer[weight_idx]
|
||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||
elif matmul_key in fixed_node_names:
|
||||
weight_idx = fixed_node_names.index(matmul_key)
|
||||
weight_node = base_model.graph.node[weight_idx]
|
||||
logger.info("found matmul node: %s", weight_node.name)
|
||||
|
||||
# find the MatMul initializer
|
||||
logger.info("matmul inputs: %s", weight_node.input)
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
matmul_idx = fixed_initializer_names.index(matmul_name)
|
||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||
logger.info("found matmul initializer: %s", matmul_node.name)
|
||||
|
||||
# blending
|
||||
base_weights = numpy_helper.to_array(matmul_node)
|
||||
logger.info("found blended weights for matmul: %s, %s", weights.shape, base_weights.shape)
|
||||
|
||||
blended = base_weights + weights.transpose()
|
||||
logger.info("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(blended, matmul_node.name)
|
||||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||
else:
|
||||
logger.info("could not find any nodes for %s", base_key)
|
||||
|
||||
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
|
||||
|
||||
# save it back to disk
|
||||
# TODO: save to memory instead
|
||||
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb")
|
||||
convert_model_to_external_data(
|
||||
base_model,
|
||||
all_tensors_to_one_file=True,
|
||||
location=f"lora-{dest_type}-external.pb",
|
||||
)
|
||||
bare_model = write_external_data_tensors(base_model, dest_path)
|
||||
|
||||
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
|
||||
|
@ -133,4 +206,13 @@ def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Liter
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
merge_lora(*argv[1:])
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--base", type=str)
|
||||
parser.add_argument("--dest", type=str)
|
||||
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
||||
parser.add_argument("--lora_models", nargs='+', type=str)
|
||||
parser.add_argument("--lora_weights", nargs='+', type=float)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
|
||||
merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
from logging import getLogger, basicConfig, DEBUG
|
||||
from onnx import load_model, ModelProto
|
||||
from onnx.numpy_helper import to_array
|
||||
from sys import argv, stdout
|
||||
|
||||
|
||||
basicConfig(stream=stdout, level=DEBUG)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def diff_models(ref_model: ModelProto, cmp_model: ModelProto):
|
||||
if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer):
|
||||
logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer))
|
||||
else:
|
||||
for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer):
|
||||
if ref_init.name != cmp_init.name:
|
||||
logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name)
|
||||
elif ref_init.data_location != cmp_init.data_location:
|
||||
logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location)
|
||||
elif ref_init.data_type != cmp_init.data_type:
|
||||
logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type)
|
||||
elif len(ref_init.raw_data) != len(cmp_init.raw_data):
|
||||
logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data))
|
||||
elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0:
|
||||
ref_data = to_array(ref_init)
|
||||
cmp_data = to_array(cmp_init)
|
||||
data_diff = ref_data - cmp_data
|
||||
if data_diff.max() > 0:
|
||||
logger.info("raw data differs: %s", data_diff)
|
||||
else:
|
||||
logger.info("initializers are identical in all checked fields: %s", ref_init.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ref_path = argv[1]
|
||||
cmp_paths = argv[2:]
|
||||
|
||||
logger.info("loading reference model from %s", ref_path)
|
||||
ref_model = load_model(ref_path)
|
||||
|
||||
for cmp_path in cmp_paths:
|
||||
logger.info("loading comparison model from %s", cmp_path)
|
||||
cmp_model = load_model(cmp_path)
|
||||
diff_models(ref_model, cmp_model)
|
Loading…
Reference in New Issue