1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/lora.py

213 lines
6.7 KiB
Python

from logging import getLogger
from os import path
from sys import argv
from typing import List, Tuple
import onnx.checker
import torch
from numpy import ndarray
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
from safetensors import safe_open
from ..utils import ConversionContext
logger = getLogger(__name__)
###
# everything in this file is still super experimental and may not produce valid ONNX models
###
def load_lora(filename: str):
model = load(filename)
for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape)
pass
return model
def blend_loras(
base: ModelProto, weights: List[ModelProto], alphas: List[float]
) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)
results = []
for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy()
for weight, alpha in zip(weights, alphas):
weight_node = next(
iter([f for f in weight.graph.initializer if f.name == base_node.name]),
None,
)
if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha
else:
logger.warning(
"missing weights: %s in %s", base_node.name, weight.doc_string
)
results.append((base_node, base_weights / total))
return results
def convert_diffusion_lora(context: ConversionContext, component: str):
lora_weights = [
f"diffusion-lora-jack/{component}/model.onnx",
f"diffusion-lora-taters/{component}/model.onnx",
]
base = load_lora(f"stable-diffusion-onnx-v1-5/{component}/model.onnx")
weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result))
del weights
del alphas
tensors = []
for node, tensor in result:
logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))
del result
graph = helper.make_graph(
base.graph.node,
base.graph.name,
base.graph.input,
base.graph.output,
tensors,
base.graph.doc_string,
base.graph.value_info,
base.graph.sparse_initializer,
)
model = helper.make_model(graph)
del model.opset_import[:]
opset = model.opset_import.add()
opset.version = 14
onnx_path = path.join(context.cache_path, f"lora-{component}.onnx")
tensor_path = path.join(context.cache_path, f"lora-{component}.tensors")
save_model(
model,
onnx_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=tensor_path,
)
logger.info(
"saved model to %s and tensors to %s",
onnx_path,
tensor_path,
)
def fix_key(key: str):
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
return key.replace(".", "_")
def merge_lora():
base_name = argv[1]
lora_name = argv[2]
base_model = load(base_name)
lora_model = safe_open(lora_name, framework="pt")
lora_nodes = []
for base_node in base_model.graph.initializer:
base_key = fix_key(base_node.name)
for key in lora_model.keys():
if "lora_down" in key:
lora_key = key[: key.index("lora_down")].replace("lora_unet_", "")
if lora_key.startswith(base_key):
print("down for key:", base_key, lora_key)
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
down_weight = lora_model.get_tensor(key).to(dtype=torch.float32)
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key).numpy() or dim
np_vals = numpy_helper.to_array(base_node)
print(np_vals.shape, up_weight.shape, down_weight.shape)
squoze = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())
try:
if len(up_weight.size()) == 2:
squoze = up_weight @ down_weight
print(squoze.shape)
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
else:
squoze = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())
# retensor = numpy_helper.from_array(np_vals, base_node.name)
retensor = helper.make_tensor(
base_node.name,
base_node.data_type,
base_node.dim,
np_vals,
raw=True,
)
print(retensor)
# TypeError: does not support assignment
lora_nodes.append(retensor)
break
except Exception as e:
print(e)
if retensor is None:
print("no lora found for key", base_key)
lora_nodes.append(base_node)
print(len(lora_nodes), len(base_model.graph.initializer))
del base_model.graph.initializer[:]
base_model.graph.initializer.extend(lora_nodes)
onnx.checker.check_model(base_model)
if __name__ == "__main__":
context = ConversionContext.from_environ()
convert_diffusion_lora(context, "unet")
convert_diffusion_lora(context, "text_encoder")