1
0
Fork 0

fix(api): test LoRA blending code

This commit is contained in:
Sean Sube 2023-09-21 18:24:08 -05:00
parent 761bfa8364
commit 52fdf4f48a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 488 additions and 329 deletions

View File

@ -1,22 +1,15 @@
from argparse import ArgumentParser
from logging import getLogger
from os import path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from onnx import ModelProto, NodeProto, load, numpy_helper
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
set_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import OrtValue
from scipy import interpolate
from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor
from ..utils import load_tensor
logger = getLogger(__name__)
@ -161,6 +154,245 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
)
def blend_weights_loha(
key: str, lora_prefix: str, lora_model: Dict, dtype
) -> Tuple[str, np.ndarray]:
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
t1_key = key.replace("hada_w1_a", "hada_t1")
t2_key = key.replace("hada_w1_a", "hada_t2")
w1b_key = key.replace("hada_w1_a", "hada_w1_b")
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
logger.trace(
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
key,
w1b_key,
w2a_key,
w2b_key,
alpha_key,
)
w1a_weight = lora_model[key].to(dtype=dtype)
w1b_weight = lora_model[w1b_key].to(dtype=dtype)
w2a_weight = lora_model[w2a_key].to(dtype=dtype)
w2b_weight = lora_model[w2b_key].to(dtype=dtype)
t1_weight = lora_model.get(t1_key, None)
t2_weight = lora_model.get(t2_key, None)
dim = w1b_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
if t1_weight is not None and t2_weight is not None:
t1_weight = t1_weight.to(dtype=dtype)
t2_weight = t2_weight.to(dtype=dtype)
logger.trace(
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
t1_weight.shape,
w1a_weight.shape,
w1b_weight.shape,
t2_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights_1 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t1_weight,
w1b_weight,
w1a_weight,
)
weights_2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t2_weight,
w2b_weight,
w2a_weight,
)
weights = weights_1 * weights_2
np_weights = weights.numpy() * (alpha / dim)
else:
logger.trace(
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
w1a_weight.shape,
w1b_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim)
return base_key, np_weights
def blend_weights_lora(
key: str, lora_prefix: str, lora_model: Dict, dtype
) -> Tuple[str, np.ndarray]:
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
mid_key = key.replace("lora_down", "lora_mid")
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype)
mid_weight = None
if mid_key in lora_model:
mid_weight = lora_model[mid_key].to(dtype=dtype)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key, dim)
if not isinstance(alpha, int):
alpha = alpha.to(dtype).numpy()
kernel = down_weight.shape[-2:]
if mid_weight is not None:
kernel = mid_weight.shape[-2:]
if len(down_weight.size()) == 2:
# blend for nn.Linear
logger.trace(
"blending weights for Linear node: (%s @ %s) * %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = up_weight @ down_weight
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
1,
1,
):
# blend for nn.Conv2d 1x1
logger.trace(
"blending weights for Conv 1x1 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = (
(up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
.unsqueeze(2)
.unsqueeze(3)
)
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
3,
3,
):
if mid_weight is not None:
# blend for nn.Conv2d 3x3 with CP decomp
logger.trace(
"composing weights for Conv 3x3 node: %s, %s, %s, %s",
down_weight.shape,
up_weight.shape,
mid_weight.shape,
alpha,
)
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
for w in range(kernel[0]):
for h in range(kernel[1]):
weights[:, :, w, h] = (
up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
) @ down_weight.squeeze(3).squeeze(2)
np_weights = weights.numpy() * (alpha / dim)
else:
# blend for nn.Conv2d 3x3
logger.trace(
"blending weights for Conv 3x3 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
for w in range(kernel[0]):
for h in range(kernel[1]):
down_w, down_h = kernel_slice(w, h, down_weight.shape)
up_w, up_h = kernel_slice(w, h, up_weight.shape)
weights[:, :, w, h] = (
up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
)
np_weights = weights.numpy() * (alpha / dim)
else:
logger.warning(
"unknown LoRA node type at %s: %s",
base_key,
up_weight.shape[-2:],
)
# TODO: should this be None?
np_weights = np.zeros((1, 1, 1, 1))
return base_key, np_weights
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
# blending
onnx_weights = numpy_helper.to_array(weight_node)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
onnx_weights = numpy_helper.to_array(matmul_node)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
t_weights = weights.transpose()
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
def blend_loras(
_conversion: ServerContext,
base_name: Union[str, ModelProto],
@ -194,205 +426,41 @@ def blend_loras(
for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
t1_key = key.replace("hada_w1_a", "hada_t1")
t2_key = key.replace("hada_w1_a", "hada_t2")
w1b_key = key.replace("hada_w1_a", "hada_w1_b")
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
logger.trace(
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
key,
w1b_key,
w2a_key,
w2b_key,
alpha_key,
base_key, np_weights = blend_weights_loha(
key, lora_prefix, lora_model, dtype
)
w1a_weight = lora_model[key].to(dtype=dtype)
w1b_weight = lora_model[w1b_key].to(dtype=dtype)
w2a_weight = lora_model[w2a_key].to(dtype=dtype)
w2b_weight = lora_model[w2b_key].to(dtype=dtype)
t1_weight = lora_model.get(t1_key, None)
t2_weight = lora_model.get(t2_key, None)
dim = w1b_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
if t1_weight is not None and t2_weight is not None:
t1_weight = t1_weight.to(dtype=dtype)
t2_weight = t2_weight.to(dtype=dtype)
logger.trace(
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
t1_weight.shape,
w1a_weight.shape,
w1b_weight.shape,
t2_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights_1 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t1_weight,
w1b_weight,
w1a_weight,
)
weights_2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t2_weight,
w2b_weight,
w2a_weight,
)
weights = weights_1 * weights_2
np_weights = weights.numpy() * (alpha / dim)
else:
logger.trace(
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
w1a_weight.shape,
w1b_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim)
np_weights *= lora_weight
np_weights = np_weights * lora_weight
if base_key in blended:
logger.trace(
"summing LoHA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] += sum_weights(blended[base_key], np_weights)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
logger.trace(
"adding LoHA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
mid_key = key.replace("lora_down", "lora_mid")
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.trace(
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
base_key, np_weights = blend_weights_lora(
key, lora_prefix, lora_model, dtype
)
down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype)
mid_weight = None
if mid_key in lora_model:
mid_weight = lora_model[mid_key].to(dtype=dtype)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key, dim)
if not isinstance(alpha, int):
alpha = alpha.to(dtype).numpy()
kernel = down_weight.shape[-2:]
if mid_weight is not None:
kernel = mid_weight.shape[-2:]
if len(down_weight.size()) == 2:
# blend for nn.Linear
logger.trace(
"blending weights for Linear node: (%s @ %s) * %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = up_weight @ down_weight
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
1,
1,
):
# blend for nn.Conv2d 1x1
logger.trace(
"blending weights for Conv 1x1 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
3,
3,
):
if mid_weight is not None:
# blend for nn.Conv2d 3x3 with CP decomp
logger.trace(
"composing weights for Conv 3x3 node: %s, %s, %s, %s",
down_weight.shape,
up_weight.shape,
mid_weight.shape,
alpha,
)
weights = torch.zeros(
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
for w in range(kernel[0]):
for h in range(kernel[1]):
weights[:, :, w, h] = (
up_weight.squeeze(3).squeeze(2)
@ mid_weight[:, :, w, h]
) @ down_weight.squeeze(3).squeeze(2)
np_weights = weights.numpy() * (alpha / dim)
else:
# blend for nn.Conv2d 3x3
logger.trace(
"blending weights for Conv 3x3 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = torch.zeros(
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
for w in range(kernel[0]):
for h in range(kernel[1]):
down_w, down_h = kernel_slice(w, h, down_weight.shape)
up_w, up_h = kernel_slice(w, h, up_weight.shape)
weights[:, :, w, h] = (
up_weight[:, :, up_w, up_h]
@ down_weight[:, :, down_w, down_h]
)
np_weights = weights.numpy() * (alpha / dim)
else:
logger.warning(
"unknown LoRA node type at %s: %s",
base_key,
up_weight.shape[-2:],
)
continue
np_weights *= lora_weight
np_weights = np_weights * lora_weight
if base_key in blended:
logger.trace(
"summing weights: %s + %s",
"summing LoRA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
# rewrite node names for XL
@ -400,7 +468,7 @@ def blend_loras(
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
logger.trace(
logger.debug(
"updating %s of %s initializers",
len(blended.keys()),
len(base_model.graph.initializer),
@ -409,10 +477,7 @@ def blend_loras(
fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer
]
logger.trace("fixed initializer names: %s", fixed_initializer_names)
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
logger.trace("fixed node names: %s", fixed_node_names)
unmatched_keys = []
for base_key, weights in blended.items():
@ -421,9 +486,10 @@ def blend_loras(
matmul_key = base_key + "_MatMul"
logger.trace(
"key %s has conv: %s, matmul: %s",
"key %s has conv: %s, gemm: %s, matmul: %s",
base_key,
conv_key in fixed_node_names,
gemm_key in fixed_node_names,
matmul_key in fixed_node_names,
)
@ -449,38 +515,9 @@ def blend_loras(
weight_node = base_model.graph.initializer[weight_idx]
logger.trace("found weight initializer: %s", weight_node.name)
# blending
onnx_weights = numpy_helper.to_array(weight_node)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
# replace the previous node
updated_node = blend_node_conv_gemm(weight_node, weights)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), weight_node.name
)
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names:
@ -497,36 +534,9 @@ def blend_loras(
matmul_node = base_model.graph.initializer[matmul_idx]
logger.trace("found matmul initializer: %s", matmul_node.name)
# blending
onnx_weights = numpy_helper.to_array(matmul_node)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
# replace the previous node
updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
t_weights = weights.transpose()
if (
weights.shape != onnx_weights.shape
and t_weights.shape != onnx_weights.shape
):
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.trace(
"blended weight shape: %s, %s", blended.shape, onnx_weights.dtype
)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), matmul_node.name
)
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:
@ -565,63 +575,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
logger.debug("weights after interpolation: %s", output.shape)
return output
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

74
api/scripts/onnx-lora.py Normal file
View File

@ -0,0 +1,74 @@
from argparse import ArgumentParser
from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from os import path
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, SessionOptions
from logging import getLogger
from onnx_web.convert.utils import ConversionContext
logger = getLogger(__name__)
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

View File

@ -1,11 +1,16 @@
import unittest
import numpy as np
import torch
from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array
from onnx_web.convert.diffusion.lora import (
blend_loras,
blend_node_conv_gemm,
blend_node_matmul,
blend_weights_loha,
blend_weights_lora,
buffer_external_data_tensors,
fix_initializer_name,
fix_node_name,
@ -151,6 +156,23 @@ class KernelSliceTests(unittest.TestCase):
)
class InterpToMatchTests(unittest.TestCase):
def test_same_shape(self):
ref = np.zeros((4, 4))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_one_dim(self):
ref = np.zeros((4, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_different_both_dims(self):
ref = np.zeros((2, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
class BlendLoRATests(unittest.TestCase):
def test_blend_unet(self):
"""
@ -183,18 +205,131 @@ class BlendLoRATests(unittest.TestCase):
pass
class InterpToMatchTests(unittest.TestCase):
def test_same_shape(self):
ref = np.zeros((4, 4))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
class BlendWeightsLoHATests(unittest.TestCase):
def test_blend_t1_t2(self):
# blend einsum: i j k l, j r, i p -> p r k l
i = 32
j = 4
k = 1
l = 1
p = 2
r = 4
def test_different_one_dim(self):
ref = np.zeros((4, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
model = {
"foo.hada_t1": torch.from_numpy(np.ones((i, j, k, l))),
"foo.hada_t2": torch.from_numpy(np.ones((i, j, k, l))),
"foo.hada_w1_a": torch.from_numpy(np.ones((i, p))),
"foo.hada_w1_b": torch.from_numpy(np.ones((j, r))),
"foo.hada_w2_a": torch.from_numpy(np.ones((i, p))),
"foo.hada_w2_b": torch.from_numpy(np.ones((j, r))),
"foo.alpha": torch.tensor(1),
}
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (p, r, k, l))
def test_different_both_dims(self):
ref = np.zeros((2, 2))
resize = np.zeros((4, 4))
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
def test_blend_w1_w2(self):
model = {
"foo.hada_w1_a": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w1_b": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w2_a": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w2_b": torch.from_numpy(np.ones((1, 4))),
"foo.alpha": torch.tensor(1),
}
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_no_dim(self):
"""
model = {
"foo.hada_w1_a": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w1_b": torch.from_numpy(np.ones((4, 1))),
"foo.hada_w2_a": torch.from_numpy(np.ones((1, 4))),
"foo.hada_w2_b": torch.from_numpy(np.ones((4, 1))),
}
result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
"""
class BlendWeightsLoRATests(unittest.TestCase):
def test_blend_kernel_none(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4))
def test_blend_kernel_1x1(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 1, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 1, 1))
def test_blend_kernel_3x3(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 3, 3))),
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 3, 3))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 3, 3))
def test_blend_kernel_3x3_cp_decomp(self):
model = {
"foo.lora_down": torch.from_numpy(np.ones((2, 4, 1, 1))),
"foo.lora_mid": torch.from_numpy(np.ones((2, 2, 3, 3))),
"foo.lora_up": torch.from_numpy(np.ones((4, 2, 1, 1))),
"foo.alpha": 1,
}
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
self.assertEqual(result.shape, (4, 4, 3, 3))
def test_blend_unknown(self):
pass
class BlendNodeConvGemmTests(unittest.TestCase):
def test_blend_kernel_1x1_and_1x1(self):
node = from_array(np.ones((4, 4, 1, 1)))
result = blend_node_conv_gemm(node, np.ones((4, 4, 1, 1)))
self.assertEqual(result.dims, [4, 4, 1, 1])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_kernel_1x1_and_none(self):
node = from_array(np.ones((4, 4, 1, 1)))
result = blend_node_conv_gemm(node, np.ones((4, 4)))
self.assertEqual(result.dims, [4, 4, 1, 1])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_other_matching(self):
node = from_array(np.ones((4, 4)))
result = blend_node_conv_gemm(node, np.ones((4, 4)))
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_other_mismatched(self):
pass
class BlendNodeMatMulTests(unittest.TestCase):
def test_blend_matching(self):
node = from_array(np.ones((4, 4)))
result = blend_node_matmul(node, np.ones((4, 4)), "test")
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
def test_blend_mismatched(self):
node = from_array(np.ones((4, 4)))
result = blend_node_matmul(node, np.ones((2, 2)), "test")
self.assertEqual(result.dims, [4, 4])
self.assertEqual(len(result.raw_data), 4 * 4 * 8)