diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 1d2279c3..b912b107 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -127,7 +127,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] logger.warning("new XL key type: %s", root) continue - logger.trace("searching for XL node: /%s/*/%s", block, suffix) + logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix) match: Optional[NodeProto] = None if "conv" in suffix: match = next( @@ -159,13 +159,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] logger.trace("matched key: %s -> %s", key, match.name) name: str = match.name - name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv") + name = fix_node_name(name) + if name.endswith("_MatMul"): + name = name[:-7] + elif name.endswith("_Gemm"): + name = name[:-5] + elif name.endswith("_Conv"): + name = name[:-5] - if name.endswith("proj_o"): - # wtf - name = f"{name}ut" - - logger.trace("matching XL key with node: %s -> %s", key, match.name) + logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name) fixed[name] = value remaining.remove(match)