1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/lora.py

186 lines
5.8 KiB
Python

from logging import getLogger
from typing import List, Tuple
from numpy import ndarray
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
from sys import argv
from safetensors import safe_open
import torch
import onnx.checker
logger = getLogger(__name__)
def load_lora(filename: str):
model = load(filename)
for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape)
pass
return model
def blend_loras(
base: ModelProto, weights: List[ModelProto], alphas: List[float]
) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)
results = []
for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy()
for weight, alpha in zip(weights, alphas):
weight_node = next(
iter([f for f in weight.graph.initializer if f.name == base_node.name]),
None,
)
if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha
else:
logger.warning(
"missing weights: %s in %s", base_node.name, weight.doc_string
)
results.append((base_node, base_weights / total))
return results
def convert_diffusion_lora(part: str):
lora_weights = [
f"diffusion-lora-jack/{part}/model.onnx",
f"diffusion-lora-taters/{part}/model.onnx",
]
base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx")
weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result))
del weights
del alphas
tensors = []
for node, tensor in result:
logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))
del result
graph = helper.make_graph(
base.graph.node,
base.graph.name,
base.graph.input,
base.graph.output,
tensors,
base.graph.doc_string,
base.graph.value_info,
base.graph.sparse_initializer,
)
model = helper.make_model(graph)
del model.opset_import[:]
opset = model.opset_import.add()
opset.version = 14
save_model(
model,
f"/tmp/lora-{part}.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"/tmp/lora-{part}.tensors",
)
logger.info(
"saved model to %s and tensors to %s",
f"/tmp/lora-{part}.onnx",
f"/tmp/lora-{part}.tensors",
)
def fix_key(key: str):
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
return key.replace(".", "_")
def merge_lora():
base_name = argv[1]
lora_name = argv[2]
base_model = load(base_name)
lora_model = safe_open(lora_name, framework="pt")
lora_nodes = []
for base_node in base_model.graph.initializer:
base_key = fix_key(base_node.name)
for key in lora_model.keys():
if "lora_down" in key:
lora_key = key[:key.index("lora_down")].replace("lora_unet_", "")
if lora_key.startswith(base_key):
print("down for key:", base_key, lora_key)
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
down_weight = lora_model.get_tensor(key).to(dtype=torch.float32)
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key).numpy() or dim
scale = alpha / dim
np_vals = numpy_helper.to_array(base_node)
print(np_vals.shape, up_weight.shape, down_weight.shape)
squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())
try:
if len(up_weight.size()) == 2:
squoze = (up_weight @ down_weight)
print(squoze.shape)
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
else:
squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())
# retensor = numpy_helper.from_array(np_vals, base_node.name)
retensor = helper.make_tensor(base_node.name, base_node.data_type, base_node.dim, np_vals, raw=True)
print(retensor)
# TypeError: does not support assignment
lora_nodes.append(retensor)
break
except Exception as e:
print(e)
if retensor is None:
print("no lora found for key", base_key)
lora_nodes.append(base_node)
print(len(lora_nodes), len(base_model.graph.initializer))
del base_model.graph.initializer[:]
base_model.graph.initializer.extend(lora_nodes)
onnx.checker.check_model(base_model)
if __name__ == "__main__":
convert_diffusion_lora("unet")
convert_diffusion_lora("text_encoder")