diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index ac03f713..1d2279c3 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -71,7 +71,7 @@ def fix_node_name(key: str): return fixed_name -def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any], List[str]]: +def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]: fixed = {} remaining = list(nodes) @@ -79,6 +79,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any] root, *rest = key.split(".") logger.trace("fixing XL node name: %s -> %s", key, root) + simple = False if root.startswith("input"): block = "down_blocks" elif root.startswith("middle"): @@ -89,10 +90,13 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any] block = "text_model" elif root.startswith("down_blocks"): block = "down_blocks" + simple = True elif root.startswith("mid_block"): block = "mid_block" + simple = True elif root.startswith("up_blocks"): block = "up_blocks" + simple = True else: logger.warning("unknown XL key name: %s", key) fixed[key] = value @@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any] suffix = None for s in [ + "conv", + "conv_shortcut", + "conv1", + "conv2", "fc1", "fc2", "ff_net_0_proj", @@ -120,17 +128,26 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any] continue logger.trace("searching for XL node: /%s/*/%s", block, suffix) - match = None - if block == "text_model": + match: Optional[NodeProto] = None + if "conv" in suffix: match = next( - node for node in remaining if node == f"{root}_MatMul" + node for node in remaining if fix_node_name(node.name) == f"{root}_Conv" + ) + elif "time_emb_proj" in root: + match = next( + node for node in remaining if fix_node_name(node.name) == f"{root}_Gemm" + ) + elif block == "text_model" or simple: + match = next( + node for node in remaining if fix_node_name(node.name) == f"{root}_MatMul" ) else: + # search in order. one side has sparse indices, so they will not match. match = next( node for node in remaining - if node.startswith(f"/{block}") - and node.endswith( + if node.name.startswith(f"/{block}") + and fix_node_name(node.name).endswith( f"{suffix}_MatMul" ) # needs to be fixed because some places use to_out.0 ) @@ -138,8 +155,12 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any] if match is None: logger.warning("no matches for XL key: %s", root) continue + else: + logger.trace("matched key: %s -> %s", key, match.name) + + name: str = match.name + name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv") - name = match.rstrip("/MatMul") if name.endswith("proj_o"): # wtf name = f"{name}ut" @@ -149,7 +170,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any] fixed[name] = value remaining.remove(match) - return (fixed, remaining) + return fixed def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]: @@ -476,9 +497,8 @@ def blend_loras( # rewrite node names for XL if xl: - blended, remaining = fix_xl_names(blended, fixed_node_names) - if len(remaining) > 0: - logger.warning("could not match some XL keys: %s", remaining) + nodes = list(base_model.graph.node) + blended = fix_xl_names(blended, nodes) logger.debug( "updating %s of %s initializers",