From 54dc970a7e95ab331aaa9110760f5133b511e285 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 21 Nov 2023 23:12:24 -0600 Subject: [PATCH] correctly remove operator types --- api/onnx_web/convert/diffusion/lora.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 1d2279c3..b912b107 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -127,7 +127,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] logger.warning("new XL key type: %s", root) continue - logger.trace("searching for XL node: /%s/*/%s", block, suffix) + logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix) match: Optional[NodeProto] = None if "conv" in suffix: match = next( @@ -159,13 +159,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] logger.trace("matched key: %s -> %s", key, match.name) name: str = match.name - name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv") + name = fix_node_name(name) + if name.endswith("_MatMul"): + name = name[:-7] + elif name.endswith("_Gemm"): + name = name[:-5] + elif name.endswith("_Conv"): + name = name[:-5] - if name.endswith("proj_o"): - # wtf - name = f"{name}ut" - - logger.trace("matching XL key with node: %s -> %s", key, match.name) + logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name) fixed[name] = value remaining.remove(match)