diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index c2e14d04..ae5b910a 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -120,35 +120,36 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): continue logger.debug("searching for XL node: /%s/*/%s", block, suffix) + match = None if block == "text_model": - matches = [ + match = next( node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul" - ] + ) else: - matches = [ + match = next( node for node in nodes if node.name.startswith(f"/{block}") and fix_node_name(node.name).endswith( f"{suffix}_MatMul" ) # needs to be fixed because some places use to_out.0 - ] + ) - if len(matches) == 0: + if match is None: logger.warning("no matches for XL key: %s", root) continue - name: str = matches[0].name + name: str = match.name name = fix_node_name(name.rstrip("/MatMul")) if name.endswith("proj_o"): # wtf name = f"{name}ut" - logger.debug("matching XL key with node: %s -> %s", key, matches[0].name) + logger.debug("matching XL key with node: %s -> %s", key, match.name) fixed[name] = value - nodes.remove(matches[0]) + nodes.remove(match) return fixed