1
0
Fork 0

fix(api): test LoRA blending code

This commit is contained in:
Sean Sube 2023-09-21 18:24:08 -05:00
parent 761bfa8364
commit 52fdf4f48a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 488 additions and 329 deletions

View File

@ -1,22 +1,15 @@
from argparse import ArgumentParser
from logging import getLogger from logging import getLogger
from os import path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from onnx import ModelProto, NodeProto, load, numpy_helper from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
from onnx.checker import check_model from onnx.external_data_helper import set_external_data
from onnx.external_data_helper import ( from onnxruntime import OrtValue
convert_model_to_external_data,
set_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from scipy import interpolate from scipy import interpolate
from ...server.context import ServerContext from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor from ..utils import load_tensor
logger = getLogger(__name__) logger = getLogger(__name__)
@ -161,39 +154,9 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
) )
def blend_loras( def blend_weights_loha(
_conversion: ServerContext, key: str, lora_prefix: str, lora_model: Dict, dtype
base_name: Union[str, ModelProto], ) -> Tuple[str, np.ndarray]:
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
xl: Optional[bool] = False,
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = torch.float32
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
if model_type == "text_encoder":
if model_index is None:
lora_prefix = "lora_te_"
else:
lora_prefix = f"lora_te{model_index}_"
else:
lora_prefix = f"lora_{model_type}_"
blended: Dict[str, np.ndarray] = {}
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue
for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "") base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
t1_key = key.replace("hada_w1_a", "hada_t1") t1_key = key.replace("hada_w1_a", "hada_t1")
@ -260,26 +223,18 @@ def blend_loras(
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight) weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim) np_weights = weights.numpy() * (alpha / dim)
np_weights *= lora_weight return base_key, np_weights
if base_key in blended:
logger.trace(
"summing LoHA weights: %s + %s", def blend_weights_lora(
blended[base_key].shape, key: str, lora_prefix: str, lora_model: Dict, dtype
np_weights.shape, ) -> Tuple[str, np.ndarray]:
)
blended[base_key] += sum_weights(blended[base_key], np_weights)
else:
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
mid_key = key.replace("lora_down", "lora_mid") mid_key = key.replace("lora_down", "lora_mid")
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha" alpha_key = key[: key.index("lora_down")] + "alpha"
logger.trace( logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
)
down_weight = lora_model[key].to(dtype=dtype) down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype) up_weight = lora_model[up_key].to(dtype=dtype)
@ -320,10 +275,7 @@ def blend_loras(
alpha, alpha,
) )
weights = ( weights = (
( (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2) .unsqueeze(2)
.unsqueeze(3) .unsqueeze(3)
) )
@ -341,15 +293,12 @@ def blend_loras(
mid_weight.shape, mid_weight.shape,
alpha, alpha,
) )
weights = torch.zeros( weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
for w in range(kernel[0]): for w in range(kernel[0]):
for h in range(kernel[1]): for h in range(kernel[1]):
weights[:, :, w, h] = ( weights[:, :, w, h] = (
up_weight.squeeze(3).squeeze(2) up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
@ mid_weight[:, :, w, h]
) @ down_weight.squeeze(3).squeeze(2) ) @ down_weight.squeeze(3).squeeze(2)
np_weights = weights.numpy() * (alpha / dim) np_weights = weights.numpy() * (alpha / dim)
@ -361,9 +310,7 @@ def blend_loras(
up_weight.shape, up_weight.shape,
alpha, alpha,
) )
weights = torch.zeros( weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
for w in range(kernel[0]): for w in range(kernel[0]):
for h in range(kernel[1]): for h in range(kernel[1]):
@ -371,8 +318,7 @@ def blend_loras(
up_w, up_h = kernel_slice(w, h, up_weight.shape) up_w, up_h = kernel_slice(w, h, up_weight.shape)
weights[:, :, w, h] = ( weights[:, :, w, h] = (
up_weight[:, :, up_w, up_h] up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
@ down_weight[:, :, down_w, down_h]
) )
np_weights = weights.numpy() * (alpha / dim) np_weights = weights.numpy() * (alpha / dim)
@ -382,17 +328,139 @@ def blend_loras(
base_key, base_key,
up_weight.shape[-2:], up_weight.shape[-2:],
) )
# TODO: should this be None?
np_weights = np.zeros((1, 1, 1, 1))
return base_key, np_weights
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
# blending
onnx_weights = numpy_helper.to_array(weight_node)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
onnx_weights = numpy_helper.to_array(matmul_node)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
t_weights = weights.transpose()
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
def blend_loras(
_conversion: ServerContext,
base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
xl: Optional[bool] = False,
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = torch.float32
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
if model_type == "text_encoder":
if model_index is None:
lora_prefix = "lora_te_"
else:
lora_prefix = f"lora_te{model_index}_"
else:
lora_prefix = f"lora_{model_type}_"
blended: Dict[str, np.ndarray] = {}
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue continue
np_weights *= lora_weight for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
base_key, np_weights = blend_weights_loha(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
if base_key in blended: if base_key in blended:
logger.trace( logger.trace(
"summing weights: %s + %s", "summing LoHA weights: %s + %s",
blended[base_key].shape, blended[base_key].shape,
np_weights.shape, np_weights.shape,
) )
blended[base_key] = sum_weights(blended[base_key], np_weights) blended[base_key] = sum_weights(blended[base_key], np_weights)
else: else:
logger.trace(
"adding LoHA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
base_key, np_weights = blend_weights_lora(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
if base_key in blended:
logger.trace(
"summing LoRA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights blended[base_key] = np_weights
# rewrite node names for XL # rewrite node names for XL
@ -400,7 +468,7 @@ def blend_loras(
nodes = list(base_model.graph.node) nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes) blended = fix_xl_names(blended, nodes)
logger.trace( logger.debug(
"updating %s of %s initializers", "updating %s of %s initializers",
len(blended.keys()), len(blended.keys()),
len(base_model.graph.initializer), len(base_model.graph.initializer),
@ -409,10 +477,7 @@ def blend_loras(
fixed_initializer_names = [ fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer fix_initializer_name(node.name) for node in base_model.graph.initializer
] ]
logger.trace("fixed initializer names: %s", fixed_initializer_names)
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
logger.trace("fixed node names: %s", fixed_node_names)
unmatched_keys = [] unmatched_keys = []
for base_key, weights in blended.items(): for base_key, weights in blended.items():
@ -421,9 +486,10 @@ def blend_loras(
matmul_key = base_key + "_MatMul" matmul_key = base_key + "_MatMul"
logger.trace( logger.trace(
"key %s has conv: %s, matmul: %s", "key %s has conv: %s, gemm: %s, matmul: %s",
base_key, base_key,
conv_key in fixed_node_names, conv_key in fixed_node_names,
gemm_key in fixed_node_names,
matmul_key in fixed_node_names, matmul_key in fixed_node_names,
) )
@ -449,38 +515,9 @@ def blend_loras(
weight_node = base_model.graph.initializer[weight_idx] weight_node = base_model.graph.initializer[weight_idx]
logger.trace("found weight initializer: %s", weight_node.name) logger.trace("found weight initializer: %s", weight_node.name)
# blending # replace the previous node
onnx_weights = numpy_helper.to_array(weight_node) updated_node = blend_node_conv_gemm(weight_node, weights)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), weight_node.name
)
del base_model.graph.initializer[weight_idx] del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node) base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names: elif matmul_key in fixed_node_names:
@ -497,36 +534,9 @@ def blend_loras(
matmul_node = base_model.graph.initializer[matmul_idx] matmul_node = base_model.graph.initializer[matmul_idx]
logger.trace("found matmul initializer: %s", matmul_node.name) logger.trace("found matmul initializer: %s", matmul_node.name)
# blending # replace the previous node
onnx_weights = numpy_helper.to_array(matmul_node) updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
t_weights = weights.transpose()
if (
weights.shape != onnx_weights.shape
and t_weights.shape != onnx_weights.shape
):
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.trace(
"blended weight shape: %s, %s", blended.shape, onnx_weights.dtype
)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), matmul_node.name
)
del base_model.graph.initializer[matmul_idx] del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node) base_model.graph.initializer.insert(matmul_idx, updated_node)
else: else:
@ -565,63 +575,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
logger.debug("weights after interpolation: %s", output.shape) logger.debug("weights after interpolation: %s", output.shape)
return output return output
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

74
api/scripts/onnx-lora.py Normal file
View File

@ -0,0 +1,74 @@
from argparse import ArgumentParser
from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from os import path
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, SessionOptions
from logging import getLogger
from onnx_web.convert.utils import ConversionContext
logger = getLogger(__name__)
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

View File

@ -1,11 +1,16 @@
import unittest import unittest
import numpy as np import numpy as np
import torch
from onnx import GraphProto, ModelProto, NodeProto from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import ( from onnx_web.convert.diffusion.lora import (
blend_loras, blend_loras,
blend_node_conv_gemm,
blend_node_matmul,
blend_weights_loha,
blend_weights_lora,
buffer_external_data_tensors, buffer_external_data_tensors,
fix_initializer_name, fix_initializer_name,
fix_node_name, fix_node_name,
@ -151,6 +156,23 @@ class KernelSliceTests(unittest.TestCase):
) )
class InterpToMatchTests(unittest.TestCase):
def test_same_shape(self):
ref = np.zeros((4, 4))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_one_dim(self):
ref = np.zeros((4, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_both_dims(self):
ref = np.zeros((2, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
class BlendLoRATests(unittest.TestCase): class BlendLoRATests(unittest.TestCase):
def test_blend_unet(self): def test_blend_unet(self):
""" """
@ -183,18 +205,131 @@ class BlendLoRATests(unittest.TestCase):
pass pass
class InterpToMatchTests(unittest.TestCase): class BlendWeightsLoHATests(unittest.TestCase):
def test_same_shape(self): def test_blend_t1_t2(self):
ref = np.zeros((4, 4)) # blend einsum: i j k l, j r, i p -> p r k l
resize = np.zeros((4, 4)) i = 32
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) j = 4
k = 1
l = 1
p = 2
r = 4
def test_different_one_dim(self): model = {
ref = np.zeros((4, 2)) "foo.hada_t1": torch.from_numpy(np.ones((i, j, k, l))),
resize = np.zeros((4, 4)) "foo.hada_t2": torch.from_numpy(np.ones((i, j, k, l))),
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) "foo.hada_w1_a": torch.from_numpy(np.ones((i, p))),
"foo.hada_w1_b": torch.from_numpy(np.ones((j, r))),
"foo.hada_w2_a": torch.from_numpy(np.ones((i, p))),
"foo.hada_w2_b": torch.from_numpy(np.ones((j, r))),
"foo.alpha": torch.tensor(1),
}
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (p, r, k, l))
def test_different_both_dims(self): def test_blend_w1_w2(self):
ref = np.zeros((2, 2)) model = {
resize = np.zeros((4, 4)) "foo.hada_w1_a": torch.from_numpy(np.ones((4, 1))),
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) "foo.hada_w1_b": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w2_a": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w2_b": torch.from_numpy(np.ones((1, 4))),
"foo.alpha": torch.tensor(1),
}
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_no_dim(self):
"""
model = {
"foo.hada_w1_a": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w1_b": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w2_a": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w2_b": torch.from_numpy(np.ones((4, 1))),
}
result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
"""
class BlendWeightsLoRATests(unittest.TestCase):
def test_blend_kernel_none(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_kernel_1x1(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 1, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 1, 1))
def test_blend_kernel_3x3(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 3, 3))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 3, 3))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 3, 3))
def test_blend_kernel_3x3_cp_decomp(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((2, 4, 1, 1))),
"foo.lora_mid": torch.from_numpy(np.ones((2, 2, 3, 3))),
"foo.lora_up": torch.from_numpy(np.ones((4, 2, 1, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 3, 3))
def test_blend_unknown(self):
pass
class BlendNodeConvGemmTests(unittest.TestCase):
def test_blend_kernel_1x1_and_1x1(self):
node = from_array(np.ones((4, 4, 1, 1)))
result = blend_node_conv_gemm(node, np.ones((4, 4, 1, 1)))
self.assertEqual(result.dims, [4, 4, 1, 1])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_kernel_1x1_and_none(self):
node = from_array(np.ones((4, 4, 1, 1)))
result = blend_node_conv_gemm(node, np.ones((4, 4)))
self.assertEqual(result.dims, [4, 4, 1, 1])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_other_matching(self):
node = from_array(np.ones((4, 4)))
result = blend_node_conv_gemm(node, np.ones((4, 4)))
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_other_mismatched(self):
pass
class BlendNodeMatMulTests(unittest.TestCase):
def test_blend_matching(self):
node = from_array(np.ones((4, 4)))
result = blend_node_matmul(node, np.ones((4, 4)), "test")
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_mismatched(self):
node = from_array(np.ones((4, 4)))
result = blend_node_matmul(node, np.ones((2, 2)), "test")
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)