1
0
Fork 0

blend LoRAs into a valid ONNX UNet (#213)

This commit is contained in:
Sean Sube 2023-03-12 13:38:51 -05:00
parent cf429ad715
commit 0b1aa26be5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 96 additions and 172 deletions

View File

@ -1,15 +1,16 @@
from itertools import groupby
from logging import getLogger from logging import getLogger
from os import path from os import path
from sys import argv from sys import argv
from typing import List, Tuple from typing import List, Literal, Tuple
import onnx.checker
import torch import torch
from numpy import ndarray from onnx import TensorProto, load, numpy_helper
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model from onnx.checker import check_model
from safetensors import safe_open from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors
from safetensors.torch import load_file
from ..utils import ConversionContext # from ..utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -19,194 +20,117 @@ logger = getLogger(__name__)
### ###
def load_lora(filename: str): def fix_name(key: str):
model = load(filename)
for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape)
pass
return model
def blend_loras(
base: ModelProto, weights: List[ModelProto], alphas: List[float]
) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)
results = []
for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy()
for weight, alpha in zip(weights, alphas):
weight_node = next(
iter([f for f in weight.graph.initializer if f.name == base_node.name]),
None,
)
if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha
else:
logger.warning(
"missing weights: %s in %s", base_node.name, weight.doc_string
)
results.append((base_node, base_weights / total))
return results
def convert_diffusion_lora(context: ConversionContext, component: str):
lora_weights = [
f"diffusion-lora-jack/{component}/model.onnx",
f"diffusion-lora-taters/{component}/model.onnx",
]
base = load_lora(f"stable-diffusion-onnx-v1-5/{component}/model.onnx")
weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result))
del weights
del alphas
tensors = []
for node, tensor in result:
logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))
del result
graph = helper.make_graph(
base.graph.node,
base.graph.name,
base.graph.input,
base.graph.output,
tensors,
base.graph.doc_string,
base.graph.value_info,
base.graph.sparse_initializer,
)
model = helper.make_model(graph)
del model.opset_import[:]
opset = model.opset_import.add()
opset.version = 14
onnx_path = path.join(context.cache_path, f"lora-{component}.onnx")
tensor_path = path.join(context.cache_path, f"lora-{component}.tensors")
save_model(
model,
onnx_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=tensor_path,
)
logger.info(
"saved model to %s and tensors to %s",
onnx_path,
tensor_path,
)
def fix_key(key: str):
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight # lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0 # lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
return key.replace(".", "_") return key.replace(".", "_")
def merge_lora(): def merge_lora(base_name: str, lora_names: str, dest_path: str, dest_type: Literal["text_encoder", "unet"]):
base_name = argv[1]
lora_name = argv[2]
base_model = load(base_name) base_model = load(base_name)
lora_model = safe_open(lora_name, framework="pt") lora_models = [load_file(name) for name in lora_names.split(",")]
lora_nodes = [] lora_nodes: List[Tuple[int, TensorProto]] = []
for base_node in base_model.graph.initializer:
base_key = fix_key(base_node.name)
for key in lora_model.keys(): fixed_initialized_names = [fix_name(node.name) for node in base_model.graph.initializer]
if "lora_down" in key: logger.info("fixed initializer names: %s", fixed_initialized_names)
lora_key = key[: key.index("lora_down")].replace("lora_unet_", "")
if lora_key.startswith(base_key):
print("down for key:", base_key, lora_key)
up_key = key.replace("lora_down", "lora_up") if dest_type == "text_encoder":
alpha_key = key[: key.index("lora_down")] + "alpha" lora_prefix = "lora_te_"
elif dest_type == "unet":
lora_prefix = "lora_unet_"
else:
lora_prefix = "lora_"
down_weight = lora_model.get_tensor(key).to(dtype=torch.float32) for i in range(len(fixed_initialized_names)):
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32) base_key = fixed_initialized_names[i]
base_node = base_model.graph.initializer[i]
dim = down_weight.size()[0] updates = []
alpha = lora_model.get(alpha_key).numpy() or dim for lora_model in lora_models:
for key in lora_model.keys():
if ".lora_down" in key:
original_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
bias_key = original_key + "_bias"
weight_key = original_key + "_weight"
np_vals = numpy_helper.to_array(base_node) if bias_key.startswith(base_key):
print(np_vals.shape, up_weight.shape, down_weight.shape) print("found bias key:", base_key, bias_key)
squoze = ( if weight_key == base_key:
( print("down for key:", base_key, weight_key)
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy()) up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
try: down_weight = lora_model[key].to(dtype=torch.float32)
if len(up_weight.size()) == 2: up_weight = lora_model[up_key].to(dtype=torch.float32)
squoze = up_weight @ down_weight
print(squoze.shape) dim = down_weight.size()[0]
np_vals = np_vals + (squoze.numpy() * (alpha / dim)) alpha = lora_model.get(alpha_key).numpy() or dim
else:
squoze = ( np_vals = numpy_helper.to_array(base_node)
( print("before shape", np_vals.shape, up_weight.shape, down_weight.shape)
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2) try:
if len(up_weight.size()) == 2:
squoze = up_weight @ down_weight
print(squoze.shape)
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
else:
squoze = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
) )
.unsqueeze(2) print(squoze.shape)
.unsqueeze(3) np_vals = np_vals + (alpha * squoze.numpy())
) print("after shape", np_vals.shape)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())
# retensor = numpy_helper.from_array(np_vals, base_node.name) updates.append(np_vals)
retensor = helper.make_tensor(
base_node.name,
base_node.data_type,
base_node.dim,
np_vals,
raw=True,
)
print(retensor)
# TypeError: does not support assignment break
lora_nodes.append(retensor) except Exception as e:
logger.exception("error blending weights with key %s", weight_key)
break if len(updates) == 0:
except Exception as e: logger.debug("no lora found for key %s", base_key)
print(e) else:
# blend updates together and append to lora_nodes
logger.info("blending %s updated weights for key %s", len(updates), base_key)
if retensor is None: # TODO: allow individual alphas
print("no lora found for key", base_key) np_vals = sum(updates) / len(updates)
lora_nodes.append(base_node)
print(len(lora_nodes), len(base_model.graph.initializer)) retensor = numpy_helper.from_array(np_vals, base_node.name)
del base_model.graph.initializer[:] logger.info("created new tensor with %s bytes", len(retensor.raw_data))
base_model.graph.initializer.extend(lora_nodes)
onnx.checker.check_model(base_model) # TypeError: does not support assignment
lora_nodes.append((i, retensor))
logger.info("updating %s of %s nodes", len(lora_nodes), len(base_model.graph.initializer))
for idx, node in lora_nodes:
del base_model.graph.initializer[idx]
base_model.graph.initializer.insert(idx, node)
# save it back to disk
# TODO: save to memory instead
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}-external.pb")
bare_model = write_external_data_tensors(base_model, dest_path)
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
with open(dest_file, "wb") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("model saved, checking...")
check_model(dest_file)
logger.info("model successfully exported")
if __name__ == "__main__": if __name__ == "__main__":
context = ConversionContext.from_environ() merge_lora(*argv[1:])
convert_diffusion_lora(context, "unet")
convert_diffusion_lora(context, "text_encoder")