1
0
Fork 0

optimize XL LoRA node matching

This commit is contained in:
Sean Sube 2023-09-03 16:08:24 -05:00
parent ebdc6a00fe
commit 1d5ac7dde5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 8 deletions

View File

@ -120,35 +120,36 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
continue
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
match = None
if block == "text_model":
matches = [
match = next(
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
]
)
else:
matches = [
match = next(
node
for node in nodes
if node.name.startswith(f"/{block}")
and fix_node_name(node.name).endswith(
f"{suffix}_MatMul"
) # needs to be fixed because some places use to_out.0
]
)
if len(matches) == 0:
if match is None:
logger.warning("no matches for XL key: %s", root)
continue
name: str = matches[0].name
name: str = match.name
name = fix_node_name(name.rstrip("/MatMul"))
if name.endswith("proj_o"):
# wtf
name = f"{name}ut"
logger.debug("matching XL key with node: %s -> %s", key, matches[0].name)
logger.debug("matching XL key with node: %s -> %s", key, match.name)
fixed[name] = value
nodes.remove(matches[0])
nodes.remove(match)
return fixed