1
0
Fork 0

adjust node names in projection test

This commit is contained in:
Sean Sube 2023-09-15 06:56:01 -05:00
parent b851c234fe
commit 3137a465ab
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 3 additions and 3 deletions

View File

@ -126,14 +126,14 @@ class FixXLNameTests(unittest.TestCase):
def test_output_projection(self):
nodes = {
"output_block_proj_o.lora_down.weight": {},
"output_block_proj_out.lora_down.weight": {},
}
fixed = fix_xl_names(nodes, [
NodeProto(name="/up_blocks_proj_o/MatMul"),
NodeProto(name="/up_blocks_proj_out/MatMul"),
])
self.assertEqual(fixed, {
"up_blocks_proj_out": nodes["output_block_proj_o.lora_down.weight"],
"up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"],
})