diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index f2dbf75e..6e8ecc6e 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -73,7 +73,7 @@ def fix_node_name(key: str): def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]: fixed = {} - remaining = list(nodes) + names = [fix_node_name(node.name) for node in nodes] for key, value in keys.items(): root, *rest = key.split(".") @@ -128,28 +128,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] continue logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix) - match: Optional[NodeProto] = None + match: Optional[str] = None if "conv" in suffix: match = next( - node for node in remaining if fix_node_name(node.name) == f"{root}_Conv" + node for node in names if node == f"{root}_Conv" ) elif "time_emb_proj" in root: match = next( - node for node in remaining if fix_node_name(node.name) == f"{root}_Gemm" + node for node in names if node == f"{root}_Gemm" ) elif block == "text_model" or simple: match = next( node - for node in remaining - if fix_node_name(node.name) == f"{root}_MatMul" + for node in names + if node == f"{root}_MatMul" ) else: # search in order. one side has sparse indices, so they will not match. match = next( node - for node in remaining - if node.name.startswith(f"/{block}") - and fix_node_name(node.name).endswith( + for node in names + if node.startswith(block) + and node.endswith( f"{suffix}_MatMul" ) # needs to be fixed because some places use to_out.0 ) @@ -158,10 +158,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] logger.warning("no matches for XL key: %s", root) continue else: - logger.trace("matched key: %s -> %s", key, match.name) + logger.trace("matched key: %s -> %s", key, match) - name: str = match.name - name = fix_node_name(name) + name = match if name.endswith("_MatMul"): name = name[:-7] elif name.endswith("_Gemm"): @@ -169,15 +168,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] elif name.endswith("_Conv"): name = name[:-5] - logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name) + logger.trace("matching XL key with node: %s -> %s, %s", key, match, name) fixed[name] = value - remaining.remove(match) + names.remove(match) logger.debug( "SDXL LoRA key fixup matched %s keys, %s remaining", len(fixed.keys()), - len(remaining), + len(names), ) return fixed @@ -487,7 +486,7 @@ def blend_loras( blended[base_key] = np_weights # rewrite node names for XL and flatten layers - weights = Dict[str, np.ndarray] = {} + weights: Dict[str, np.ndarray] = {} for blended in layers: if xl: