optimize XL LoRA node matching
This commit is contained in:
parent
ebdc6a00fe
commit
1d5ac7dde5
|
@ -120,35 +120,36 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
|||
continue
|
||||
|
||||
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
|
||||
match = None
|
||||
if block == "text_model":
|
||||
matches = [
|
||||
match = next(
|
||||
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
|
||||
]
|
||||
)
|
||||
else:
|
||||
matches = [
|
||||
match = next(
|
||||
node
|
||||
for node in nodes
|
||||
if node.name.startswith(f"/{block}")
|
||||
and fix_node_name(node.name).endswith(
|
||||
f"{suffix}_MatMul"
|
||||
) # needs to be fixed because some places use to_out.0
|
||||
]
|
||||
)
|
||||
|
||||
if len(matches) == 0:
|
||||
if match is None:
|
||||
logger.warning("no matches for XL key: %s", root)
|
||||
continue
|
||||
|
||||
name: str = matches[0].name
|
||||
name: str = match.name
|
||||
name = fix_node_name(name.rstrip("/MatMul"))
|
||||
|
||||
if name.endswith("proj_o"):
|
||||
# wtf
|
||||
name = f"{name}ut"
|
||||
|
||||
logger.debug("matching XL key with node: %s -> %s", key, matches[0].name)
|
||||
logger.debug("matching XL key with node: %s -> %s", key, match.name)
|
||||
|
||||
fixed[name] = value
|
||||
nodes.remove(matches[0])
|
||||
nodes.remove(match)
|
||||
|
||||
return fixed
|
||||
|
||||
|
|
Loading…
Reference in New Issue