fix SDXL node names once
This commit is contained in:
parent
74832fc61b
commit
d7c95a4a4f
|
@ -73,7 +73,7 @@ def fix_node_name(key: str):
|
||||||
|
|
||||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
|
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
|
||||||
fixed = {}
|
fixed = {}
|
||||||
remaining = list(nodes)
|
names = [fix_node_name(node.name) for node in nodes]
|
||||||
|
|
||||||
for key, value in keys.items():
|
for key, value in keys.items():
|
||||||
root, *rest = key.split(".")
|
root, *rest = key.split(".")
|
||||||
|
@ -128,28 +128,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
||||||
match: Optional[NodeProto] = None
|
match: Optional[str] = None
|
||||||
if "conv" in suffix:
|
if "conv" in suffix:
|
||||||
match = next(
|
match = next(
|
||||||
node for node in remaining if fix_node_name(node.name) == f"{root}_Conv"
|
node for node in names if node == f"{root}_Conv"
|
||||||
)
|
)
|
||||||
elif "time_emb_proj" in root:
|
elif "time_emb_proj" in root:
|
||||||
match = next(
|
match = next(
|
||||||
node for node in remaining if fix_node_name(node.name) == f"{root}_Gemm"
|
node for node in names if node == f"{root}_Gemm"
|
||||||
)
|
)
|
||||||
elif block == "text_model" or simple:
|
elif block == "text_model" or simple:
|
||||||
match = next(
|
match = next(
|
||||||
node
|
node
|
||||||
for node in remaining
|
for node in names
|
||||||
if fix_node_name(node.name) == f"{root}_MatMul"
|
if node == f"{root}_MatMul"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# search in order. one side has sparse indices, so they will not match.
|
# search in order. one side has sparse indices, so they will not match.
|
||||||
match = next(
|
match = next(
|
||||||
node
|
node
|
||||||
for node in remaining
|
for node in names
|
||||||
if node.name.startswith(f"/{block}")
|
if node.startswith(block)
|
||||||
and fix_node_name(node.name).endswith(
|
and node.endswith(
|
||||||
f"{suffix}_MatMul"
|
f"{suffix}_MatMul"
|
||||||
) # needs to be fixed because some places use to_out.0
|
) # needs to be fixed because some places use to_out.0
|
||||||
)
|
)
|
||||||
|
@ -158,10 +158,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
logger.warning("no matches for XL key: %s", root)
|
logger.warning("no matches for XL key: %s", root)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.trace("matched key: %s -> %s", key, match.name)
|
logger.trace("matched key: %s -> %s", key, match)
|
||||||
|
|
||||||
name: str = match.name
|
name = match
|
||||||
name = fix_node_name(name)
|
|
||||||
if name.endswith("_MatMul"):
|
if name.endswith("_MatMul"):
|
||||||
name = name[:-7]
|
name = name[:-7]
|
||||||
elif name.endswith("_Gemm"):
|
elif name.endswith("_Gemm"):
|
||||||
|
@ -169,15 +168,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
elif name.endswith("_Conv"):
|
elif name.endswith("_Conv"):
|
||||||
name = name[:-5]
|
name = name[:-5]
|
||||||
|
|
||||||
logger.trace("matching XL key with node: %s -> %s, %s", key, match.name, name)
|
logger.trace("matching XL key with node: %s -> %s, %s", key, match, name)
|
||||||
|
|
||||||
fixed[name] = value
|
fixed[name] = value
|
||||||
remaining.remove(match)
|
names.remove(match)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"SDXL LoRA key fixup matched %s keys, %s remaining",
|
"SDXL LoRA key fixup matched %s keys, %s remaining",
|
||||||
len(fixed.keys()),
|
len(fixed.keys()),
|
||||||
len(remaining),
|
len(names),
|
||||||
)
|
)
|
||||||
|
|
||||||
return fixed
|
return fixed
|
||||||
|
@ -487,7 +486,7 @@ def blend_loras(
|
||||||
blended[base_key] = np_weights
|
blended[base_key] = np_weights
|
||||||
|
|
||||||
# rewrite node names for XL and flatten layers
|
# rewrite node names for XL and flatten layers
|
||||||
weights = Dict[str, np.ndarray] = {}
|
weights: Dict[str, np.ndarray] = {}
|
||||||
|
|
||||||
for blended in layers:
|
for blended in layers:
|
||||||
if xl:
|
if xl:
|
||||||
|
|
Loading…
Reference in New Issue