blend LoRAs into a valid ONNX UNet (#213)
This commit is contained in:
parent
cf429ad715
commit
0b1aa26be5
|
@ -1,15 +1,16 @@
|
||||||
|
from itertools import groupby
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from sys import argv
|
from sys import argv
|
||||||
from typing import List, Tuple
|
from typing import List, Literal, Tuple
|
||||||
|
|
||||||
import onnx.checker
|
|
||||||
import torch
|
import torch
|
||||||
from numpy import ndarray
|
from onnx import TensorProto, load, numpy_helper
|
||||||
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
|
from onnx.checker import check_model
|
||||||
from safetensors import safe_open
|
from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from ..utils import ConversionContext
|
# from ..utils import ConversionContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -19,147 +20,57 @@ logger = getLogger(__name__)
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
def load_lora(filename: str):
|
def fix_name(key: str):
|
||||||
model = load(filename)
|
|
||||||
|
|
||||||
for weight in model.graph.initializer:
|
|
||||||
# print(weight.name, numpy_helper.to_array(weight).shape)
|
|
||||||
pass
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def blend_loras(
|
|
||||||
base: ModelProto, weights: List[ModelProto], alphas: List[float]
|
|
||||||
) -> List[Tuple[TensorProto, ndarray]]:
|
|
||||||
total = 1 + sum(alphas)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for base_node in base.graph.initializer:
|
|
||||||
logger.info("blending initializer node %s", base_node.name)
|
|
||||||
base_weights = numpy_helper.to_array(base_node).copy()
|
|
||||||
|
|
||||||
for weight, alpha in zip(weights, alphas):
|
|
||||||
weight_node = next(
|
|
||||||
iter([f for f in weight.graph.initializer if f.name == base_node.name]),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if weight_node is not None:
|
|
||||||
base_weights += numpy_helper.to_array(weight_node) * alpha
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"missing weights: %s in %s", base_node.name, weight.doc_string
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append((base_node, base_weights / total))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusion_lora(context: ConversionContext, component: str):
|
|
||||||
lora_weights = [
|
|
||||||
f"diffusion-lora-jack/{component}/model.onnx",
|
|
||||||
f"diffusion-lora-taters/{component}/model.onnx",
|
|
||||||
]
|
|
||||||
|
|
||||||
base = load_lora(f"stable-diffusion-onnx-v1-5/{component}/model.onnx")
|
|
||||||
weights = [load_lora(f) for f in lora_weights]
|
|
||||||
alphas = [1 / len(weights)] * len(weights)
|
|
||||||
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
|
|
||||||
|
|
||||||
result = blend_loras(base, weights, alphas)
|
|
||||||
logger.info("blended result keys: %s", len(result))
|
|
||||||
|
|
||||||
del weights
|
|
||||||
del alphas
|
|
||||||
|
|
||||||
tensors = []
|
|
||||||
for node, tensor in result:
|
|
||||||
logger.info("remaking tensor for %s", node.name)
|
|
||||||
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))
|
|
||||||
|
|
||||||
del result
|
|
||||||
|
|
||||||
graph = helper.make_graph(
|
|
||||||
base.graph.node,
|
|
||||||
base.graph.name,
|
|
||||||
base.graph.input,
|
|
||||||
base.graph.output,
|
|
||||||
tensors,
|
|
||||||
base.graph.doc_string,
|
|
||||||
base.graph.value_info,
|
|
||||||
base.graph.sparse_initializer,
|
|
||||||
)
|
|
||||||
model = helper.make_model(graph)
|
|
||||||
|
|
||||||
del model.opset_import[:]
|
|
||||||
opset = model.opset_import.add()
|
|
||||||
opset.version = 14
|
|
||||||
|
|
||||||
onnx_path = path.join(context.cache_path, f"lora-{component}.onnx")
|
|
||||||
tensor_path = path.join(context.cache_path, f"lora-{component}.tensors")
|
|
||||||
save_model(
|
|
||||||
model,
|
|
||||||
onnx_path,
|
|
||||||
save_as_external_data=True,
|
|
||||||
all_tensors_to_one_file=True,
|
|
||||||
location=tensor_path,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"saved model to %s and tensors to %s",
|
|
||||||
onnx_path,
|
|
||||||
tensor_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fix_key(key: str):
|
|
||||||
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
||||||
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
||||||
return key.replace(".", "_")
|
return key.replace(".", "_")
|
||||||
|
|
||||||
|
|
||||||
def merge_lora():
|
def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"]):
|
||||||
base_name = argv[1]
|
|
||||||
lora_name = argv[2]
|
|
||||||
|
|
||||||
base_model = load(base_name)
|
base_model = load(base_name)
|
||||||
lora_model = safe_open(lora_name, framework="pt")
|
lora_models = [load_file(name) for name in lora_names.split(",")]
|
||||||
|
|
||||||
lora_nodes = []
|
lora_nodes: List[Tuple[int, TensorProto]] = []
|
||||||
for base_node in base_model.graph.initializer:
|
|
||||||
base_key = fix_key(base_node.name)
|
|
||||||
|
|
||||||
|
fixed_initialized_names = [fix_name(node.name) for node in base_model.graph.initializer]
|
||||||
|
logger.info("fixed initializer names: %s", fixed_initialized_names)
|
||||||
|
|
||||||
|
if dest_type == "text_encoder":
|
||||||
|
lora_prefix = "lora_te_"
|
||||||
|
elif dest_type == "unet":
|
||||||
|
lora_prefix = "lora_unet_"
|
||||||
|
else:
|
||||||
|
lora_prefix = "lora_"
|
||||||
|
|
||||||
|
for i in range(len(fixed_initialized_names)):
|
||||||
|
base_key = fixed_initialized_names[i]
|
||||||
|
base_node = base_model.graph.initializer[i]
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
for lora_model in lora_models:
|
||||||
for key in lora_model.keys():
|
for key in lora_model.keys():
|
||||||
if "lora_down" in key:
|
if ".lora_down" in key:
|
||||||
lora_key = key[: key.index("lora_down")].replace("lora_unet_", "")
|
original_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||||
if lora_key.startswith(base_key):
|
bias_key = original_key + "_bias"
|
||||||
print("down for key:", base_key, lora_key)
|
weight_key = original_key + "_weight"
|
||||||
|
|
||||||
|
if bias_key.startswith(base_key):
|
||||||
|
print("found bias key:", base_key, bias_key)
|
||||||
|
|
||||||
|
if weight_key == base_key:
|
||||||
|
print("down for key:", base_key, weight_key)
|
||||||
|
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
|
|
||||||
down_weight = lora_model.get_tensor(key).to(dtype=torch.float32)
|
down_weight = lora_model[key].to(dtype=torch.float32)
|
||||||
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32)
|
up_weight = lora_model[up_key].to(dtype=torch.float32)
|
||||||
|
|
||||||
dim = down_weight.size()[0]
|
dim = down_weight.size()[0]
|
||||||
alpha = lora_model.get(alpha_key).numpy() or dim
|
alpha = lora_model.get(alpha_key).numpy() or dim
|
||||||
|
|
||||||
np_vals = numpy_helper.to_array(base_node)
|
np_vals = numpy_helper.to_array(base_node)
|
||||||
print(np_vals.shape, up_weight.shape, down_weight.shape)
|
print("before shape", np_vals.shape, up_weight.shape, down_weight.shape)
|
||||||
|
|
||||||
squoze = (
|
|
||||||
(
|
|
||||||
up_weight.squeeze(3).squeeze(2)
|
|
||||||
@ down_weight.squeeze(3).squeeze(2)
|
|
||||||
)
|
|
||||||
.unsqueeze(2)
|
|
||||||
.unsqueeze(3)
|
|
||||||
)
|
|
||||||
print(squoze.shape)
|
|
||||||
|
|
||||||
np_vals = np_vals + (alpha * squoze.numpy())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if len(up_weight.size()) == 2:
|
if len(up_weight.size()) == 2:
|
||||||
|
@ -177,36 +88,49 @@ def merge_lora():
|
||||||
)
|
)
|
||||||
print(squoze.shape)
|
print(squoze.shape)
|
||||||
np_vals = np_vals + (alpha * squoze.numpy())
|
np_vals = np_vals + (alpha * squoze.numpy())
|
||||||
|
print("after shape", np_vals.shape)
|
||||||
|
|
||||||
# retensor = numpy_helper.from_array(np_vals, base_node.name)
|
updates.append(np_vals)
|
||||||
retensor = helper.make_tensor(
|
|
||||||
base_node.name,
|
|
||||||
base_node.data_type,
|
|
||||||
base_node.dim,
|
|
||||||
np_vals,
|
|
||||||
raw=True,
|
|
||||||
)
|
|
||||||
print(retensor)
|
|
||||||
|
|
||||||
# TypeError: does not support assignment
|
|
||||||
lora_nodes.append(retensor)
|
|
||||||
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.exception("error blending weights with key %s", weight_key)
|
||||||
|
|
||||||
if retensor is None:
|
if len(updates) == 0:
|
||||||
print("no lora found for key", base_key)
|
logger.debug("no lora found for key %s", base_key)
|
||||||
lora_nodes.append(base_node)
|
else:
|
||||||
|
# blend updates together and append to lora_nodes
|
||||||
|
logger.info("blending %s updated weights for key %s", len(updates), base_key)
|
||||||
|
|
||||||
print(len(lora_nodes), len(base_model.graph.initializer))
|
# TODO: allow individual alphas
|
||||||
del base_model.graph.initializer[:]
|
np_vals = sum(updates) / len(updates)
|
||||||
base_model.graph.initializer.extend(lora_nodes)
|
|
||||||
|
|
||||||
onnx.checker.check_model(base_model)
|
retensor = numpy_helper.from_array(np_vals, base_node.name)
|
||||||
|
logger.info("created new tensor with %s bytes", len(retensor.raw_data))
|
||||||
|
|
||||||
|
# TypeError: does not support assignment
|
||||||
|
lora_nodes.append((i, retensor))
|
||||||
|
|
||||||
|
|
||||||
|
logger.info("updating %s of %s nodes", len(lora_nodes), len(base_model.graph.initializer))
|
||||||
|
for idx, node in lora_nodes:
|
||||||
|
del base_model.graph.initializer[idx]
|
||||||
|
base_model.graph.initializer.insert(idx, node)
|
||||||
|
|
||||||
|
# save it back to disk
|
||||||
|
# TODO: save to memory instead
|
||||||
|
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb")
|
||||||
|
bare_model = write_external_data_tensors(base_model, dest_path)
|
||||||
|
|
||||||
|
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
|
||||||
|
with open(dest_file, "wb") as model_file:
|
||||||
|
model_file.write(bare_model.SerializeToString())
|
||||||
|
|
||||||
|
logger.info("model saved, checking...")
|
||||||
|
check_model(dest_file)
|
||||||
|
|
||||||
|
logger.info("model successfully exported")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
context = ConversionContext.from_environ()
|
merge_lora(*argv[1:])
|
||||||
convert_diffusion_lora(context, "unet")
|
|
||||||
convert_diffusion_lora(context, "text_encoder")
|
|
||||||
|
|
Loading…
Reference in New Issue