switch LoRA back to fixing node names on the fly
This commit is contained in:
parent
a39fe1d21c
commit
33bd67beb6
|
@ -71,7 +71,7 @@ def fix_node_name(key: str):
|
||||||
return fixed_name
|
return fixed_name
|
||||||
|
|
||||||
|
|
||||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any], List[str]]:
|
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
|
||||||
fixed = {}
|
fixed = {}
|
||||||
remaining = list(nodes)
|
remaining = list(nodes)
|
||||||
|
|
||||||
|
@ -79,6 +79,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
|
||||||
root, *rest = key.split(".")
|
root, *rest = key.split(".")
|
||||||
logger.trace("fixing XL node name: %s -> %s", key, root)
|
logger.trace("fixing XL node name: %s -> %s", key, root)
|
||||||
|
|
||||||
|
simple = False
|
||||||
if root.startswith("input"):
|
if root.startswith("input"):
|
||||||
block = "down_blocks"
|
block = "down_blocks"
|
||||||
elif root.startswith("middle"):
|
elif root.startswith("middle"):
|
||||||
|
@ -89,10 +90,13 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
|
||||||
block = "text_model"
|
block = "text_model"
|
||||||
elif root.startswith("down_blocks"):
|
elif root.startswith("down_blocks"):
|
||||||
block = "down_blocks"
|
block = "down_blocks"
|
||||||
|
simple = True
|
||||||
elif root.startswith("mid_block"):
|
elif root.startswith("mid_block"):
|
||||||
block = "mid_block"
|
block = "mid_block"
|
||||||
|
simple = True
|
||||||
elif root.startswith("up_blocks"):
|
elif root.startswith("up_blocks"):
|
||||||
block = "up_blocks"
|
block = "up_blocks"
|
||||||
|
simple = True
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown XL key name: %s", key)
|
logger.warning("unknown XL key name: %s", key)
|
||||||
fixed[key] = value
|
fixed[key] = value
|
||||||
|
@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
|
||||||
|
|
||||||
suffix = None
|
suffix = None
|
||||||
for s in [
|
for s in [
|
||||||
|
"conv",
|
||||||
|
"conv_shortcut",
|
||||||
|
"conv1",
|
||||||
|
"conv2",
|
||||||
"fc1",
|
"fc1",
|
||||||
"fc2",
|
"fc2",
|
||||||
"ff_net_0_proj",
|
"ff_net_0_proj",
|
||||||
|
@ -120,17 +128,26 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.trace("searching for XL node: /%s/*/%s", block, suffix)
|
logger.trace("searching for XL node: /%s/*/%s", block, suffix)
|
||||||
match = None
|
match: Optional[NodeProto] = None
|
||||||
if block == "text_model":
|
if "conv" in suffix:
|
||||||
match = next(
|
match = next(
|
||||||
node for node in remaining if node == f"{root}_MatMul"
|
node for node in remaining if fix_node_name(node.name) == f"{root}_Conv"
|
||||||
|
)
|
||||||
|
elif "time_emb_proj" in root:
|
||||||
|
match = next(
|
||||||
|
node for node in remaining if fix_node_name(node.name) == f"{root}_Gemm"
|
||||||
|
)
|
||||||
|
elif block == "text_model" or simple:
|
||||||
|
match = next(
|
||||||
|
node for node in remaining if fix_node_name(node.name) == f"{root}_MatMul"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# search in order. one side has sparse indices, so they will not match.
|
||||||
match = next(
|
match = next(
|
||||||
node
|
node
|
||||||
for node in remaining
|
for node in remaining
|
||||||
if node.startswith(f"/{block}")
|
if node.name.startswith(f"/{block}")
|
||||||
and node.endswith(
|
and fix_node_name(node.name).endswith(
|
||||||
f"{suffix}_MatMul"
|
f"{suffix}_MatMul"
|
||||||
) # needs to be fixed because some places use to_out.0
|
) # needs to be fixed because some places use to_out.0
|
||||||
)
|
)
|
||||||
|
@ -138,8 +155,12 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
|
||||||
if match is None:
|
if match is None:
|
||||||
logger.warning("no matches for XL key: %s", root)
|
logger.warning("no matches for XL key: %s", root)
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
logger.trace("matched key: %s -> %s", key, match.name)
|
||||||
|
|
||||||
|
name: str = match.name
|
||||||
|
name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv")
|
||||||
|
|
||||||
name = match.rstrip("/MatMul")
|
|
||||||
if name.endswith("proj_o"):
|
if name.endswith("proj_o"):
|
||||||
# wtf
|
# wtf
|
||||||
name = f"{name}ut"
|
name = f"{name}ut"
|
||||||
|
@ -149,7 +170,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
|
||||||
fixed[name] = value
|
fixed[name] = value
|
||||||
remaining.remove(match)
|
remaining.remove(match)
|
||||||
|
|
||||||
return (fixed, remaining)
|
return fixed
|
||||||
|
|
||||||
|
|
||||||
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
||||||
|
@ -476,9 +497,8 @@ def blend_loras(
|
||||||
|
|
||||||
# rewrite node names for XL
|
# rewrite node names for XL
|
||||||
if xl:
|
if xl:
|
||||||
blended, remaining = fix_xl_names(blended, fixed_node_names)
|
nodes = list(base_model.graph.node)
|
||||||
if len(remaining) > 0:
|
blended = fix_xl_names(blended, nodes)
|
||||||
logger.warning("could not match some XL keys: %s", remaining)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"updating %s of %s initializers",
|
"updating %s of %s initializers",
|
||||||
|
|
Loading…
Reference in New Issue