correctly remove operator types
This commit is contained in:
parent
33bd67beb6
commit
54dc970a7e
|
@ -127,7 +127,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
|||
logger.warning("new XL key type: %s", root)
|
||||
continue
|
||||
|
||||
logger.trace("searching for XL node: /%s/*/%s", block, suffix)
|
||||
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
||||
match: Optional[NodeProto] = None
|
||||
if "conv" in suffix:
|
||||
match = next(
|
||||
|
@ -159,13 +159,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
|||
logger.trace("matched key: %s -> %s", key, match.name)
|
||||
|
||||
name: str = match.name
|
||||
name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv")
|
||||
name = fix_node_name(name)
|
||||
if name.endswith("_MatMul"):
|
||||
name = name[:-7]
|
||||
elif name.endswith("_Gemm"):
|
||||
name = name[:-5]
|
||||
elif name.endswith("_Conv"):
|
||||
name = name[:-5]
|
||||
|
||||
if name.endswith("proj_o"):
|
||||
# wtf
|
||||
name = f"{name}ut"
|
||||
|
||||
logger.trace("matching XL key with node: %s -> %s", key, match.name)
|
||||
logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name)
|
||||
|
||||
fixed[name] = value
|
||||
remaining.remove(match)
|
||||
|
|
Loading…
Reference in New Issue