1
0
Fork 0
onnx-web/api/tests/convert/diffusion/test_lora.py

180 lines
4.9 KiB
Python
Raw Normal View History

2023-09-15 00:35:48 +00:00
import unittest
import numpy as np
from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import (
blend_loras,
buffer_external_data_tensors,
fix_initializer_name,
fix_node_name,
fix_xl_names,
interp_to_match,
kernel_slice,
sum_weights,
)
class SumWeightsTests(unittest.TestCase):
def test_same_shape(self):
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4))
def test_1x1_kernel(self):
weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
def test_3x3_kernel(self):
"""
weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4)))
self.assertEqual(weights.shape, (4, 4, 1, 1))
"""
pass
class BufferExternalDataTensorTests(unittest.TestCase):
def test_basic_external(self):
model = ModelProto(
graph=GraphProto(
initializer=[
from_array(np.zeros((4, 4))),
],
)
)
(slim_model, external_weights) = buffer_external_data_tensors(model)
self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer))
self.assertEqual(len(external_weights), 1)
class FixInitializerKeyTests(unittest.TestCase):
def test_fix_name(self):
inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"]
outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_initializer_name(input), output)
class FixNodeNameTests(unittest.TestCase):
def test_fix_name(self):
inputs = [
"lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight",
"_prefix",
]
outputs = [
"lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight",
"prefix",
]
for input, output in zip(inputs, outputs):
self.assertEqual(fix_node_name(input), output)
class FixXLNameTests(unittest.TestCase):
def test_empty(self):
nodes = {}
fixed = fix_xl_names(nodes, [])
self.assertEqual(fixed, {})
def test_input_block(self):
nodes = {
"input_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/down_blocks_proj/MatMul"),
])
self.assertEqual(fixed, {
"down_blocks_proj": nodes["input_block_proj.lora_down.weight"],
})
def test_middle_block(self):
nodes = {
"middle_block_proj.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/mid_blocks_proj/MatMul"),
])
self.assertEqual(fixed, {
"mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"],
})
def test_output_block(self):
pass
def test_text_model(self):
pass
def test_unknown_block(self):
pass
def test_unmatched_block(self):
nodes = {
"lora_unet.input_block.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [""])
self.assertEqual(fixed, nodes)
def test_output_projection(self):
nodes = {
2023-09-15 11:56:01 +00:00
"output_block_proj_out.lora_down.weight": {},
2023-09-15 00:35:48 +00:00
}
fixed = fix_xl_names(nodes, [
2023-09-15 11:56:01 +00:00
NodeProto(name="/up_blocks_proj_out/MatMul"),
2023-09-15 00:35:48 +00:00
])
self.assertEqual(fixed, {
2023-09-15 11:56:01 +00:00
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
2023-09-15 00:35:48 +00:00
})
class KernelSliceTests(unittest.TestCase):
def test_within_kernel(self):
self.assertEqual(
kernel_slice(1, 1, (3, 3, 3, 3)),
(1, 1),
)
def test_outside_kernel(self):
self.assertEqual(
kernel_slice(9, 9, (3, 3, 3, 3)),
(2, 2),
)
2023-09-15 13:40:56 +00:00
2023-09-15 00:35:48 +00:00
class BlendLoRATests(unittest.TestCase):
2023-09-15 13:40:56 +00:00
def test_blend_unet(self):
pass
def test_blend_text_encoder(self):
pass
def test_blend_text_encoder_index(self):
pass
2023-09-15 00:35:48 +00:00
class InterpToMatchTests(unittest.TestCase):
def test_same_shape(self):
ref = np.zeros((4, 4))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_one_dim(self):
ref = np.zeros((4, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_both_dims(self):
ref = np.zeros((2, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))