1
0
Fork 0

fix(api): match SDXL keys per LoRA

This commit is contained in:
Sean Sube 2023-11-24 15:22:07 -06:00
parent 8d4410305e
commit 74832fc61b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 32 additions and 36 deletions

View File

@ -452,13 +452,16 @@ def blend_loras(
else:
lora_prefix = f"lora_{model_type}_"
blended: Dict[str, np.ndarray] = {}
layers = []
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue
blended: Dict[str, np.ndarray] = {}
layers.append(blended)
for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
@ -466,38 +469,36 @@ def blend_loras(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
if base_key in blended:
logger.trace(
"summing LoHA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
logger.trace(
"adding LoHA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
logger.trace(
"adding LoHA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
base_key, np_weights = blend_weights_lora(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
if base_key in blended:
logger.trace(
"summing LoRA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
# rewrite node names for XL and flatten layers
weights = Dict[str, np.ndarray] = {}
for blended in layers:
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
for key, value in blended.items():
if key in weights:
weights[key] = sum_weights(weights[key], value)
else:
weights[key] = value
# fix node names once
fixed_initializer_names = [
@ -505,19 +506,14 @@ def blend_loras(
]
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
# rewrite node names for XL
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
logger.debug(
"updating %s of %s initializers",
len(blended.keys()),
len(weights.keys()),
len(base_model.graph.initializer),
)
unmatched_keys = []
for base_key, weights in blended.items():
for base_key, weights in weights.items():
conv_key = base_key + "_Conv"
gemm_key = base_key + "_Gemm"
matmul_key = base_key + "_MatMul"
@ -579,7 +575,7 @@ def blend_loras(
else:
unmatched_keys.append(base_key)
logger.debug(
logger.trace(
"node counts: %s -> %s, %s -> %s",
len(fixed_initializer_names),
len(base_model.graph.initializer),
@ -588,7 +584,7 @@ def blend_loras(
)
if len(unmatched_keys) > 0:
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
return base_model