diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index cb0e2db9..afa681f2 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,22 +1,15 @@ -from argparse import ArgumentParser from logging import getLogger -from os import path from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch -from onnx import ModelProto, NodeProto, load, numpy_helper -from onnx.checker import check_model -from onnx.external_data_helper import ( - convert_model_to_external_data, - set_external_data, - write_external_data_tensors, -) -from onnxruntime import InferenceSession, OrtValue, SessionOptions +from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper +from onnx.external_data_helper import set_external_data +from onnxruntime import OrtValue from scipy import interpolate from ...server.context import ServerContext -from ..utils import ConversionContext, load_tensor +from ..utils import load_tensor logger = getLogger(__name__) @@ -161,6 +154,245 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, ) +def blend_weights_loha( + key: str, lora_prefix: str, lora_model: Dict, dtype +) -> Tuple[str, np.ndarray]: + base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "") + + t1_key = key.replace("hada_w1_a", "hada_t1") + t2_key = key.replace("hada_w1_a", "hada_t2") + w1b_key = key.replace("hada_w1_a", "hada_w1_b") + w2a_key = key.replace("hada_w1_a", "hada_w2_a") + w2b_key = key.replace("hada_w1_a", "hada_w2_b") + alpha_key = key[: key.index("hada_w1_a")] + "alpha" + logger.trace( + "blending weights for LoHA keys: %s, %s, %s, %s, %s", + key, + w1b_key, + w2a_key, + w2b_key, + alpha_key, + ) + + w1a_weight = lora_model[key].to(dtype=dtype) + w1b_weight = lora_model[w1b_key].to(dtype=dtype) + w2a_weight = lora_model[w2a_key].to(dtype=dtype) + w2b_weight = lora_model[w2b_key].to(dtype=dtype) + + t1_weight = lora_model.get(t1_key, None) + t2_weight = lora_model.get(t2_key, None) + + dim = w1b_weight.size()[0] + alpha = lora_model.get(alpha_key, dim).to(dtype).numpy() + + if t1_weight is not None and t2_weight is not None: + t1_weight = t1_weight.to(dtype=dtype) + t2_weight = t2_weight.to(dtype=dtype) + + logger.trace( + "composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)", + t1_weight.shape, + w1a_weight.shape, + w1b_weight.shape, + t2_weight.shape, + w2a_weight.shape, + w2b_weight.shape, + ) + weights_1 = torch.einsum( + "i j k l, j r, i p -> p r k l", + t1_weight, + w1b_weight, + w1a_weight, + ) + weights_2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + t2_weight, + w2b_weight, + w2a_weight, + ) + weights = weights_1 * weights_2 + np_weights = weights.numpy() * (alpha / dim) + else: + logger.trace( + "blending weights for LoHA node: (%s @ %s) * (%s @ %s)", + w1a_weight.shape, + w1b_weight.shape, + w2a_weight.shape, + w2b_weight.shape, + ) + weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight) + np_weights = weights.numpy() * (alpha / dim) + + return base_key, np_weights + + +def blend_weights_lora( + key: str, lora_prefix: str, lora_model: Dict, dtype +) -> Tuple[str, np.ndarray]: + base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") + + mid_key = key.replace("lora_down", "lora_mid") + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key) + + down_weight = lora_model[key].to(dtype=dtype) + up_weight = lora_model[up_key].to(dtype=dtype) + + mid_weight = None + if mid_key in lora_model: + mid_weight = lora_model[mid_key].to(dtype=dtype) + + dim = down_weight.size()[0] + alpha = lora_model.get(alpha_key, dim) + + if not isinstance(alpha, int): + alpha = alpha.to(dtype).numpy() + + kernel = down_weight.shape[-2:] + if mid_weight is not None: + kernel = mid_weight.shape[-2:] + + if len(down_weight.size()) == 2: + # blend for nn.Linear + logger.trace( + "blending weights for Linear node: (%s @ %s) * %s", + down_weight.shape, + up_weight.shape, + alpha, + ) + weights = up_weight @ down_weight + np_weights = weights.numpy() * (alpha / dim) + elif len(down_weight.size()) == 4 and kernel == ( + 1, + 1, + ): + # blend for nn.Conv2d 1x1 + logger.trace( + "blending weights for Conv 1x1 node: %s, %s, %s", + down_weight.shape, + up_weight.shape, + alpha, + ) + weights = ( + (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)) + .unsqueeze(2) + .unsqueeze(3) + ) + np_weights = weights.numpy() * (alpha / dim) + elif len(down_weight.size()) == 4 and kernel == ( + 3, + 3, + ): + if mid_weight is not None: + # blend for nn.Conv2d 3x3 with CP decomp + logger.trace( + "composing weights for Conv 3x3 node: %s, %s, %s, %s", + down_weight.shape, + up_weight.shape, + mid_weight.shape, + alpha, + ) + weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel)) + + for w in range(kernel[0]): + for h in range(kernel[1]): + weights[:, :, w, h] = ( + up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h] + ) @ down_weight.squeeze(3).squeeze(2) + + np_weights = weights.numpy() * (alpha / dim) + else: + # blend for nn.Conv2d 3x3 + logger.trace( + "blending weights for Conv 3x3 node: %s, %s, %s", + down_weight.shape, + up_weight.shape, + alpha, + ) + weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel)) + + for w in range(kernel[0]): + for h in range(kernel[1]): + down_w, down_h = kernel_slice(w, h, down_weight.shape) + up_w, up_h = kernel_slice(w, h, up_weight.shape) + + weights[:, :, w, h] = ( + up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h] + ) + + np_weights = weights.numpy() * (alpha / dim) + else: + logger.warning( + "unknown LoRA node type at %s: %s", + base_key, + up_weight.shape[-2:], + ) + # TODO: should this be None? + np_weights = np.zeros((1, 1, 1, 1)) + + return base_key, np_weights + + +def blend_node_conv_gemm(weight_node, weights) -> TensorProto: + # blending + onnx_weights = numpy_helper.to_array(weight_node) + logger.trace( + "found blended weights for conv: %s, %s", + onnx_weights.shape, + weights.shape, + ) + + if onnx_weights.shape[-2:] == (1, 1): + if weights.shape[-2:] == (1, 1): + blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) + else: + blended = onnx_weights.squeeze((3, 2)) + weights + + blended = np.expand_dims(blended, (2, 3)) + else: + if onnx_weights.shape != weights.shape: + logger.warning( + "reshaping weights for mismatched Conv node: %s, %s", + onnx_weights.shape, + weights.shape, + ) + # TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus + blended = onnx_weights + weights.reshape(onnx_weights.shape) + else: + blended = onnx_weights + weights + + logger.trace("blended weight shape: %s", blended.shape) + + # replace the original initializer + return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name) + + +def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto: + onnx_weights = numpy_helper.to_array(matmul_node) + logger.trace( + "found blended weights for matmul: %s, %s", + weights.shape, + onnx_weights.shape, + ) + + t_weights = weights.transpose() + if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape: + logger.warning( + "weight shapes do not match for %s: %s vs %s", + matmul_key, + weights.shape, + onnx_weights.shape, + ) + t_weights = interp_to_match(weights, onnx_weights).transpose() + + blended = onnx_weights + t_weights + logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype) + + # replace the original initializer + return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name) + + def blend_loras( _conversion: ServerContext, base_name: Union[str, ModelProto], @@ -194,205 +426,41 @@ def blend_loras( for key in lora_model.keys(): if ".hada_w1_a" in key and lora_prefix in key: # LoHA - base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "") - - t1_key = key.replace("hada_w1_a", "hada_t1") - t2_key = key.replace("hada_w1_a", "hada_t2") - w1b_key = key.replace("hada_w1_a", "hada_w1_b") - w2a_key = key.replace("hada_w1_a", "hada_w2_a") - w2b_key = key.replace("hada_w1_a", "hada_w2_b") - alpha_key = key[: key.index("hada_w1_a")] + "alpha" - logger.trace( - "blending weights for LoHA keys: %s, %s, %s, %s, %s", - key, - w1b_key, - w2a_key, - w2b_key, - alpha_key, + base_key, np_weights = blend_weights_loha( + key, lora_prefix, lora_model, dtype ) - - w1a_weight = lora_model[key].to(dtype=dtype) - w1b_weight = lora_model[w1b_key].to(dtype=dtype) - w2a_weight = lora_model[w2a_key].to(dtype=dtype) - w2b_weight = lora_model[w2b_key].to(dtype=dtype) - - t1_weight = lora_model.get(t1_key, None) - t2_weight = lora_model.get(t2_key, None) - - dim = w1b_weight.size()[0] - alpha = lora_model.get(alpha_key, dim).to(dtype).numpy() - - if t1_weight is not None and t2_weight is not None: - t1_weight = t1_weight.to(dtype=dtype) - t2_weight = t2_weight.to(dtype=dtype) - - logger.trace( - "composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)", - t1_weight.shape, - w1a_weight.shape, - w1b_weight.shape, - t2_weight.shape, - w2a_weight.shape, - w2b_weight.shape, - ) - weights_1 = torch.einsum( - "i j k l, j r, i p -> p r k l", - t1_weight, - w1b_weight, - w1a_weight, - ) - weights_2 = torch.einsum( - "i j k l, j r, i p -> p r k l", - t2_weight, - w2b_weight, - w2a_weight, - ) - weights = weights_1 * weights_2 - np_weights = weights.numpy() * (alpha / dim) - else: - logger.trace( - "blending weights for LoHA node: (%s @ %s) * (%s @ %s)", - w1a_weight.shape, - w1b_weight.shape, - w2a_weight.shape, - w2b_weight.shape, - ) - weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight) - np_weights = weights.numpy() * (alpha / dim) - - np_weights *= lora_weight + np_weights = np_weights * lora_weight if base_key in blended: logger.trace( "summing LoHA weights: %s + %s", blended[base_key].shape, np_weights.shape, ) - blended[base_key] += sum_weights(blended[base_key], np_weights) + blended[base_key] = sum_weights(blended[base_key], np_weights) else: + logger.trace( + "adding LoHA weights: %s", + np_weights.shape, + ) blended[base_key] = np_weights elif ".lora_down" in key and lora_prefix in key: # LoRA or LoCON - base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") - - mid_key = key.replace("lora_down", "lora_mid") - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - logger.trace( - "blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key + base_key, np_weights = blend_weights_lora( + key, lora_prefix, lora_model, dtype ) - - down_weight = lora_model[key].to(dtype=dtype) - up_weight = lora_model[up_key].to(dtype=dtype) - - mid_weight = None - if mid_key in lora_model: - mid_weight = lora_model[mid_key].to(dtype=dtype) - - dim = down_weight.size()[0] - alpha = lora_model.get(alpha_key, dim) - - if not isinstance(alpha, int): - alpha = alpha.to(dtype).numpy() - - kernel = down_weight.shape[-2:] - if mid_weight is not None: - kernel = mid_weight.shape[-2:] - - if len(down_weight.size()) == 2: - # blend for nn.Linear - logger.trace( - "blending weights for Linear node: (%s @ %s) * %s", - down_weight.shape, - up_weight.shape, - alpha, - ) - weights = up_weight @ down_weight - np_weights = weights.numpy() * (alpha / dim) - elif len(down_weight.size()) == 4 and kernel == ( - 1, - 1, - ): - # blend for nn.Conv2d 1x1 - logger.trace( - "blending weights for Conv 1x1 node: %s, %s, %s", - down_weight.shape, - up_weight.shape, - alpha, - ) - weights = ( - ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) - ) - np_weights = weights.numpy() * (alpha / dim) - elif len(down_weight.size()) == 4 and kernel == ( - 3, - 3, - ): - if mid_weight is not None: - # blend for nn.Conv2d 3x3 with CP decomp - logger.trace( - "composing weights for Conv 3x3 node: %s, %s, %s, %s", - down_weight.shape, - up_weight.shape, - mid_weight.shape, - alpha, - ) - weights = torch.zeros( - (up_weight.shape[0], down_weight.shape[1], *kernel) - ) - - for w in range(kernel[0]): - for h in range(kernel[1]): - weights[:, :, w, h] = ( - up_weight.squeeze(3).squeeze(2) - @ mid_weight[:, :, w, h] - ) @ down_weight.squeeze(3).squeeze(2) - - np_weights = weights.numpy() * (alpha / dim) - else: - # blend for nn.Conv2d 3x3 - logger.trace( - "blending weights for Conv 3x3 node: %s, %s, %s", - down_weight.shape, - up_weight.shape, - alpha, - ) - weights = torch.zeros( - (up_weight.shape[0], down_weight.shape[1], *kernel) - ) - - for w in range(kernel[0]): - for h in range(kernel[1]): - down_w, down_h = kernel_slice(w, h, down_weight.shape) - up_w, up_h = kernel_slice(w, h, up_weight.shape) - - weights[:, :, w, h] = ( - up_weight[:, :, up_w, up_h] - @ down_weight[:, :, down_w, down_h] - ) - - np_weights = weights.numpy() * (alpha / dim) - else: - logger.warning( - "unknown LoRA node type at %s: %s", - base_key, - up_weight.shape[-2:], - ) - continue - - np_weights *= lora_weight + np_weights = np_weights * lora_weight if base_key in blended: logger.trace( - "summing weights: %s + %s", + "summing LoRA weights: %s + %s", blended[base_key].shape, np_weights.shape, ) blended[base_key] = sum_weights(blended[base_key], np_weights) else: + logger.trace( + "adding LoRA weights: %s", + np_weights.shape, + ) blended[base_key] = np_weights # rewrite node names for XL @@ -400,7 +468,7 @@ def blend_loras( nodes = list(base_model.graph.node) blended = fix_xl_names(blended, nodes) - logger.trace( + logger.debug( "updating %s of %s initializers", len(blended.keys()), len(base_model.graph.initializer), @@ -409,10 +477,7 @@ def blend_loras( fixed_initializer_names = [ fix_initializer_name(node.name) for node in base_model.graph.initializer ] - logger.trace("fixed initializer names: %s", fixed_initializer_names) - fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] - logger.trace("fixed node names: %s", fixed_node_names) unmatched_keys = [] for base_key, weights in blended.items(): @@ -421,9 +486,10 @@ def blend_loras( matmul_key = base_key + "_MatMul" logger.trace( - "key %s has conv: %s, matmul: %s", + "key %s has conv: %s, gemm: %s, matmul: %s", base_key, conv_key in fixed_node_names, + gemm_key in fixed_node_names, matmul_key in fixed_node_names, ) @@ -449,38 +515,9 @@ def blend_loras( weight_node = base_model.graph.initializer[weight_idx] logger.trace("found weight initializer: %s", weight_node.name) - # blending - onnx_weights = numpy_helper.to_array(weight_node) - logger.trace( - "found blended weights for conv: %s, %s", - onnx_weights.shape, - weights.shape, - ) + # replace the previous node + updated_node = blend_node_conv_gemm(weight_node, weights) - if onnx_weights.shape[-2:] == (1, 1): - if weights.shape[-2:] == (1, 1): - blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) - else: - blended = onnx_weights.squeeze((3, 2)) + weights - - blended = np.expand_dims(blended, (2, 3)) - else: - if onnx_weights.shape != weights.shape: - logger.warning( - "reshaping weights for mismatched Conv node: %s, %s", - onnx_weights.shape, - weights.shape, - ) - blended = onnx_weights + weights.reshape(onnx_weights.shape) - else: - blended = onnx_weights + weights - - logger.trace("blended weight shape: %s", blended.shape) - - # replace the original initializer - updated_node = numpy_helper.from_array( - blended.astype(onnx_weights.dtype), weight_node.name - ) del base_model.graph.initializer[weight_idx] base_model.graph.initializer.insert(weight_idx, updated_node) elif matmul_key in fixed_node_names: @@ -497,36 +534,9 @@ def blend_loras( matmul_node = base_model.graph.initializer[matmul_idx] logger.trace("found matmul initializer: %s", matmul_node.name) - # blending - onnx_weights = numpy_helper.to_array(matmul_node) - logger.trace( - "found blended weights for matmul: %s, %s", - weights.shape, - onnx_weights.shape, - ) + # replace the previous node + updated_node = blend_node_matmul(matmul_node, weights, matmul_key) - t_weights = weights.transpose() - if ( - weights.shape != onnx_weights.shape - and t_weights.shape != onnx_weights.shape - ): - logger.warning( - "weight shapes do not match for %s: %s vs %s", - matmul_key, - weights.shape, - onnx_weights.shape, - ) - t_weights = interp_to_match(weights, onnx_weights).transpose() - - blended = onnx_weights + t_weights - logger.trace( - "blended weight shape: %s, %s", blended.shape, onnx_weights.dtype - ) - - # replace the original initializer - updated_node = numpy_helper.from_array( - blended.astype(onnx_weights.dtype), matmul_node.name - ) del base_model.graph.initializer[matmul_idx] base_model.graph.initializer.insert(matmul_idx, updated_node) else: @@ -565,63 +575,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray: logger.debug("weights after interpolation: %s", output.shape) return output - - -if __name__ == "__main__": - context = ConversionContext.from_environ() - parser = ArgumentParser() - parser.add_argument("--base", type=str) - parser.add_argument("--dest", type=str) - parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) - parser.add_argument("--lora_models", nargs="+", type=str, default=[]) - parser.add_argument("--lora_weights", nargs="+", type=float, default=[]) - - args = parser.parse_args() - logger.info( - "merging %s with %s with weights: %s", - args.lora_models, - args.base, - args.lora_weights, - ) - - default_weight = 1.0 / len(args.lora_models) - while len(args.lora_weights) < len(args.lora_models): - args.lora_weights.append(default_weight) - - blend_model = blend_loras( - context, - args.base, - list(zip(args.lora_models, args.lora_weights)), - args.type, - ) - if args.dest is None or args.dest == "" or args.dest == ":load": - # convert to external data and save to memory - (bare_model, external_data) = buffer_external_data_tensors(blend_model) - logger.info("saved external data for %s nodes", len(external_data)) - - external_names, external_values = zip(*external_data) - opts = SessionOptions() - opts.add_external_initializers(list(external_names), list(external_values)) - sess = InferenceSession( - bare_model.SerializeToString(), - sess_options=opts, - providers=["CPUExecutionProvider"], - ) - logger.info( - "successfully loaded blended model: %s", [i.name for i in sess.get_inputs()] - ) - else: - convert_model_to_external_data( - blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb" - ) - bare_model = write_external_data_tensors(blend_model, args.dest) - dest_file = path.join(args.dest, f"lora-{args.type}.onnx") - - with open(dest_file, "w+b") as model_file: - model_file.write(bare_model.SerializeToString()) - - logger.info("successfully saved blended model: %s", dest_file) - - check_model(dest_file) - - logger.info("checked blended model") diff --git a/api/scripts/onnx-lora.py b/api/scripts/onnx-lora.py new file mode 100644 index 00000000..14e72d14 --- /dev/null +++ b/api/scripts/onnx-lora.py @@ -0,0 +1,74 @@ +from argparse import ArgumentParser +from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors +from os import path +from onnx.checker import check_model +from onnx.external_data_helper import ( + convert_model_to_external_data, + write_external_data_tensors, +) +from onnxruntime import InferenceSession, SessionOptions +from logging import getLogger + +from onnx_web.convert.utils import ConversionContext + +logger = getLogger(__name__) + + +if __name__ == "__main__": + context = ConversionContext.from_environ() + parser = ArgumentParser() + parser.add_argument("--base", type=str) + parser.add_argument("--dest", type=str) + parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) + parser.add_argument("--lora_models", nargs="+", type=str, default=[]) + parser.add_argument("--lora_weights", nargs="+", type=float, default=[]) + + args = parser.parse_args() + logger.info( + "merging %s with %s with weights: %s", + args.lora_models, + args.base, + args.lora_weights, + ) + + default_weight = 1.0 / len(args.lora_models) + while len(args.lora_weights) < len(args.lora_models): + args.lora_weights.append(default_weight) + + blend_model = blend_loras( + context, + args.base, + list(zip(args.lora_models, args.lora_weights)), + args.type, + ) + if args.dest is None or args.dest == "" or args.dest == ":load": + # convert to external data and save to memory + (bare_model, external_data) = buffer_external_data_tensors(blend_model) + logger.info("saved external data for %s nodes", len(external_data)) + + external_names, external_values = zip(*external_data) + opts = SessionOptions() + opts.add_external_initializers(list(external_names), list(external_values)) + sess = InferenceSession( + bare_model.SerializeToString(), + sess_options=opts, + providers=["CPUExecutionProvider"], + ) + logger.info( + "successfully loaded blended model: %s", [i.name for i in sess.get_inputs()] + ) + else: + convert_model_to_external_data( + blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb" + ) + bare_model = write_external_data_tensors(blend_model, args.dest) + dest_file = path.join(args.dest, f"lora-{args.type}.onnx") + + with open(dest_file, "w+b") as model_file: + model_file.write(bare_model.SerializeToString()) + + logger.info("successfully saved blended model: %s", dest_file) + + check_model(dest_file) + + logger.info("checked blended model") diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index 01372e93..87a7fff0 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -1,11 +1,16 @@ import unittest import numpy as np +import torch from onnx import GraphProto, ModelProto, NodeProto from onnx.numpy_helper import from_array from onnx_web.convert.diffusion.lora import ( blend_loras, + blend_node_conv_gemm, + blend_node_matmul, + blend_weights_loha, + blend_weights_lora, buffer_external_data_tensors, fix_initializer_name, fix_node_name, @@ -151,6 +156,23 @@ class KernelSliceTests(unittest.TestCase): ) +class InterpToMatchTests(unittest.TestCase): + def test_same_shape(self): + ref = np.zeros((4, 4)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + def test_different_one_dim(self): + ref = np.zeros((4, 2)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + def test_different_both_dims(self): + ref = np.zeros((2, 2)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + class BlendLoRATests(unittest.TestCase): def test_blend_unet(self): """ @@ -183,18 +205,131 @@ class BlendLoRATests(unittest.TestCase): pass -class InterpToMatchTests(unittest.TestCase): - def test_same_shape(self): - ref = np.zeros((4, 4)) - resize = np.zeros((4, 4)) - self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) +class BlendWeightsLoHATests(unittest.TestCase): + def test_blend_t1_t2(self): + # blend einsum: i j k l, j r, i p -> p r k l + i = 32 + j = 4 + k = 1 + l = 1 + p = 2 + r = 4 - def test_different_one_dim(self): - ref = np.zeros((4, 2)) - resize = np.zeros((4, 4)) - self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + model = { + "foo.hada_t1": torch.from_numpy(np.ones((i, j, k, l))), + "foo.hada_t2": torch.from_numpy(np.ones((i, j, k, l))), + "foo.hada_w1_a": torch.from_numpy(np.ones((i, p))), + "foo.hada_w1_b": torch.from_numpy(np.ones((j, r))), + "foo.hada_w2_a": torch.from_numpy(np.ones((i, p))), + "foo.hada_w2_b": torch.from_numpy(np.ones((j, r))), + "foo.alpha": torch.tensor(1), + } + key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32) + self.assertEqual(result.shape, (p, r, k, l)) - def test_different_both_dims(self): - ref = np.zeros((2, 2)) - resize = np.zeros((4, 4)) - self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + def test_blend_w1_w2(self): + model = { + "foo.hada_w1_a": torch.from_numpy(np.ones((4, 1))), + "foo.hada_w1_b": torch.from_numpy(np.ones((1, 4))), + "foo.hada_w2_a": torch.from_numpy(np.ones((4, 1))), + "foo.hada_w2_b": torch.from_numpy(np.ones((1, 4))), + "foo.alpha": torch.tensor(1), + } + key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4)) + + def test_blend_no_dim(self): + """ + model = { + "foo.hada_w1_a": torch.from_numpy(np.ones((1, 4))), + "foo.hada_w1_b": torch.from_numpy(np.ones((4, 1))), + "foo.hada_w2_a": torch.from_numpy(np.ones((1, 4))), + "foo.hada_w2_b": torch.from_numpy(np.ones((4, 1))), + } + result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4)) + """ + +class BlendWeightsLoRATests(unittest.TestCase): + def test_blend_kernel_none(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((1, 4))), + "foo.lora_up": torch.from_numpy(np.ones((4, 1))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4)) + + + def test_blend_kernel_1x1(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))), + "foo.lora_up": torch.from_numpy(np.ones((4, 1, 1, 1))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4, 1, 1)) + + def test_blend_kernel_3x3(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((1, 4, 3, 3))), + "foo.lora_up": torch.from_numpy(np.ones((4, 1, 3, 3))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4, 3, 3)) + + def test_blend_kernel_3x3_cp_decomp(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((2, 4, 1, 1))), + "foo.lora_mid": torch.from_numpy(np.ones((2, 2, 3, 3))), + "foo.lora_up": torch.from_numpy(np.ones((4, 2, 1, 1))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4, 3, 3)) + + def test_blend_unknown(self): + pass + + +class BlendNodeConvGemmTests(unittest.TestCase): + def test_blend_kernel_1x1_and_1x1(self): + node = from_array(np.ones((4, 4, 1, 1))) + result = blend_node_conv_gemm(node, np.ones((4, 4, 1, 1))) + + self.assertEqual(result.dims, [4, 4, 1, 1]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_kernel_1x1_and_none(self): + node = from_array(np.ones((4, 4, 1, 1))) + result = blend_node_conv_gemm(node, np.ones((4, 4))) + + self.assertEqual(result.dims, [4, 4, 1, 1]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_other_matching(self): + node = from_array(np.ones((4, 4))) + result = blend_node_conv_gemm(node, np.ones((4, 4))) + + self.assertEqual(result.dims, [4, 4]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_other_mismatched(self): + pass + + +class BlendNodeMatMulTests(unittest.TestCase): + def test_blend_matching(self): + node = from_array(np.ones((4, 4))) + result = blend_node_matmul(node, np.ones((4, 4)), "test") + + self.assertEqual(result.dims, [4, 4]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_mismatched(self): + node = from_array(np.ones((4, 4))) + result = blend_node_matmul(node, np.ones((2, 2)), "test") + + self.assertEqual(result.dims, [4, 4]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8)