1
0
Fork 0
onnx-web/api/tests/convert/diffusion/test_lora.py

386 lines
12 KiB
Python

import unittest
from unittest.mock import MagicMock, patch
import numpy as np
import torch
from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.constants import ONNX_MODEL
from onnx_web.convert.diffusion.lora import (
blend_loras,
blend_node_conv_gemm,
blend_node_matmul,
blend_weights_loha,
blend_weights_lora,
buffer_external_data_tensors,
fix_initializer_name,
fix_node_name,
fix_xl_names,
interp_to_match,
kernel_slice,
sum_weights,
)
class SumWeightsTests(unittest.TestCase):
def test_same_shape(self):
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4))
def test_1x1_kernel(self):
weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
def test_3x3_kernel(self):
"""
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
"""
pass
class BufferExternalDataTensorTests(unittest.TestCase):
def test_basic_external(self):
model = ModelProto(
graph=GraphProto(
initializer=[
from_array(np.zeros((4, 4))),
],
)
)
(slim_model, external_weights) = buffer_external_data_tensors(model)
self.assertEqual(
len(slim_model.graph.initializer), len(model.graph.initializer)
)
self.assertEqual(len(external_weights), 1)
class FixInitializerKeyTests(unittest.TestCase):
def test_fix_name(self):
inputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"
]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_initializer_name(input), output)
class FixNodeNameTests(unittest.TestCase):
def test_fix_name(self):
inputs = [
"lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight",
"_prefix",
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight",
"prefix",
]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_node_name(input), output)
class FixXLNameTests(unittest.TestCase):
def test_empty(self):
nodes = {}
fixed = fix_xl_names(nodes, [])
self.assertEqual(fixed, {})
def test_input_block(self):
nodes = {
"input_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/down_blocks_proj/MatMul"),
],
)
self.assertEqual(
fixed,
{
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
},
)
def test_middle_block(self):
nodes = {
"middle_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/mid_blocks_proj/MatMul"),
],
)
self.assertEqual(
fixed,
{
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
},
)
def test_output_block(self):
pass
def test_text_model(self):
pass
def test_unknown_block(self):
pass
def test_unmatched_block(self):
nodes = {
"lora_unet.input_block.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [NodeProto(name="test")])
self.assertEqual(fixed, nodes)
def test_output_projection(self):
nodes = {
"output_block_proj_out.lora_down.weight": {},
}
fixed = fix_xl_names(
nodes,
[
NodeProto(name="/up_blocks_proj_out/MatMul"),
],
)
self.assertEqual(
fixed,
{
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
},
)
class KernelSliceTests(unittest.TestCase):
def test_within_kernel(self):
self.assertEqual(
kernel_slice(1, 1, (3, 3, 3, 3)),
(1, 1),
)
def test_outside_kernel(self):
self.assertEqual(
kernel_slice(9, 9, (3, 3, 3, 3)),
(2, 2),
)
class InterpToMatchTests(unittest.TestCase):
def test_same_shape(self):
ref = np.zeros((4, 4))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_one_dim(self):
ref = np.zeros((4, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_both_dims(self):
ref = np.zeros((2, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
class BlendLoRATests(unittest.TestCase):
def test_blend_unet(self):
"""
blend_loras(None, "test", [], "unet")
"""
pass
def test_blend_text_encoder(self):
"""
blend_loras(None, "test", [], "text_encoder")
"""
pass
def test_blend_text_encoder_index(self):
"""
blend_loras(None, "test", [], "text_encoder", model_index=2)
"""
pass
def test_unmatched_keys(self):
pass
def test_xl_keys(self):
"""
blend_loras(None, "test", [], "unet", xl=True)
"""
pass
def test_node_dtype(self):
pass
@patch("onnx_web.convert.diffusion.lora.load")
@patch("onnx_web.convert.diffusion.lora.load_tensor")
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
model_type = "unet"
model_index = 2
xl = True
mock_load.return_value = MagicMock()
mock_load_tensor.return_value = MagicMock()
# Call the blend_loras function
blended_model = blend_loras(
None, ONNX_MODEL, loras, model_type, model_index, xl
)
# Assert that the InferenceSession is called with the correct arguments
mock_load.assert_called_once_with(ONNX_MODEL)
# Assert that the model is loaded successfully
self.assertEqual(blended_model, mock_load.return_value)
# Assert that the blending logic is executed correctly
# (assertions specific to the blending logic can be added here)
class BlendWeightsLoHATests(unittest.TestCase):
def test_blend_t1_t2(self):
# blend einsum: i j k l, j r, i p -> p r k l
i = 32
j = 4
k = 1
l = 1 # NOQA
p = 2
r = 4
model = {
"foo.hada_t1": torch.from_numpy(np.ones((i, j, k, l))),
"foo.hada_t2": torch.from_numpy(np.ones((i, j, k, l))),
"foo.hada_w1_a": torch.from_numpy(np.ones((i, p))),
"foo.hada_w1_b": torch.from_numpy(np.ones((j, r))),
"foo.hada_w2_a": torch.from_numpy(np.ones((i, p))),
"foo.hada_w2_b": torch.from_numpy(np.ones((j, r))),
"foo.alpha": torch.tensor(1),
}
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (p, r, k, l))
def test_blend_w1_w2(self):
model = {
"foo.hada_w1_a": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w1_b": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w2_a": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w2_b": torch.from_numpy(np.ones((1, 4))),
"foo.alpha": torch.tensor(1),
}
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_no_dim(self):
"""
model = {
"foo.hada_w1_a": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w1_b": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w2_a": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w2_b": torch.from_numpy(np.ones((4, 1))),
}
result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
"""
class BlendWeightsLoRATests(unittest.TestCase):
def test_blend_kernel_none(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_kernel_1x1(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 1, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 1, 1))
def test_blend_kernel_3x3(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 3, 3))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 3, 3))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 3, 3))
def test_blend_kernel_3x3_cp_decomp(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((2, 4, 1, 1))),
"foo.lora_mid": torch.from_numpy(np.ones((2, 2, 3, 3))),
"foo.lora_up": torch.from_numpy(np.ones((4, 2, 1, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 3, 3))
def test_blend_unknown(self):
pass
class BlendNodeConvGemmTests(unittest.TestCase):
def test_blend_kernel_1x1_and_1x1(self):
node = from_array(np.ones((4, 4, 1, 1)))
result = blend_node_conv_gemm(node, np.ones((4, 4, 1, 1)))
self.assertEqual(result.dims, [4, 4, 1, 1])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_kernel_1x1_and_none(self):
node = from_array(np.ones((4, 4, 1, 1)))
result = blend_node_conv_gemm(node, np.ones((4, 4)))
self.assertEqual(result.dims, [4, 4, 1, 1])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_other_matching(self):
node = from_array(np.ones((4, 4)))
result = blend_node_conv_gemm(node, np.ones((4, 4)))
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_other_mismatched(self):
pass
class BlendNodeMatMulTests(unittest.TestCase):
def test_blend_matching(self):
node = from_array(np.ones((4, 4)))
result = blend_node_matmul(node, np.ones((4, 4)), "test")
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_mismatched(self):
node = from_array(np.ones((4, 4)))
result = blend_node_matmul(node, np.ones((2, 2)), "test")
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)