fix(api): match SDXL keys per LoRA
This commit is contained in:
parent
8d4410305e
commit
74832fc61b
|
@ -452,13 +452,16 @@ def blend_loras(
|
|||
else:
|
||||
lora_prefix = f"lora_{model_type}_"
|
||||
|
||||
blended: Dict[str, np.ndarray] = {}
|
||||
layers = []
|
||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||
if lora_model is None:
|
||||
logger.warning("unable to load tensor for LoRA")
|
||||
continue
|
||||
|
||||
blended: Dict[str, np.ndarray] = {}
|
||||
layers.append(blended)
|
||||
|
||||
for key in lora_model.keys():
|
||||
if ".hada_w1_a" in key and lora_prefix in key:
|
||||
# LoHA
|
||||
|
@ -466,14 +469,6 @@ def blend_loras(
|
|||
key, lora_prefix, lora_model, dtype
|
||||
)
|
||||
np_weights = np_weights * lora_weight
|
||||
if base_key in blended:
|
||||
logger.trace(
|
||||
"summing LoHA weights: %s + %s",
|
||||
blended[base_key].shape,
|
||||
np_weights.shape,
|
||||
)
|
||||
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
||||
else:
|
||||
logger.trace(
|
||||
"adding LoHA weights: %s",
|
||||
np_weights.shape,
|
||||
|
@ -485,39 +480,40 @@ def blend_loras(
|
|||
key, lora_prefix, lora_model, dtype
|
||||
)
|
||||
np_weights = np_weights * lora_weight
|
||||
if base_key in blended:
|
||||
logger.trace(
|
||||
"summing LoRA weights: %s + %s",
|
||||
blended[base_key].shape,
|
||||
np_weights.shape,
|
||||
)
|
||||
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
||||
else:
|
||||
logger.trace(
|
||||
"adding LoRA weights: %s",
|
||||
np_weights.shape,
|
||||
)
|
||||
blended[base_key] = np_weights
|
||||
|
||||
# rewrite node names for XL and flatten layers
|
||||
weights = Dict[str, np.ndarray] = {}
|
||||
|
||||
for blended in layers:
|
||||
if xl:
|
||||
nodes = list(base_model.graph.node)
|
||||
blended = fix_xl_names(blended, nodes)
|
||||
|
||||
for key, value in blended.items():
|
||||
if key in weights:
|
||||
weights[key] = sum_weights(weights[key], value)
|
||||
else:
|
||||
weights[key] = value
|
||||
|
||||
# fix node names once
|
||||
fixed_initializer_names = [
|
||||
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
||||
]
|
||||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||
|
||||
# rewrite node names for XL
|
||||
if xl:
|
||||
nodes = list(base_model.graph.node)
|
||||
blended = fix_xl_names(blended, nodes)
|
||||
|
||||
logger.debug(
|
||||
"updating %s of %s initializers",
|
||||
len(blended.keys()),
|
||||
len(weights.keys()),
|
||||
len(base_model.graph.initializer),
|
||||
)
|
||||
|
||||
unmatched_keys = []
|
||||
for base_key, weights in blended.items():
|
||||
for base_key, weights in weights.items():
|
||||
conv_key = base_key + "_Conv"
|
||||
gemm_key = base_key + "_Gemm"
|
||||
matmul_key = base_key + "_MatMul"
|
||||
|
@ -579,7 +575,7 @@ def blend_loras(
|
|||
else:
|
||||
unmatched_keys.append(base_key)
|
||||
|
||||
logger.debug(
|
||||
logger.trace(
|
||||
"node counts: %s -> %s, %s -> %s",
|
||||
len(fixed_initializer_names),
|
||||
len(base_model.graph.initializer),
|
||||
|
@ -588,7 +584,7 @@ def blend_loras(
|
|||
)
|
||||
|
||||
if len(unmatched_keys) > 0:
|
||||
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
||||
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
|
||||
|
||||
return base_model
|
||||
|
||||
|
|
Loading…
Reference in New Issue