384 lines
12 KiB
Python
384 lines
12 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
from onnx import GraphProto, ModelProto, NodeProto
|
|
from onnx.numpy_helper import from_array
|
|
|
|
from onnx_web.convert.diffusion.lora import (
|
|
blend_loras,
|
|
blend_node_conv_gemm,
|
|
blend_node_matmul,
|
|
blend_weights_loha,
|
|
blend_weights_lora,
|
|
buffer_external_data_tensors,
|
|
fix_initializer_name,
|
|
fix_node_name,
|
|
fix_xl_names,
|
|
interp_to_match,
|
|
kernel_slice,
|
|
sum_weights,
|
|
)
|
|
|
|
|
|
class SumWeightsTests(unittest.TestCase):
|
|
def test_same_shape(self):
|
|
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4)))
|
|
self.assertEqual(weights.shape, (4, 4))
|
|
|
|
def test_1x1_kernel(self):
|
|
weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4)))
|
|
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
|
|
|
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
|
|
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
|
|
|
def test_3x3_kernel(self):
|
|
"""
|
|
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
|
|
self.assertEqual(weights.shape, (4, 4, 1, 1))
|
|
"""
|
|
pass
|
|
|
|
|
|
class BufferExternalDataTensorTests(unittest.TestCase):
|
|
def test_basic_external(self):
|
|
model = ModelProto(
|
|
graph=GraphProto(
|
|
initializer=[
|
|
from_array(np.zeros((4, 4))),
|
|
],
|
|
)
|
|
)
|
|
(slim_model, external_weights) = buffer_external_data_tensors(model)
|
|
|
|
self.assertEqual(
|
|
len(slim_model.graph.initializer), len(model.graph.initializer)
|
|
)
|
|
self.assertEqual(len(external_weights), 1)
|
|
|
|
|
|
class FixInitializerKeyTests(unittest.TestCase):
|
|
def test_fix_name(self):
|
|
inputs = [
|
|
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"
|
|
]
|
|
outputs = [
|
|
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"
|
|
]
|
|
|
|
for input, output in zip(inputs, outputs):
|
|
self.assertEqual(fix_initializer_name(input), output)
|
|
|
|
|
|
class FixNodeNameTests(unittest.TestCase):
|
|
def test_fix_name(self):
|
|
inputs = [
|
|
"lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight",
|
|
"_prefix",
|
|
]
|
|
outputs = [
|
|
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight",
|
|
"prefix",
|
|
]
|
|
|
|
for input, output in zip(inputs, outputs):
|
|
self.assertEqual(fix_node_name(input), output)
|
|
|
|
|
|
class FixXLNameTests(unittest.TestCase):
|
|
def test_empty(self):
|
|
nodes = {}
|
|
fixed = fix_xl_names(nodes, [])
|
|
|
|
self.assertEqual(fixed, {})
|
|
|
|
def test_input_block(self):
|
|
nodes = {
|
|
"input_block_proj.lora_down.weight": {},
|
|
}
|
|
fixed = fix_xl_names(
|
|
nodes,
|
|
[
|
|
NodeProto(name="/down_blocks_proj/MatMul"),
|
|
],
|
|
)
|
|
|
|
self.assertEqual(
|
|
fixed,
|
|
{
|
|
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
|
|
},
|
|
)
|
|
|
|
def test_middle_block(self):
|
|
nodes = {
|
|
"middle_block_proj.lora_down.weight": {},
|
|
}
|
|
fixed = fix_xl_names(
|
|
nodes,
|
|
[
|
|
NodeProto(name="/mid_blocks_proj/MatMul"),
|
|
],
|
|
)
|
|
|
|
self.assertEqual(
|
|
fixed,
|
|
{
|
|
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
|
|
},
|
|
)
|
|
|
|
def test_output_block(self):
|
|
pass
|
|
|
|
def test_text_model(self):
|
|
pass
|
|
|
|
def test_unknown_block(self):
|
|
pass
|
|
|
|
def test_unmatched_block(self):
|
|
nodes = {
|
|
"lora_unet.input_block.lora_down.weight": {},
|
|
}
|
|
fixed = fix_xl_names(nodes, [NodeProto(name="test")])
|
|
|
|
self.assertEqual(fixed, nodes)
|
|
|
|
def test_output_projection(self):
|
|
nodes = {
|
|
"output_block_proj_out.lora_down.weight": {},
|
|
}
|
|
fixed = fix_xl_names(
|
|
nodes,
|
|
[
|
|
NodeProto(name="/up_blocks_proj_out/MatMul"),
|
|
],
|
|
)
|
|
|
|
self.assertEqual(
|
|
fixed,
|
|
{
|
|
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
|
|
},
|
|
)
|
|
|
|
|
|
class KernelSliceTests(unittest.TestCase):
|
|
def test_within_kernel(self):
|
|
self.assertEqual(
|
|
kernel_slice(1, 1, (3, 3, 3, 3)),
|
|
(1, 1),
|
|
)
|
|
|
|
def test_outside_kernel(self):
|
|
self.assertEqual(
|
|
kernel_slice(9, 9, (3, 3, 3, 3)),
|
|
(2, 2),
|
|
)
|
|
|
|
|
|
class InterpToMatchTests(unittest.TestCase):
|
|
def test_same_shape(self):
|
|
ref = np.zeros((4, 4))
|
|
resize = np.zeros((4, 4))
|
|
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
|
|
|
def test_different_one_dim(self):
|
|
ref = np.zeros((4, 2))
|
|
resize = np.zeros((4, 4))
|
|
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
|
|
|
def test_different_both_dims(self):
|
|
ref = np.zeros((2, 2))
|
|
resize = np.zeros((4, 4))
|
|
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
|
|
|
|
|
class BlendLoRATests(unittest.TestCase):
|
|
def test_blend_unet(self):
|
|
"""
|
|
blend_loras(None, "test", [], "unet")
|
|
"""
|
|
pass
|
|
|
|
def test_blend_text_encoder(self):
|
|
"""
|
|
blend_loras(None, "test", [], "text_encoder")
|
|
"""
|
|
pass
|
|
|
|
def test_blend_text_encoder_index(self):
|
|
"""
|
|
blend_loras(None, "test", [], "text_encoder", model_index=2)
|
|
"""
|
|
pass
|
|
|
|
def test_unmatched_keys(self):
|
|
pass
|
|
|
|
def test_xl_keys(self):
|
|
"""
|
|
blend_loras(None, "test", [], "unet", xl=True)
|
|
"""
|
|
pass
|
|
|
|
def test_node_dtype(self):
|
|
pass
|
|
|
|
@patch("onnx_web.convert.diffusion.lora.load")
|
|
@patch("onnx_web.convert.diffusion.lora.load_tensor")
|
|
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
|
|
base_name = "model.onnx"
|
|
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
|
|
model_type = "unet"
|
|
model_index = 2
|
|
xl = True
|
|
|
|
mock_load.return_value = MagicMock()
|
|
mock_load_tensor.return_value = MagicMock()
|
|
|
|
# Call the blend_loras function
|
|
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl)
|
|
|
|
# Assert that the InferenceSession is called with the correct arguments
|
|
mock_load.assert_called_once_with(base_name)
|
|
|
|
# Assert that the model is loaded successfully
|
|
self.assertEqual(blended_model, mock_load.return_value)
|
|
|
|
# Assert that the blending logic is executed correctly
|
|
# (assertions specific to the blending logic can be added here)
|
|
|
|
|
|
class BlendWeightsLoHATests(unittest.TestCase):
|
|
def test_blend_t1_t2(self):
|
|
# blend einsum: i j k l, j r, i p -> p r k l
|
|
i = 32
|
|
j = 4
|
|
k = 1
|
|
l = 1 # NOQA
|
|
p = 2
|
|
r = 4
|
|
|
|
model = {
|
|
"foo.hada_t1": torch.from_numpy(np.ones((i, j, k, l))),
|
|
"foo.hada_t2": torch.from_numpy(np.ones((i, j, k, l))),
|
|
"foo.hada_w1_a": torch.from_numpy(np.ones((i, p))),
|
|
"foo.hada_w1_b": torch.from_numpy(np.ones((j, r))),
|
|
"foo.hada_w2_a": torch.from_numpy(np.ones((i, p))),
|
|
"foo.hada_w2_b": torch.from_numpy(np.ones((j, r))),
|
|
"foo.alpha": torch.tensor(1),
|
|
}
|
|
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
|
|
self.assertEqual(result.shape, (p, r, k, l))
|
|
|
|
def test_blend_w1_w2(self):
|
|
model = {
|
|
"foo.hada_w1_a": torch.from_numpy(np.ones((4, 1))),
|
|
"foo.hada_w1_b": torch.from_numpy(np.ones((1, 4))),
|
|
"foo.hada_w2_a": torch.from_numpy(np.ones((4, 1))),
|
|
"foo.hada_w2_b": torch.from_numpy(np.ones((1, 4))),
|
|
"foo.alpha": torch.tensor(1),
|
|
}
|
|
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
|
|
self.assertEqual(result.shape, (4, 4))
|
|
|
|
def test_blend_no_dim(self):
|
|
"""
|
|
model = {
|
|
"foo.hada_w1_a": torch.from_numpy(np.ones((1, 4))),
|
|
"foo.hada_w1_b": torch.from_numpy(np.ones((4, 1))),
|
|
"foo.hada_w2_a": torch.from_numpy(np.ones((1, 4))),
|
|
"foo.hada_w2_b": torch.from_numpy(np.ones((4, 1))),
|
|
}
|
|
result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
|
|
self.assertEqual(result.shape, (4, 4))
|
|
"""
|
|
|
|
|
|
class BlendWeightsLoRATests(unittest.TestCase):
|
|
def test_blend_kernel_none(self):
|
|
model = {
|
|
"foo.lora_down": torch.from_numpy(np.ones((1, 4))),
|
|
"foo.lora_up": torch.from_numpy(np.ones((4, 1))),
|
|
"foo.alpha": 1,
|
|
}
|
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
|
self.assertEqual(result.shape, (4, 4))
|
|
|
|
def test_blend_kernel_1x1(self):
|
|
model = {
|
|
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
|
|
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 1, 1))),
|
|
"foo.alpha": 1,
|
|
}
|
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
|
self.assertEqual(result.shape, (4, 4, 1, 1))
|
|
|
|
def test_blend_kernel_3x3(self):
|
|
model = {
|
|
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 3, 3))),
|
|
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 3, 3))),
|
|
"foo.alpha": 1,
|
|
}
|
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
|
self.assertEqual(result.shape, (4, 4, 3, 3))
|
|
|
|
def test_blend_kernel_3x3_cp_decomp(self):
|
|
model = {
|
|
"foo.lora_down": torch.from_numpy(np.ones((2, 4, 1, 1))),
|
|
"foo.lora_mid": torch.from_numpy(np.ones((2, 2, 3, 3))),
|
|
"foo.lora_up": torch.from_numpy(np.ones((4, 2, 1, 1))),
|
|
"foo.alpha": 1,
|
|
}
|
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
|
self.assertEqual(result.shape, (4, 4, 3, 3))
|
|
|
|
def test_blend_unknown(self):
|
|
pass
|
|
|
|
|
|
class BlendNodeConvGemmTests(unittest.TestCase):
|
|
def test_blend_kernel_1x1_and_1x1(self):
|
|
node = from_array(np.ones((4, 4, 1, 1)))
|
|
result = blend_node_conv_gemm(node, np.ones((4, 4, 1, 1)))
|
|
|
|
self.assertEqual(result.dims, [4, 4, 1, 1])
|
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
|
|
|
def test_blend_kernel_1x1_and_none(self):
|
|
node = from_array(np.ones((4, 4, 1, 1)))
|
|
result = blend_node_conv_gemm(node, np.ones((4, 4)))
|
|
|
|
self.assertEqual(result.dims, [4, 4, 1, 1])
|
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
|
|
|
def test_blend_other_matching(self):
|
|
node = from_array(np.ones((4, 4)))
|
|
result = blend_node_conv_gemm(node, np.ones((4, 4)))
|
|
|
|
self.assertEqual(result.dims, [4, 4])
|
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
|
|
|
def test_blend_other_mismatched(self):
|
|
pass
|
|
|
|
|
|
class BlendNodeMatMulTests(unittest.TestCase):
|
|
def test_blend_matching(self):
|
|
node = from_array(np.ones((4, 4)))
|
|
result = blend_node_matmul(node, np.ones((4, 4)), "test")
|
|
|
|
self.assertEqual(result.dims, [4, 4])
|
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
|
|
|
def test_blend_mismatched(self):
|
|
node = from_array(np.ones((4, 4)))
|
|
result = blend_node_matmul(node, np.ones((2, 2)), "test")
|
|
|
|
self.assertEqual(result.dims, [4, 4])
|
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|