optimize XL LoRA node matching
This commit is contained in:
parent
ebdc6a00fe
commit
1d5ac7dde5
|
@ -120,35 +120,36 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
|
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
|
||||||
|
match = None
|
||||||
if block == "text_model":
|
if block == "text_model":
|
||||||
matches = [
|
match = next(
|
||||||
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
|
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
|
||||||
]
|
)
|
||||||
else:
|
else:
|
||||||
matches = [
|
match = next(
|
||||||
node
|
node
|
||||||
for node in nodes
|
for node in nodes
|
||||||
if node.name.startswith(f"/{block}")
|
if node.name.startswith(f"/{block}")
|
||||||
and fix_node_name(node.name).endswith(
|
and fix_node_name(node.name).endswith(
|
||||||
f"{suffix}_MatMul"
|
f"{suffix}_MatMul"
|
||||||
) # needs to be fixed because some places use to_out.0
|
) # needs to be fixed because some places use to_out.0
|
||||||
]
|
)
|
||||||
|
|
||||||
if len(matches) == 0:
|
if match is None:
|
||||||
logger.warning("no matches for XL key: %s", root)
|
logger.warning("no matches for XL key: %s", root)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name: str = matches[0].name
|
name: str = match.name
|
||||||
name = fix_node_name(name.rstrip("/MatMul"))
|
name = fix_node_name(name.rstrip("/MatMul"))
|
||||||
|
|
||||||
if name.endswith("proj_o"):
|
if name.endswith("proj_o"):
|
||||||
# wtf
|
# wtf
|
||||||
name = f"{name}ut"
|
name = f"{name}ut"
|
||||||
|
|
||||||
logger.debug("matching XL key with node: %s -> %s", key, matches[0].name)
|
logger.debug("matching XL key with node: %s -> %s", key, match.name)
|
||||||
|
|
||||||
fixed[name] = value
|
fixed[name] = value
|
||||||
nodes.remove(matches[0])
|
nodes.remove(match)
|
||||||
|
|
||||||
return fixed
|
return fixed
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue