1
0
Fork 0

switch LoRA back to fixing node names on the fly

This commit is contained in:
Sean Sube 2023-11-21 22:33:17 -06:00
parent a39fe1d21c
commit 33bd67beb6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 31 additions and 11 deletions

View File

@ -71,7 +71,7 @@ def fix_node_name(key: str):
return fixed_name return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any], List[str]]: def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
fixed = {} fixed = {}
remaining = list(nodes) remaining = list(nodes)
@ -79,6 +79,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
root, *rest = key.split(".") root, *rest = key.split(".")
logger.trace("fixing XL node name: %s -> %s", key, root) logger.trace("fixing XL node name: %s -> %s", key, root)
simple = False
if root.startswith("input"): if root.startswith("input"):
block = "down_blocks" block = "down_blocks"
elif root.startswith("middle"): elif root.startswith("middle"):
@ -89,10 +90,13 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
block = "text_model" block = "text_model"
elif root.startswith("down_blocks"): elif root.startswith("down_blocks"):
block = "down_blocks" block = "down_blocks"
simple = True
elif root.startswith("mid_block"): elif root.startswith("mid_block"):
block = "mid_block" block = "mid_block"
simple = True
elif root.startswith("up_blocks"): elif root.startswith("up_blocks"):
block = "up_blocks" block = "up_blocks"
simple = True
else: else:
logger.warning("unknown XL key name: %s", key) logger.warning("unknown XL key name: %s", key)
fixed[key] = value fixed[key] = value
@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
suffix = None suffix = None
for s in [ for s in [
"conv",
"conv_shortcut",
"conv1",
"conv2",
"fc1", "fc1",
"fc2", "fc2",
"ff_net_0_proj", "ff_net_0_proj",
@ -120,17 +128,26 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
continue continue
logger.trace("searching for XL node: /%s/*/%s", block, suffix) logger.trace("searching for XL node: /%s/*/%s", block, suffix)
match = None match: Optional[NodeProto] = None
if block == "text_model": if "conv" in suffix:
match = next( match = next(
node for node in remaining if node == f"{root}_MatMul" node for node in remaining if fix_node_name(node.name) == f"{root}_Conv"
)
elif "time_emb_proj" in root:
match = next(
node for node in remaining if fix_node_name(node.name) == f"{root}_Gemm"
)
elif block == "text_model" or simple:
match = next(
node for node in remaining if fix_node_name(node.name) == f"{root}_MatMul"
) )
else: else:
# search in order. one side has sparse indices, so they will not match.
match = next( match = next(
node node
for node in remaining for node in remaining
if node.startswith(f"/{block}") if node.name.startswith(f"/{block}")
and node.endswith( and fix_node_name(node.name).endswith(
f"{suffix}_MatMul" f"{suffix}_MatMul"
) # needs to be fixed because some places use to_out.0 ) # needs to be fixed because some places use to_out.0
) )
@ -138,8 +155,12 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
if match is None: if match is None:
logger.warning("no matches for XL key: %s", root) logger.warning("no matches for XL key: %s", root)
continue continue
else:
logger.trace("matched key: %s -> %s", key, match.name)
name: str = match.name
name = fix_node_name(name).rstrip("/MatMul").rstrip("/Gemm").rstrip("/Conv")
name = match.rstrip("/MatMul")
if name.endswith("proj_o"): if name.endswith("proj_o"):
# wtf # wtf
name = f"{name}ut" name = f"{name}ut"
@ -149,7 +170,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[str]) -> Tuple[Dict[str, Any]
fixed[name] = value fixed[name] = value
remaining.remove(match) remaining.remove(match)
return (fixed, remaining) return fixed
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]: def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
@ -476,9 +497,8 @@ def blend_loras(
# rewrite node names for XL # rewrite node names for XL
if xl: if xl:
blended, remaining = fix_xl_names(blended, fixed_node_names) nodes = list(base_model.graph.node)
if len(remaining) > 0: blended = fix_xl_names(blended, nodes)
logger.warning("could not match some XL keys: %s", remaining)
logger.debug( logger.debug(
"updating %s of %s initializers", "updating %s of %s initializers",