fix(api): match SDXL keys per LoRA
This commit is contained in:
parent
8d4410305e
commit
74832fc61b
|
@ -452,13 +452,16 @@ def blend_loras(
|
||||||
else:
|
else:
|
||||||
lora_prefix = f"lora_{model_type}_"
|
lora_prefix = f"lora_{model_type}_"
|
||||||
|
|
||||||
blended: Dict[str, np.ndarray] = {}
|
layers = []
|
||||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||||
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||||
if lora_model is None:
|
if lora_model is None:
|
||||||
logger.warning("unable to load tensor for LoRA")
|
logger.warning("unable to load tensor for LoRA")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
blended: Dict[str, np.ndarray] = {}
|
||||||
|
layers.append(blended)
|
||||||
|
|
||||||
for key in lora_model.keys():
|
for key in lora_model.keys():
|
||||||
if ".hada_w1_a" in key and lora_prefix in key:
|
if ".hada_w1_a" in key and lora_prefix in key:
|
||||||
# LoHA
|
# LoHA
|
||||||
|
@ -466,38 +469,36 @@ def blend_loras(
|
||||||
key, lora_prefix, lora_model, dtype
|
key, lora_prefix, lora_model, dtype
|
||||||
)
|
)
|
||||||
np_weights = np_weights * lora_weight
|
np_weights = np_weights * lora_weight
|
||||||
if base_key in blended:
|
logger.trace(
|
||||||
logger.trace(
|
"adding LoHA weights: %s",
|
||||||
"summing LoHA weights: %s + %s",
|
np_weights.shape,
|
||||||
blended[base_key].shape,
|
)
|
||||||
np_weights.shape,
|
blended[base_key] = np_weights
|
||||||
)
|
|
||||||
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
|
||||||
else:
|
|
||||||
logger.trace(
|
|
||||||
"adding LoHA weights: %s",
|
|
||||||
np_weights.shape,
|
|
||||||
)
|
|
||||||
blended[base_key] = np_weights
|
|
||||||
elif ".lora_down" in key and lora_prefix in key:
|
elif ".lora_down" in key and lora_prefix in key:
|
||||||
# LoRA or LoCON
|
# LoRA or LoCON
|
||||||
base_key, np_weights = blend_weights_lora(
|
base_key, np_weights = blend_weights_lora(
|
||||||
key, lora_prefix, lora_model, dtype
|
key, lora_prefix, lora_model, dtype
|
||||||
)
|
)
|
||||||
np_weights = np_weights * lora_weight
|
np_weights = np_weights * lora_weight
|
||||||
if base_key in blended:
|
logger.trace(
|
||||||
logger.trace(
|
"adding LoRA weights: %s",
|
||||||
"summing LoRA weights: %s + %s",
|
np_weights.shape,
|
||||||
blended[base_key].shape,
|
)
|
||||||
np_weights.shape,
|
blended[base_key] = np_weights
|
||||||
)
|
|
||||||
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
# rewrite node names for XL and flatten layers
|
||||||
else:
|
weights = Dict[str, np.ndarray] = {}
|
||||||
logger.trace(
|
|
||||||
"adding LoRA weights: %s",
|
for blended in layers:
|
||||||
np_weights.shape,
|
if xl:
|
||||||
)
|
nodes = list(base_model.graph.node)
|
||||||
blended[base_key] = np_weights
|
blended = fix_xl_names(blended, nodes)
|
||||||
|
|
||||||
|
for key, value in blended.items():
|
||||||
|
if key in weights:
|
||||||
|
weights[key] = sum_weights(weights[key], value)
|
||||||
|
else:
|
||||||
|
weights[key] = value
|
||||||
|
|
||||||
# fix node names once
|
# fix node names once
|
||||||
fixed_initializer_names = [
|
fixed_initializer_names = [
|
||||||
|
@ -505,19 +506,14 @@ def blend_loras(
|
||||||
]
|
]
|
||||||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||||
|
|
||||||
# rewrite node names for XL
|
|
||||||
if xl:
|
|
||||||
nodes = list(base_model.graph.node)
|
|
||||||
blended = fix_xl_names(blended, nodes)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"updating %s of %s initializers",
|
"updating %s of %s initializers",
|
||||||
len(blended.keys()),
|
len(weights.keys()),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
)
|
)
|
||||||
|
|
||||||
unmatched_keys = []
|
unmatched_keys = []
|
||||||
for base_key, weights in blended.items():
|
for base_key, weights in weights.items():
|
||||||
conv_key = base_key + "_Conv"
|
conv_key = base_key + "_Conv"
|
||||||
gemm_key = base_key + "_Gemm"
|
gemm_key = base_key + "_Gemm"
|
||||||
matmul_key = base_key + "_MatMul"
|
matmul_key = base_key + "_MatMul"
|
||||||
|
@ -579,7 +575,7 @@ def blend_loras(
|
||||||
else:
|
else:
|
||||||
unmatched_keys.append(base_key)
|
unmatched_keys.append(base_key)
|
||||||
|
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"node counts: %s -> %s, %s -> %s",
|
"node counts: %s -> %s, %s -> %s",
|
||||||
len(fixed_initializer_names),
|
len(fixed_initializer_names),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
|
@ -588,7 +584,7 @@ def blend_loras(
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(unmatched_keys) > 0:
|
if len(unmatched_keys) > 0:
|
||||||
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
|
||||||
|
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue