correctly remove operator types
This commit is contained in:
parent
33bd67beb6
commit
54dc970a7e
|
@ -127,7 +127,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
logger.warning("new XL key type: %s", root)
|
logger.warning("new XL key type: %s", root)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.trace("searching for XL node: /%s/*/%s", block, suffix)
|
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
||||||
match: Optional[NodeProto] = None
|
match: Optional[NodeProto] = None
|
||||||
if "conv" in suffix:
|
if "conv" in suffix:
|
||||||
match = next(
|
match = next(
|
||||||
|
@ -159,13 +159,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
logger.trace("matched key: %s -> %s", key, match.name)
|
logger.trace("matched key: %s -> %s", key, match.name)
|
||||||
|
|
||||||
name: str = match.name
|
name: str = match.name
|
||||||
name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv")
|
name = fix_node_name(name)
|
||||||
|
if name.endswith("_MatMul"):
|
||||||
|
name = name[:-7]
|
||||||
|
elif name.endswith("_Gemm"):
|
||||||
|
name = name[:-5]
|
||||||
|
elif name.endswith("_Conv"):
|
||||||
|
name = name[:-5]
|
||||||
|
|
||||||
if name.endswith("proj_o"):
|
logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name)
|
||||||
# wtf
|
|
||||||
name = f"{name}ut"
|
|
||||||
|
|
||||||
logger.trace("matching XL key with node: %s -> %s", key, match.name)
|
|
||||||
|
|
||||||
fixed[name] = value
|
fixed[name] = value
|
||||||
remaining.remove(match)
|
remaining.remove(match)
|
||||||
|
|
Loading…
Reference in New Issue