1
0
Fork 0

fix SDXL node names once

This commit is contained in:
Sean Sube 2023-11-24 16:51:03 -06:00
parent 74832fc61b
commit d7c95a4a4f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 15 additions and 16 deletions

View File

@ -73,7 +73,7 @@ def fix_node_name(key: str):
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]: def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
fixed = {} fixed = {}
remaining = list(nodes) names = [fix_node_name(node.name) for node in nodes]
for key, value in keys.items(): for key, value in keys.items():
root, *rest = key.split(".") root, *rest = key.split(".")
@ -128,28 +128,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
continue continue
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix) logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
match: Optional[NodeProto] = None match: Optional[str] = None
if "conv" in suffix: if "conv" in suffix:
match = next( match = next(
node for node in remaining if fix_node_name(node.name) == f"{root}_Conv" node for node in names if node == f"{root}_Conv"
) )
elif "time_emb_proj" in root: elif "time_emb_proj" in root:
match = next( match = next(
node for node in remaining if fix_node_name(node.name) == f"{root}_Gemm" node for node in names if node == f"{root}_Gemm"
) )
elif block == "text_model" or simple: elif block == "text_model" or simple:
match = next( match = next(
node node
for node in remaining for node in names
if fix_node_name(node.name) == f"{root}_MatMul" if node == f"{root}_MatMul"
) )
else: else:
# search in order. one side has sparse indices, so they will not match. # search in order. one side has sparse indices, so they will not match.
match = next( match = next(
node node
for node in remaining for node in names
if node.name.startswith(f"/{block}") if node.startswith(block)
and fix_node_name(node.name).endswith( and node.endswith(
f"{suffix}_MatMul" f"{suffix}_MatMul"
) # needs to be fixed because some places use to_out.0 ) # needs to be fixed because some places use to_out.0
) )
@ -158,10 +158,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
logger.warning("no matches for XL key: %s", root) logger.warning("no matches for XL key: %s", root)
continue continue
else: else:
logger.trace("matched key: %s -> %s", key, match.name) logger.trace("matched key: %s -> %s", key, match)
name: str = match.name name = match
name = fix_node_name(name)
if name.endswith("_MatMul"): if name.endswith("_MatMul"):
name = name[:-7] name = name[:-7]
elif name.endswith("_Gemm"): elif name.endswith("_Gemm"):
@ -169,15 +168,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
elif name.endswith("_Conv"): elif name.endswith("_Conv"):
name = name[:-5] name = name[:-5]
logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name) logger.trace("matching XL key with node: %s -> %s, %s", key, match, name)
fixed[name] = value fixed[name] = value
remaining.remove(match) names.remove(match)
logger.debug( logger.debug(
"SDXL LoRA key fixup matched %s keys, %s remaining", "SDXL LoRA key fixup matched %s keys, %s remaining",
len(fixed.keys()), len(fixed.keys()),
len(remaining), len(names),
) )
return fixed return fixed
@ -487,7 +486,7 @@ def blend_loras(
blended[base_key] = np_weights blended[base_key] = np_weights
# rewrite node names for XL and flatten layers # rewrite node names for XL and flatten layers
weights = Dict[str, np.ndarray] = {} weights: Dict[str, np.ndarray] = {}
for blended in layers: for blended in layers:
if xl: if xl: