From a39fe1d21caac515eb294906a51570ae8a2dc058 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 21 Nov 2023 21:46:34 -0600 Subject: [PATCH] use pre-fixed names for XL LoRA key matching --- api/onnx_web/convert/diffusion/lora.py | 46 ++++++++++++-------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 91904e11..ac03f713 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -71,8 +71,9 @@ def fix_node_name(key: str): return fixed_name -def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]): +def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any], List[str]]: fixed = {} + remaining = list(nodes) for key, value in keys.items(): root, *rest = key.split(".") @@ -87,14 +88,11 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]): elif root.startswith("text_model"): block = "text_model" elif root.startswith("down_blocks"): - fixed[fix_node_name(key)] = value - continue - elif root.startswith("mid_blocks"): - fixed[fix_node_name(key)] = value - continue + block = "down_blocks" + elif root.startswith("mid_block"): + block = "mid_block" elif root.startswith("up_blocks"): - fixed[fix_node_name(key)] = value - continue + block = "up_blocks" else: logger.warning("unknown XL key name: %s", key) fixed[key] = value @@ -125,14 +123,14 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]): match = None if block == "text_model": match = next( - node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul" + node for node in remaining if node == f"{root}_MatMul" ) else: match = next( node - for node in nodes - if node.name.startswith(f"/{block}") - and fix_node_name(node.name).endswith( + for node in remaining + if node.startswith(f"/{block}") + and node.endswith( f"{suffix}_MatMul" ) # needs to be fixed because some places use to_out.0 ) @@ -141,9 +139,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]): logger.warning("no matches for XL key: %s", root) continue - name: str = match.name - name = fix_node_name(name.rstrip("/MatMul")) - + name = match.rstrip("/MatMul") if name.endswith("proj_o"): # wtf name = f"{name}ut" @@ -151,9 +147,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]): logger.trace("matching XL key with node: %s -> %s", key, match.name) fixed[name] = value - nodes.remove(match) + remaining.remove(match) - return fixed + return (fixed, remaining) def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]: @@ -472,10 +468,17 @@ def blend_loras( ) blended[base_key] = np_weights + # fix node names once + fixed_initializer_names = [ + fix_initializer_name(node.name) for node in base_model.graph.initializer + ] + fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] + # rewrite node names for XL if xl: - nodes = list(base_model.graph.node) - blended = fix_xl_names(blended, nodes) + blended, remaining = fix_xl_names(blended, fixed_node_names) + if len(remaining) > 0: + logger.warning("could not match some XL keys: %s", remaining) logger.debug( "updating %s of %s initializers", @@ -483,11 +486,6 @@ def blend_loras( len(base_model.graph.initializer), ) - fixed_initializer_names = [ - fix_initializer_name(node.name) for node in base_model.graph.initializer - ] - fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] - unmatched_keys = [] for base_key, weights in blended.items(): conv_key = base_key + "_Conv"