1
0
Fork 0

correctly remove operator types

This commit is contained in:
Sean Sube 2023-11-21 23:12:24 -06:00
parent 33bd67beb6
commit 54dc970a7e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 7 deletions

View File

@ -127,7 +127,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
logger.warning("new XL key type: %s", root) logger.warning("new XL key type: %s", root)
continue continue
logger.trace("searching for XL node: /%s/*/%s", block, suffix) logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
match: Optional[NodeProto] = None match: Optional[NodeProto] = None
if "conv" in suffix: if "conv" in suffix:
match = next( match = next(
@ -159,13 +159,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
logger.trace("matched key: %s -> %s", key, match.name) logger.trace("matched key: %s -> %s", key, match.name)
name: str = match.name name: str = match.name
name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv") name = fix_node_name(name)
if name.endswith("_MatMul"):
name = name[:-7]
elif name.endswith("_Gemm"):
name = name[:-5]
elif name.endswith("_Conv"):
name = name[:-5]
if name.endswith("proj_o"): logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name)
# wtf
name = f"{name}ut"
logger.trace("matching XL key with node: %s -> %s", key, match.name)
fixed[name] = value fixed[name] = value
remaining.remove(match) remaining.remove(match)