1
0
Fork 0

correctly remove operator types

This commit is contained in:
Sean Sube 2023-11-21 23:12:24 -06:00
parent 33bd67beb6
commit 54dc970a7e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 7 deletions

View File

@ -127,7 +127,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
logger.warning("new XL key type: %s", root)
continue
logger.trace("searching for XL node: /%s/*/%s", block, suffix)
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
match: Optional[NodeProto] = None
if "conv" in suffix:
match = next(
@ -159,13 +159,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
logger.trace("matched key: %s -> %s", key, match.name)
name: str = match.name
name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv")
name = fix_node_name(name)
if name.endswith("_MatMul"):
name = name[:-7]
elif name.endswith("_Gemm"):
name = name[:-5]
elif name.endswith("_Conv"):
name = name[:-5]
if name.endswith("proj_o"):
# wtf
name = f"{name}ut"
logger.trace("matching XL key with node: %s -> %s", key, match.name)
logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name)
fixed[name] = value
remaining.remove(match)