1
0
Fork 0

feat(api): add conversion script for LoRAs from sd-scripts (#213)

This commit is contained in:
Sean Sube 2023-03-14 18:00:26 -05:00
parent 0b1aa26be5
commit 4c17edb267
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 205 additions and 79 deletions

View File

@ -1,16 +1,19 @@
from itertools import groupby
from argparse import ArgumentParser
from logging import getLogger
from os import path
from sys import argv
from typing import List, Literal, Tuple
from typing import Dict, Literal
import numpy as np
import torch
from onnx import TensorProto, load, numpy_helper
from onnx.checker import check_model
from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors
from onnx.external_data_helper import (
convert_model_to_external_data,
write_external_data_tensors,
)
from safetensors.torch import load_file
# from ..utils import ConversionContext
from onnx_web.convert.utils import ConversionContext
logger = getLogger(__name__)
@ -20,65 +23,66 @@ logger = getLogger(__name__)
###
def fix_name(key: str):
def fix_initializer_name(key: str):
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
return key.replace(".", "_")
def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"]):
def fix_node_name(key: str):
fixed_name = fix_initializer_name(key.replace("/", "_"))
if fixed_name[0] == "_":
return fixed_name[1:]
else:
return fixed_name
def merge_lora(
base_name: str,
lora_names: str,
dest_path: str,
dest_type: Literal["text_encoder", "unet"],
lora_weights: "np.NDArray[np.float64]" = None,
):
base_model = load(base_name)
lora_models = [load_file(name) for name in lora_names.split(",")]
lora_nodes: List[Tuple[int, TensorProto]] = []
fixed_initialized_names = [fix_name(node.name) for node in base_model.graph.initializer]
logger.info("fixed initializer names: %s", fixed_initialized_names)
lora_models = [load_file(name) for name in lora_names]
lora_count = len(lora_models)
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
if dest_type == "text_encoder":
lora_prefix = "lora_te_"
elif dest_type == "unet":
lora_prefix = "lora_unet_"
else:
lora_prefix = "lora_"
lora_prefix = f"lora_{dest_type}_"
for i in range(len(fixed_initialized_names)):
base_key = fixed_initialized_names[i]
base_node = base_model.graph.initializer[i]
updates = []
for lora_model in lora_models:
blended: Dict[str, np.ndarray] = {}
for lora_name, lora_model, lora_weight in zip(lora_names, lora_models, lora_weights):
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
for key in lora_model.keys():
if ".lora_down" in key:
original_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
bias_key = original_key + "_bias"
weight_key = original_key + "_weight"
if bias_key.startswith(base_key):
print("found bias key:", base_key, bias_key)
if weight_key == base_key:
print("down for key:", base_key, weight_key)
if ".lora_down" in key and lora_prefix in key:
base_key = key[: key.index(".lora_down")].replace(
lora_prefix, ""
)
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.info("blending weights for keys: %s, %s, %s", key, up_key, alpha_key)
down_weight = lora_model[key].to(dtype=torch.float32)
up_weight = lora_model[up_key].to(dtype=torch.float32)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key).numpy() or dim
np_vals = numpy_helper.to_array(base_node)
print("before shape", np_vals.shape, up_weight.shape, down_weight.shape)
alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy()
try:
if len(up_weight.size()) == 2:
squoze = up_weight @ down_weight
print(squoze.shape)
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
else:
squoze = (
# blend for nn.Linear
logger.info("blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
weights = up_weight @ down_weight
np_weights = (weights.numpy() * (alpha / dim))
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
# blend for nn.Conv2d 1x1
logger.info("blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
weights = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
@ -86,40 +90,109 @@ def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Liter
.unsqueeze(2)
.unsqueeze(3)
)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())
print("after shape", np_vals.shape)
updates.append(np_vals)
break
except Exception as e:
logger.exception("error blending weights with key %s", weight_key)
if len(updates) == 0:
logger.debug("no lora found for key %s", base_key)
np_weights = (weights.numpy() * (alpha / dim))
else:
# blend updates together and append to lora_nodes
logger.info("blending %s updated weights for key %s", len(updates), base_key)
# TODO: add support for Conv2d 3x3
logger.warning("unknown LoRA node type at %s: %s", base_key, up_weight.shape[-2:])
continue
# TODO: allow individual alphas
np_vals = sum(updates) / len(updates)
np_weights *= lora_weight
if base_key in blended:
blended[base_key] += np_weights
else:
blended[base_key] = np_weights
retensor = numpy_helper.from_array(np_vals, base_node.name)
logger.info("created new tensor with %s bytes", len(retensor.raw_data))
except Exception:
logger.exception(
"error blending weights for key %s", base_key
)
# TypeError: does not support assignment
lora_nodes.append((i, retensor))
logger.info(
"updating %s of %s initializers: %s",
len(blended.keys()),
len(base_model.graph.initializer),
list(blended.keys())
)
fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer
]
# logger.info("fixed initializer names: %s", fixed_initializer_names)
fixed_node_names = [
fix_node_name(node.name) for node in base_model.graph.node
]
# logger.info("fixed node names: %s", fixed_node_names)
logger.info("updating %s of %s nodes", len(lora_nodes), len(base_model.graph.initializer))
for idx, node in lora_nodes:
del base_model.graph.initializer[idx]
base_model.graph.initializer.insert(idx, node)
for base_key, weights in blended.items():
conv_key = base_key + "_Conv"
matmul_key = base_key + "_MatMul"
logger.info("key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, matmul_key in fixed_node_names)
if conv_key in fixed_node_names:
conv_idx = fixed_node_names.index(conv_key)
conv_node = base_model.graph.node[conv_idx]
logger.info("found conv node: %s", conv_node.name)
# find weight initializer
logger.info("conv inputs: %s", conv_node.input)
weight_name = [n for n in conv_node.input if ".weight" in n][0]
weight_name = fix_initializer_name(weight_name)
weight_idx = fixed_initializer_names.index(weight_name)
weight_node = base_model.graph.initializer[weight_idx]
logger.info("found weight initializer: %s", weight_node.name)
# blending
base_weights = numpy_helper.to_array(weight_node)
logger.info("found blended weights for conv: %s, %s", weights.shape, base_weights.shape)
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
blended = np.expand_dims(blended, (2, 3))
logger.info("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(blended, weight_node.name)
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names:
weight_idx = fixed_node_names.index(matmul_key)
weight_node = base_model.graph.node[weight_idx]
logger.info("found matmul node: %s", weight_node.name)
# find the MatMul initializer
logger.info("matmul inputs: %s", weight_node.input)
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
matmul_idx = fixed_initializer_names.index(matmul_name)
matmul_node = base_model.graph.initializer[matmul_idx]
logger.info("found matmul initializer: %s", matmul_node.name)
# blending
base_weights = numpy_helper.to_array(matmul_node)
logger.info("found blended weights for matmul: %s, %s", weights.shape, base_weights.shape)
blended = base_weights + weights.transpose()
logger.info("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(blended, matmul_node.name)
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:
logger.info("could not find any nodes for %s", base_key)
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
# save it back to disk
# TODO: save to memory instead
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb")
convert_model_to_external_data(
base_model,
all_tensors_to_one_file=True,
location=f"lora-{dest_type}-external.pb",
)
bare_model = write_external_data_tensors(base_model, dest_path)
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
@ -133,4 +206,13 @@ def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Liter
if __name__ == "__main__":
merge_lora(*argv[1:])
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs='+', type=str)
parser.add_argument("--lora_weights", nargs='+', type=float)
args = parser.parse_args()
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)

44
api/scripts/onnx-diff.py Normal file
View File

@ -0,0 +1,44 @@
from logging import getLogger, basicConfig, DEBUG
from onnx import load_model, ModelProto
from onnx.numpy_helper import to_array
from sys import argv, stdout
basicConfig(stream=stdout, level=DEBUG)
logger = getLogger(__name__)
def diff_models(ref_model: ModelProto, cmp_model: ModelProto):
if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer):
logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer))
else:
for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer):
if ref_init.name != cmp_init.name:
logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name)
elif ref_init.data_location != cmp_init.data_location:
logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location)
elif ref_init.data_type != cmp_init.data_type:
logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type)
elif len(ref_init.raw_data) != len(cmp_init.raw_data):
logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data))
elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0:
ref_data = to_array(ref_init)
cmp_data = to_array(cmp_init)
data_diff = ref_data - cmp_data
if data_diff.max() > 0:
logger.info("raw data differs: %s", data_diff)
else:
logger.info("initializers are identical in all checked fields: %s", ref_init.name)
if __name__ == "__main__":
ref_path = argv[1]
cmp_paths = argv[2:]
logger.info("loading reference model from %s", ref_path)
ref_model = load_model(ref_path)
for cmp_path in cmp_paths:
logger.info("loading comparison model from %s", cmp_path)
cmp_model = load_model(cmp_path)
diff_models(ref_model, cmp_model)