diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index d10157ee..f2dbf75e 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -452,13 +452,16 @@ def blend_loras( else: lora_prefix = f"lora_{model_type}_" - blended: Dict[str, np.ndarray] = {} + layers = [] for (lora_name, lora_weight), lora_model in zip(loras, lora_models): logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight) if lora_model is None: logger.warning("unable to load tensor for LoRA") continue + blended: Dict[str, np.ndarray] = {} + layers.append(blended) + for key in lora_model.keys(): if ".hada_w1_a" in key and lora_prefix in key: # LoHA @@ -466,38 +469,36 @@ def blend_loras( key, lora_prefix, lora_model, dtype ) np_weights = np_weights * lora_weight - if base_key in blended: - logger.trace( - "summing LoHA weights: %s + %s", - blended[base_key].shape, - np_weights.shape, - ) - blended[base_key] = sum_weights(blended[base_key], np_weights) - else: - logger.trace( - "adding LoHA weights: %s", - np_weights.shape, - ) - blended[base_key] = np_weights + logger.trace( + "adding LoHA weights: %s", + np_weights.shape, + ) + blended[base_key] = np_weights elif ".lora_down" in key and lora_prefix in key: # LoRA or LoCON base_key, np_weights = blend_weights_lora( key, lora_prefix, lora_model, dtype ) np_weights = np_weights * lora_weight - if base_key in blended: - logger.trace( - "summing LoRA weights: %s + %s", - blended[base_key].shape, - np_weights.shape, - ) - blended[base_key] = sum_weights(blended[base_key], np_weights) - else: - logger.trace( - "adding LoRA weights: %s", - np_weights.shape, - ) - blended[base_key] = np_weights + logger.trace( + "adding LoRA weights: %s", + np_weights.shape, + ) + blended[base_key] = np_weights + + # rewrite node names for XL and flatten layers + weights = Dict[str, np.ndarray] = {} + + for blended in layers: + if xl: + nodes = list(base_model.graph.node) + blended = fix_xl_names(blended, nodes) + + for key, value in blended.items(): + if key in weights: + weights[key] = sum_weights(weights[key], value) + else: + weights[key] = value # fix node names once fixed_initializer_names = [ @@ -505,19 +506,14 @@ def blend_loras( ] fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] - # rewrite node names for XL - if xl: - nodes = list(base_model.graph.node) - blended = fix_xl_names(blended, nodes) - logger.debug( "updating %s of %s initializers", - len(blended.keys()), + len(weights.keys()), len(base_model.graph.initializer), ) unmatched_keys = [] - for base_key, weights in blended.items(): + for base_key, weights in weights.items(): conv_key = base_key + "_Conv" gemm_key = base_key + "_Gemm" matmul_key = base_key + "_MatMul" @@ -579,7 +575,7 @@ def blend_loras( else: unmatched_keys.append(base_key) - logger.debug( + logger.trace( "node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), @@ -588,7 +584,7 @@ def blend_loras( ) if len(unmatched_keys) > 0: - logger.warning("could not find nodes for some keys: %s", unmatched_keys) + logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys) return base_model