use pre-fixed names for XL LoRA key matching
This commit is contained in:
parent
a02523c54c
commit
a39fe1d21c
|
@ -71,8 +71,9 @@ def fix_node_name(key: str):
|
|||
return fixed_name
|
||||
|
||||
|
||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
|
||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any], List[str]]:
|
||||
fixed = {}
|
||||
remaining = list(nodes)
|
||||
|
||||
for key, value in keys.items():
|
||||
root, *rest = key.split(".")
|
||||
|
@ -87,14 +88,11 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
|
|||
elif root.startswith("text_model"):
|
||||
block = "text_model"
|
||||
elif root.startswith("down_blocks"):
|
||||
fixed[fix_node_name(key)] = value
|
||||
continue
|
||||
elif root.startswith("mid_blocks"):
|
||||
fixed[fix_node_name(key)] = value
|
||||
continue
|
||||
block = "down_blocks"
|
||||
elif root.startswith("mid_block"):
|
||||
block = "mid_block"
|
||||
elif root.startswith("up_blocks"):
|
||||
fixed[fix_node_name(key)] = value
|
||||
continue
|
||||
block = "up_blocks"
|
||||
else:
|
||||
logger.warning("unknown XL key name: %s", key)
|
||||
fixed[key] = value
|
||||
|
@ -125,14 +123,14 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
|
|||
match = None
|
||||
if block == "text_model":
|
||||
match = next(
|
||||
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
|
||||
node for node in remaining if node == f"{root}_MatMul"
|
||||
)
|
||||
else:
|
||||
match = next(
|
||||
node
|
||||
for node in nodes
|
||||
if node.name.startswith(f"/{block}")
|
||||
and fix_node_name(node.name).endswith(
|
||||
for node in remaining
|
||||
if node.startswith(f"/{block}")
|
||||
and node.endswith(
|
||||
f"{suffix}_MatMul"
|
||||
) # needs to be fixed because some places use to_out.0
|
||||
)
|
||||
|
@ -141,9 +139,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
|
|||
logger.warning("no matches for XL key: %s", root)
|
||||
continue
|
||||
|
||||
name: str = match.name
|
||||
name = fix_node_name(name.rstrip("/MatMul"))
|
||||
|
||||
name = match.rstrip("/MatMul")
|
||||
if name.endswith("proj_o"):
|
||||
# wtf
|
||||
name = f"{name}ut"
|
||||
|
@ -151,9 +147,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]):
|
|||
logger.trace("matching XL key with node: %s -> %s", key, match.name)
|
||||
|
||||
fixed[name] = value
|
||||
nodes.remove(match)
|
||||
remaining.remove(match)
|
||||
|
||||
return fixed
|
||||
return (fixed, remaining)
|
||||
|
||||
|
||||
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
||||
|
@ -472,10 +468,17 @@ def blend_loras(
|
|||
)
|
||||
blended[base_key] = np_weights
|
||||
|
||||
# fix node names once
|
||||
fixed_initializer_names = [
|
||||
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
||||
]
|
||||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||
|
||||
# rewrite node names for XL
|
||||
if xl:
|
||||
nodes = list(base_model.graph.node)
|
||||
blended = fix_xl_names(blended, nodes)
|
||||
blended, remaining = fix_xl_names(blended, fixed_node_names)
|
||||
if len(remaining) > 0:
|
||||
logger.warning("could not match some XL keys: %s", remaining)
|
||||
|
||||
logger.debug(
|
||||
"updating %s of %s initializers",
|
||||
|
@ -483,11 +486,6 @@ def blend_loras(
|
|||
len(base_model.graph.initializer),
|
||||
)
|
||||
|
||||
fixed_initializer_names = [
|
||||
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
||||
]
|
||||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||
|
||||
unmatched_keys = []
|
||||
for base_key, weights in blended.items():
|
||||
conv_key = base_key + "_Conv"
|
||||
|
|
Loading…
Reference in New Issue