1
0
Fork 0

use pre-fixed names for XL LoRA key matching

This commit is contained in:
Sean Sube 2023-11-21 21:46:34 -06:00
parent a02523c54c
commit a39fe1d21c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 22 additions and 24 deletions

View File

@ -71,8 +71,9 @@ def fix_node_name(key: str):
return fixed_name return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]): def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any], List[str]]:
fixed = {} fixed = {}
remaining = list(nodes)
for key, value in keys.items(): for key, value in keys.items():
root, *rest = key.split(".") root, *rest = key.split(".")
@ -87,14 +88,11 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
elif root.startswith("text_model"): elif root.startswith("text_model"):
block = "text_model" block = "text_model"
elif root.startswith("down_blocks"): elif root.startswith("down_blocks"):
fixed[fix_node_name(key)] = value block = "down_blocks"
continue elif root.startswith("mid_block"):
elif root.startswith("mid_blocks"): block = "mid_block"
fixed[fix_node_name(key)] = value
continue
elif root.startswith("up_blocks"): elif root.startswith("up_blocks"):
fixed[fix_node_name(key)] = value block = "up_blocks"
continue
else: else:
logger.warning("unknown XL key name: %s", key) logger.warning("unknown XL key name: %s", key)
fixed[key] = value fixed[key] = value
@ -125,14 +123,14 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
match = None match = None
if block == "text_model": if block == "text_model":
match = next( match = next(
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul" node for node in remaining if node == f"{root}_MatMul"
) )
else: else:
match = next( match = next(
node node
for node in nodes for node in remaining
if node.name.startswith(f"/{block}") if node.startswith(f"/{block}")
and fix_node_name(node.name).endswith( and node.endswith(
f"{suffix}_MatMul" f"{suffix}_MatMul"
) # needs to be fixed because some places use to_out.0 ) # needs to be fixed because some places use to_out.0
) )
@ -141,9 +139,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
logger.warning("no matches for XL key: %s", root) logger.warning("no matches for XL key: %s", root)
continue continue
name: str = match.name name = match.rstrip("/MatMul")
name = fix_node_name(name.rstrip("/MatMul"))
if name.endswith("proj_o"): if name.endswith("proj_o"):
# wtf # wtf
name = f"{name}ut" name = f"{name}ut"
@ -151,9 +147,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
logger.trace("matching XL key with node: %s -> %s", key, match.name) logger.trace("matching XL key with node: %s -> %s", key, match.name)
fixed[name] = value fixed[name] = value
nodes.remove(match) remaining.remove(match)
return fixed return (fixed, remaining)
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]: def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
@ -472,10 +468,17 @@ def blend_loras(
) )
blended[base_key] = np_weights blended[base_key] = np_weights
# fix node names once
fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer
]
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
# rewrite node names for XL # rewrite node names for XL
if xl: if xl:
nodes = list(base_model.graph.node) blended, remaining = fix_xl_names(blended, fixed_node_names)
blended = fix_xl_names(blended, nodes) if len(remaining) > 0:
logger.warning("could not match some XL keys: %s", remaining)
logger.debug( logger.debug(
"updating %s of %s initializers", "updating %s of %s initializers",
@ -483,11 +486,6 @@ def blend_loras(
len(base_model.graph.initializer), len(base_model.graph.initializer),
) )
fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer
]
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
unmatched_keys = [] unmatched_keys = []
for base_key, weights in blended.items(): for base_key, weights in blended.items():
conv_key = base_key + "_Conv" conv_key = base_key + "_Conv"