fix(api): test LoRA blending code
This commit is contained in:
parent
761bfa8364
commit
52fdf4f48a
|
@ -1,22 +1,15 @@
|
||||||
from argparse import ArgumentParser
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from onnx import ModelProto, NodeProto, load, numpy_helper
|
from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
|
||||||
from onnx.checker import check_model
|
from onnx.external_data_helper import set_external_data
|
||||||
from onnx.external_data_helper import (
|
from onnxruntime import OrtValue
|
||||||
convert_model_to_external_data,
|
|
||||||
set_external_data,
|
|
||||||
write_external_data_tensors,
|
|
||||||
)
|
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
|
||||||
from scipy import interpolate
|
from scipy import interpolate
|
||||||
|
|
||||||
from ...server.context import ServerContext
|
from ...server.context import ServerContext
|
||||||
from ..utils import ConversionContext, load_tensor
|
from ..utils import load_tensor
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -161,39 +154,9 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def blend_loras(
|
def blend_weights_loha(
|
||||||
_conversion: ServerContext,
|
key: str, lora_prefix: str, lora_model: Dict, dtype
|
||||||
base_name: Union[str, ModelProto],
|
) -> Tuple[str, np.ndarray]:
|
||||||
loras: List[Tuple[str, float]],
|
|
||||||
model_type: Literal["text_encoder", "unet"],
|
|
||||||
model_index: Optional[int] = None,
|
|
||||||
xl: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
# always load to CPU for blending
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
|
||||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
|
||||||
|
|
||||||
if model_type == "text_encoder":
|
|
||||||
if model_index is None:
|
|
||||||
lora_prefix = "lora_te_"
|
|
||||||
else:
|
|
||||||
lora_prefix = f"lora_te{model_index}_"
|
|
||||||
else:
|
|
||||||
lora_prefix = f"lora_{model_type}_"
|
|
||||||
|
|
||||||
blended: Dict[str, np.ndarray] = {}
|
|
||||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
|
||||||
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
|
||||||
if lora_model is None:
|
|
||||||
logger.warning("unable to load tensor for LoRA")
|
|
||||||
continue
|
|
||||||
|
|
||||||
for key in lora_model.keys():
|
|
||||||
if ".hada_w1_a" in key and lora_prefix in key:
|
|
||||||
# LoHA
|
|
||||||
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
|
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
|
||||||
|
|
||||||
t1_key = key.replace("hada_w1_a", "hada_t1")
|
t1_key = key.replace("hada_w1_a", "hada_t1")
|
||||||
|
@ -260,26 +223,18 @@ def blend_loras(
|
||||||
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
|
||||||
np_weights *= lora_weight
|
return base_key, np_weights
|
||||||
if base_key in blended:
|
|
||||||
logger.trace(
|
|
||||||
"summing LoHA weights: %s + %s",
|
def blend_weights_lora(
|
||||||
blended[base_key].shape,
|
key: str, lora_prefix: str, lora_model: Dict, dtype
|
||||||
np_weights.shape,
|
) -> Tuple[str, np.ndarray]:
|
||||||
)
|
|
||||||
blended[base_key] += sum_weights(blended[base_key], np_weights)
|
|
||||||
else:
|
|
||||||
blended[base_key] = np_weights
|
|
||||||
elif ".lora_down" in key and lora_prefix in key:
|
|
||||||
# LoRA or LoCON
|
|
||||||
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||||
|
|
||||||
mid_key = key.replace("lora_down", "lora_mid")
|
mid_key = key.replace("lora_down", "lora_mid")
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
logger.trace(
|
logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
|
||||||
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
|
|
||||||
)
|
|
||||||
|
|
||||||
down_weight = lora_model[key].to(dtype=dtype)
|
down_weight = lora_model[key].to(dtype=dtype)
|
||||||
up_weight = lora_model[up_key].to(dtype=dtype)
|
up_weight = lora_model[up_key].to(dtype=dtype)
|
||||||
|
@ -320,10 +275,7 @@ def blend_loras(
|
||||||
alpha,
|
alpha,
|
||||||
)
|
)
|
||||||
weights = (
|
weights = (
|
||||||
(
|
(up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
|
||||||
up_weight.squeeze(3).squeeze(2)
|
|
||||||
@ down_weight.squeeze(3).squeeze(2)
|
|
||||||
)
|
|
||||||
.unsqueeze(2)
|
.unsqueeze(2)
|
||||||
.unsqueeze(3)
|
.unsqueeze(3)
|
||||||
)
|
)
|
||||||
|
@ -341,15 +293,12 @@ def blend_loras(
|
||||||
mid_weight.shape,
|
mid_weight.shape,
|
||||||
alpha,
|
alpha,
|
||||||
)
|
)
|
||||||
weights = torch.zeros(
|
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
||||||
(up_weight.shape[0], down_weight.shape[1], *kernel)
|
|
||||||
)
|
|
||||||
|
|
||||||
for w in range(kernel[0]):
|
for w in range(kernel[0]):
|
||||||
for h in range(kernel[1]):
|
for h in range(kernel[1]):
|
||||||
weights[:, :, w, h] = (
|
weights[:, :, w, h] = (
|
||||||
up_weight.squeeze(3).squeeze(2)
|
up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
|
||||||
@ mid_weight[:, :, w, h]
|
|
||||||
) @ down_weight.squeeze(3).squeeze(2)
|
) @ down_weight.squeeze(3).squeeze(2)
|
||||||
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
@ -361,9 +310,7 @@ def blend_loras(
|
||||||
up_weight.shape,
|
up_weight.shape,
|
||||||
alpha,
|
alpha,
|
||||||
)
|
)
|
||||||
weights = torch.zeros(
|
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
||||||
(up_weight.shape[0], down_weight.shape[1], *kernel)
|
|
||||||
)
|
|
||||||
|
|
||||||
for w in range(kernel[0]):
|
for w in range(kernel[0]):
|
||||||
for h in range(kernel[1]):
|
for h in range(kernel[1]):
|
||||||
|
@ -371,8 +318,7 @@ def blend_loras(
|
||||||
up_w, up_h = kernel_slice(w, h, up_weight.shape)
|
up_w, up_h = kernel_slice(w, h, up_weight.shape)
|
||||||
|
|
||||||
weights[:, :, w, h] = (
|
weights[:, :, w, h] = (
|
||||||
up_weight[:, :, up_w, up_h]
|
up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
|
||||||
@ down_weight[:, :, down_w, down_h]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
@ -382,17 +328,139 @@ def blend_loras(
|
||||||
base_key,
|
base_key,
|
||||||
up_weight.shape[-2:],
|
up_weight.shape[-2:],
|
||||||
)
|
)
|
||||||
|
# TODO: should this be None?
|
||||||
|
np_weights = np.zeros((1, 1, 1, 1))
|
||||||
|
|
||||||
|
return base_key, np_weights
|
||||||
|
|
||||||
|
|
||||||
|
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
|
||||||
|
# blending
|
||||||
|
onnx_weights = numpy_helper.to_array(weight_node)
|
||||||
|
logger.trace(
|
||||||
|
"found blended weights for conv: %s, %s",
|
||||||
|
onnx_weights.shape,
|
||||||
|
weights.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
if onnx_weights.shape[-2:] == (1, 1):
|
||||||
|
if weights.shape[-2:] == (1, 1):
|
||||||
|
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||||
|
else:
|
||||||
|
blended = onnx_weights.squeeze((3, 2)) + weights
|
||||||
|
|
||||||
|
blended = np.expand_dims(blended, (2, 3))
|
||||||
|
else:
|
||||||
|
if onnx_weights.shape != weights.shape:
|
||||||
|
logger.warning(
|
||||||
|
"reshaping weights for mismatched Conv node: %s, %s",
|
||||||
|
onnx_weights.shape,
|
||||||
|
weights.shape,
|
||||||
|
)
|
||||||
|
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
|
||||||
|
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
||||||
|
else:
|
||||||
|
blended = onnx_weights + weights
|
||||||
|
|
||||||
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
|
# replace the original initializer
|
||||||
|
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
|
||||||
|
onnx_weights = numpy_helper.to_array(matmul_node)
|
||||||
|
logger.trace(
|
||||||
|
"found blended weights for matmul: %s, %s",
|
||||||
|
weights.shape,
|
||||||
|
onnx_weights.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_weights = weights.transpose()
|
||||||
|
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
|
||||||
|
logger.warning(
|
||||||
|
"weight shapes do not match for %s: %s vs %s",
|
||||||
|
matmul_key,
|
||||||
|
weights.shape,
|
||||||
|
onnx_weights.shape,
|
||||||
|
)
|
||||||
|
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
||||||
|
|
||||||
|
blended = onnx_weights + t_weights
|
||||||
|
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
|
||||||
|
|
||||||
|
# replace the original initializer
|
||||||
|
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_loras(
|
||||||
|
_conversion: ServerContext,
|
||||||
|
base_name: Union[str, ModelProto],
|
||||||
|
loras: List[Tuple[str, float]],
|
||||||
|
model_type: Literal["text_encoder", "unet"],
|
||||||
|
model_index: Optional[int] = None,
|
||||||
|
xl: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
# always load to CPU for blending
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||||
|
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||||
|
|
||||||
|
if model_type == "text_encoder":
|
||||||
|
if model_index is None:
|
||||||
|
lora_prefix = "lora_te_"
|
||||||
|
else:
|
||||||
|
lora_prefix = f"lora_te{model_index}_"
|
||||||
|
else:
|
||||||
|
lora_prefix = f"lora_{model_type}_"
|
||||||
|
|
||||||
|
blended: Dict[str, np.ndarray] = {}
|
||||||
|
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||||
|
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||||
|
if lora_model is None:
|
||||||
|
logger.warning("unable to load tensor for LoRA")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
np_weights *= lora_weight
|
for key in lora_model.keys():
|
||||||
|
if ".hada_w1_a" in key and lora_prefix in key:
|
||||||
|
# LoHA
|
||||||
|
base_key, np_weights = blend_weights_loha(
|
||||||
|
key, lora_prefix, lora_model, dtype
|
||||||
|
)
|
||||||
|
np_weights = np_weights * lora_weight
|
||||||
if base_key in blended:
|
if base_key in blended:
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"summing weights: %s + %s",
|
"summing LoHA weights: %s + %s",
|
||||||
blended[base_key].shape,
|
blended[base_key].shape,
|
||||||
np_weights.shape,
|
np_weights.shape,
|
||||||
)
|
)
|
||||||
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
||||||
else:
|
else:
|
||||||
|
logger.trace(
|
||||||
|
"adding LoHA weights: %s",
|
||||||
|
np_weights.shape,
|
||||||
|
)
|
||||||
|
blended[base_key] = np_weights
|
||||||
|
elif ".lora_down" in key and lora_prefix in key:
|
||||||
|
# LoRA or LoCON
|
||||||
|
base_key, np_weights = blend_weights_lora(
|
||||||
|
key, lora_prefix, lora_model, dtype
|
||||||
|
)
|
||||||
|
np_weights = np_weights * lora_weight
|
||||||
|
if base_key in blended:
|
||||||
|
logger.trace(
|
||||||
|
"summing LoRA weights: %s + %s",
|
||||||
|
blended[base_key].shape,
|
||||||
|
np_weights.shape,
|
||||||
|
)
|
||||||
|
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
||||||
|
else:
|
||||||
|
logger.trace(
|
||||||
|
"adding LoRA weights: %s",
|
||||||
|
np_weights.shape,
|
||||||
|
)
|
||||||
blended[base_key] = np_weights
|
blended[base_key] = np_weights
|
||||||
|
|
||||||
# rewrite node names for XL
|
# rewrite node names for XL
|
||||||
|
@ -400,7 +468,7 @@ def blend_loras(
|
||||||
nodes = list(base_model.graph.node)
|
nodes = list(base_model.graph.node)
|
||||||
blended = fix_xl_names(blended, nodes)
|
blended = fix_xl_names(blended, nodes)
|
||||||
|
|
||||||
logger.trace(
|
logger.debug(
|
||||||
"updating %s of %s initializers",
|
"updating %s of %s initializers",
|
||||||
len(blended.keys()),
|
len(blended.keys()),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
|
@ -409,10 +477,7 @@ def blend_loras(
|
||||||
fixed_initializer_names = [
|
fixed_initializer_names = [
|
||||||
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
||||||
]
|
]
|
||||||
logger.trace("fixed initializer names: %s", fixed_initializer_names)
|
|
||||||
|
|
||||||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||||
logger.trace("fixed node names: %s", fixed_node_names)
|
|
||||||
|
|
||||||
unmatched_keys = []
|
unmatched_keys = []
|
||||||
for base_key, weights in blended.items():
|
for base_key, weights in blended.items():
|
||||||
|
@ -421,9 +486,10 @@ def blend_loras(
|
||||||
matmul_key = base_key + "_MatMul"
|
matmul_key = base_key + "_MatMul"
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"key %s has conv: %s, matmul: %s",
|
"key %s has conv: %s, gemm: %s, matmul: %s",
|
||||||
base_key,
|
base_key,
|
||||||
conv_key in fixed_node_names,
|
conv_key in fixed_node_names,
|
||||||
|
gemm_key in fixed_node_names,
|
||||||
matmul_key in fixed_node_names,
|
matmul_key in fixed_node_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -449,38 +515,9 @@ def blend_loras(
|
||||||
weight_node = base_model.graph.initializer[weight_idx]
|
weight_node = base_model.graph.initializer[weight_idx]
|
||||||
logger.trace("found weight initializer: %s", weight_node.name)
|
logger.trace("found weight initializer: %s", weight_node.name)
|
||||||
|
|
||||||
# blending
|
# replace the previous node
|
||||||
onnx_weights = numpy_helper.to_array(weight_node)
|
updated_node = blend_node_conv_gemm(weight_node, weights)
|
||||||
logger.trace(
|
|
||||||
"found blended weights for conv: %s, %s",
|
|
||||||
onnx_weights.shape,
|
|
||||||
weights.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
if onnx_weights.shape[-2:] == (1, 1):
|
|
||||||
if weights.shape[-2:] == (1, 1):
|
|
||||||
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
|
||||||
else:
|
|
||||||
blended = onnx_weights.squeeze((3, 2)) + weights
|
|
||||||
|
|
||||||
blended = np.expand_dims(blended, (2, 3))
|
|
||||||
else:
|
|
||||||
if onnx_weights.shape != weights.shape:
|
|
||||||
logger.warning(
|
|
||||||
"reshaping weights for mismatched Conv node: %s, %s",
|
|
||||||
onnx_weights.shape,
|
|
||||||
weights.shape,
|
|
||||||
)
|
|
||||||
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
|
||||||
else:
|
|
||||||
blended = onnx_weights + weights
|
|
||||||
|
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
|
||||||
|
|
||||||
# replace the original initializer
|
|
||||||
updated_node = numpy_helper.from_array(
|
|
||||||
blended.astype(onnx_weights.dtype), weight_node.name
|
|
||||||
)
|
|
||||||
del base_model.graph.initializer[weight_idx]
|
del base_model.graph.initializer[weight_idx]
|
||||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||||
elif matmul_key in fixed_node_names:
|
elif matmul_key in fixed_node_names:
|
||||||
|
@ -497,36 +534,9 @@ def blend_loras(
|
||||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||||
logger.trace("found matmul initializer: %s", matmul_node.name)
|
logger.trace("found matmul initializer: %s", matmul_node.name)
|
||||||
|
|
||||||
# blending
|
# replace the previous node
|
||||||
onnx_weights = numpy_helper.to_array(matmul_node)
|
updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
|
||||||
logger.trace(
|
|
||||||
"found blended weights for matmul: %s, %s",
|
|
||||||
weights.shape,
|
|
||||||
onnx_weights.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
t_weights = weights.transpose()
|
|
||||||
if (
|
|
||||||
weights.shape != onnx_weights.shape
|
|
||||||
and t_weights.shape != onnx_weights.shape
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"weight shapes do not match for %s: %s vs %s",
|
|
||||||
matmul_key,
|
|
||||||
weights.shape,
|
|
||||||
onnx_weights.shape,
|
|
||||||
)
|
|
||||||
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
|
||||||
|
|
||||||
blended = onnx_weights + t_weights
|
|
||||||
logger.trace(
|
|
||||||
"blended weight shape: %s, %s", blended.shape, onnx_weights.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# replace the original initializer
|
|
||||||
updated_node = numpy_helper.from_array(
|
|
||||||
blended.astype(onnx_weights.dtype), matmul_node.name
|
|
||||||
)
|
|
||||||
del base_model.graph.initializer[matmul_idx]
|
del base_model.graph.initializer[matmul_idx]
|
||||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||||
else:
|
else:
|
||||||
|
@ -565,63 +575,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
|
||||||
logger.debug("weights after interpolation: %s", output.shape)
|
logger.debug("weights after interpolation: %s", output.shape)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
context = ConversionContext.from_environ()
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument("--base", type=str)
|
|
||||||
parser.add_argument("--dest", type=str)
|
|
||||||
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
|
||||||
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
|
|
||||||
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
logger.info(
|
|
||||||
"merging %s with %s with weights: %s",
|
|
||||||
args.lora_models,
|
|
||||||
args.base,
|
|
||||||
args.lora_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
default_weight = 1.0 / len(args.lora_models)
|
|
||||||
while len(args.lora_weights) < len(args.lora_models):
|
|
||||||
args.lora_weights.append(default_weight)
|
|
||||||
|
|
||||||
blend_model = blend_loras(
|
|
||||||
context,
|
|
||||||
args.base,
|
|
||||||
list(zip(args.lora_models, args.lora_weights)),
|
|
||||||
args.type,
|
|
||||||
)
|
|
||||||
if args.dest is None or args.dest == "" or args.dest == ":load":
|
|
||||||
# convert to external data and save to memory
|
|
||||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
|
||||||
logger.info("saved external data for %s nodes", len(external_data))
|
|
||||||
|
|
||||||
external_names, external_values = zip(*external_data)
|
|
||||||
opts = SessionOptions()
|
|
||||||
opts.add_external_initializers(list(external_names), list(external_values))
|
|
||||||
sess = InferenceSession(
|
|
||||||
bare_model.SerializeToString(),
|
|
||||||
sess_options=opts,
|
|
||||||
providers=["CPUExecutionProvider"],
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
convert_model_to_external_data(
|
|
||||||
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
|
|
||||||
)
|
|
||||||
bare_model = write_external_data_tensors(blend_model, args.dest)
|
|
||||||
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
|
||||||
|
|
||||||
with open(dest_file, "w+b") as model_file:
|
|
||||||
model_file.write(bare_model.SerializeToString())
|
|
||||||
|
|
||||||
logger.info("successfully saved blended model: %s", dest_file)
|
|
||||||
|
|
||||||
check_model(dest_file)
|
|
||||||
|
|
||||||
logger.info("checked blended model")
|
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||||
|
from os import path
|
||||||
|
from onnx.checker import check_model
|
||||||
|
from onnx.external_data_helper import (
|
||||||
|
convert_model_to_external_data,
|
||||||
|
write_external_data_tensors,
|
||||||
|
)
|
||||||
|
from onnxruntime import InferenceSession, SessionOptions
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from onnx_web.convert.utils import ConversionContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
context = ConversionContext.from_environ()
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--base", type=str)
|
||||||
|
parser.add_argument("--dest", type=str)
|
||||||
|
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
||||||
|
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
|
||||||
|
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info(
|
||||||
|
"merging %s with %s with weights: %s",
|
||||||
|
args.lora_models,
|
||||||
|
args.base,
|
||||||
|
args.lora_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
default_weight = 1.0 / len(args.lora_models)
|
||||||
|
while len(args.lora_weights) < len(args.lora_models):
|
||||||
|
args.lora_weights.append(default_weight)
|
||||||
|
|
||||||
|
blend_model = blend_loras(
|
||||||
|
context,
|
||||||
|
args.base,
|
||||||
|
list(zip(args.lora_models, args.lora_weights)),
|
||||||
|
args.type,
|
||||||
|
)
|
||||||
|
if args.dest is None or args.dest == "" or args.dest == ":load":
|
||||||
|
# convert to external data and save to memory
|
||||||
|
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||||
|
logger.info("saved external data for %s nodes", len(external_data))
|
||||||
|
|
||||||
|
external_names, external_values = zip(*external_data)
|
||||||
|
opts = SessionOptions()
|
||||||
|
opts.add_external_initializers(list(external_names), list(external_values))
|
||||||
|
sess = InferenceSession(
|
||||||
|
bare_model.SerializeToString(),
|
||||||
|
sess_options=opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
convert_model_to_external_data(
|
||||||
|
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
|
||||||
|
)
|
||||||
|
bare_model = write_external_data_tensors(blend_model, args.dest)
|
||||||
|
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
||||||
|
|
||||||
|
with open(dest_file, "w+b") as model_file:
|
||||||
|
model_file.write(bare_model.SerializeToString())
|
||||||
|
|
||||||
|
logger.info("successfully saved blended model: %s", dest_file)
|
||||||
|
|
||||||
|
check_model(dest_file)
|
||||||
|
|
||||||
|
logger.info("checked blended model")
|
|
@ -1,11 +1,16 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from onnx import GraphProto, ModelProto, NodeProto
|
from onnx import GraphProto, ModelProto, NodeProto
|
||||||
from onnx.numpy_helper import from_array
|
from onnx.numpy_helper import from_array
|
||||||
|
|
||||||
from onnx_web.convert.diffusion.lora import (
|
from onnx_web.convert.diffusion.lora import (
|
||||||
blend_loras,
|
blend_loras,
|
||||||
|
blend_node_conv_gemm,
|
||||||
|
blend_node_matmul,
|
||||||
|
blend_weights_loha,
|
||||||
|
blend_weights_lora,
|
||||||
buffer_external_data_tensors,
|
buffer_external_data_tensors,
|
||||||
fix_initializer_name,
|
fix_initializer_name,
|
||||||
fix_node_name,
|
fix_node_name,
|
||||||
|
@ -151,6 +156,23 @@ class KernelSliceTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InterpToMatchTests(unittest.TestCase):
|
||||||
|
def test_same_shape(self):
|
||||||
|
ref = np.zeros((4, 4))
|
||||||
|
resize = np.zeros((4, 4))
|
||||||
|
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
||||||
|
|
||||||
|
def test_different_one_dim(self):
|
||||||
|
ref = np.zeros((4, 2))
|
||||||
|
resize = np.zeros((4, 4))
|
||||||
|
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
||||||
|
|
||||||
|
def test_different_both_dims(self):
|
||||||
|
ref = np.zeros((2, 2))
|
||||||
|
resize = np.zeros((4, 4))
|
||||||
|
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
||||||
|
|
||||||
|
|
||||||
class BlendLoRATests(unittest.TestCase):
|
class BlendLoRATests(unittest.TestCase):
|
||||||
def test_blend_unet(self):
|
def test_blend_unet(self):
|
||||||
"""
|
"""
|
||||||
|
@ -183,18 +205,131 @@ class BlendLoRATests(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InterpToMatchTests(unittest.TestCase):
|
class BlendWeightsLoHATests(unittest.TestCase):
|
||||||
def test_same_shape(self):
|
def test_blend_t1_t2(self):
|
||||||
ref = np.zeros((4, 4))
|
# blend einsum: i j k l, j r, i p -> p r k l
|
||||||
resize = np.zeros((4, 4))
|
i = 32
|
||||||
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
j = 4
|
||||||
|
k = 1
|
||||||
|
l = 1
|
||||||
|
p = 2
|
||||||
|
r = 4
|
||||||
|
|
||||||
def test_different_one_dim(self):
|
model = {
|
||||||
ref = np.zeros((4, 2))
|
"foo.hada_t1": torch.from_numpy(np.ones((i, j, k, l))),
|
||||||
resize = np.zeros((4, 4))
|
"foo.hada_t2": torch.from_numpy(np.ones((i, j, k, l))),
|
||||||
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
"foo.hada_w1_a": torch.from_numpy(np.ones((i, p))),
|
||||||
|
"foo.hada_w1_b": torch.from_numpy(np.ones((j, r))),
|
||||||
|
"foo.hada_w2_a": torch.from_numpy(np.ones((i, p))),
|
||||||
|
"foo.hada_w2_b": torch.from_numpy(np.ones((j, r))),
|
||||||
|
"foo.alpha": torch.tensor(1),
|
||||||
|
}
|
||||||
|
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
|
||||||
|
self.assertEqual(result.shape, (p, r, k, l))
|
||||||
|
|
||||||
def test_different_both_dims(self):
|
def test_blend_w1_w2(self):
|
||||||
ref = np.zeros((2, 2))
|
model = {
|
||||||
resize = np.zeros((4, 4))
|
"foo.hada_w1_a": torch.from_numpy(np.ones((4, 1))),
|
||||||
self.assertEqual(interp_to_match(ref, resize).shape, (4, 4))
|
"foo.hada_w1_b": torch.from_numpy(np.ones((1, 4))),
|
||||||
|
"foo.hada_w2_a": torch.from_numpy(np.ones((4, 1))),
|
||||||
|
"foo.hada_w2_b": torch.from_numpy(np.ones((1, 4))),
|
||||||
|
"foo.alpha": torch.tensor(1),
|
||||||
|
}
|
||||||
|
key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
|
||||||
|
self.assertEqual(result.shape, (4, 4))
|
||||||
|
|
||||||
|
def test_blend_no_dim(self):
|
||||||
|
"""
|
||||||
|
model = {
|
||||||
|
"foo.hada_w1_a": torch.from_numpy(np.ones((1, 4))),
|
||||||
|
"foo.hada_w1_b": torch.from_numpy(np.ones((4, 1))),
|
||||||
|
"foo.hada_w2_a": torch.from_numpy(np.ones((1, 4))),
|
||||||
|
"foo.hada_w2_b": torch.from_numpy(np.ones((4, 1))),
|
||||||
|
}
|
||||||
|
result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32)
|
||||||
|
self.assertEqual(result.shape, (4, 4))
|
||||||
|
"""
|
||||||
|
|
||||||
|
class BlendWeightsLoRATests(unittest.TestCase):
|
||||||
|
def test_blend_kernel_none(self):
|
||||||
|
model = {
|
||||||
|
"foo.lora_down": torch.from_numpy(np.ones((1, 4))),
|
||||||
|
"foo.lora_up": torch.from_numpy(np.ones((4, 1))),
|
||||||
|
"foo.alpha": 1,
|
||||||
|
}
|
||||||
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
||||||
|
self.assertEqual(result.shape, (4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def test_blend_kernel_1x1(self):
|
||||||
|
model = {
|
||||||
|
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))),
|
||||||
|
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 1, 1))),
|
||||||
|
"foo.alpha": 1,
|
||||||
|
}
|
||||||
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
||||||
|
self.assertEqual(result.shape, (4, 4, 1, 1))
|
||||||
|
|
||||||
|
def test_blend_kernel_3x3(self):
|
||||||
|
model = {
|
||||||
|
"foo.lora_down": torch.from_numpy(np.ones((1, 4, 3, 3))),
|
||||||
|
"foo.lora_up": torch.from_numpy(np.ones((4, 1, 3, 3))),
|
||||||
|
"foo.alpha": 1,
|
||||||
|
}
|
||||||
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
||||||
|
self.assertEqual(result.shape, (4, 4, 3, 3))
|
||||||
|
|
||||||
|
def test_blend_kernel_3x3_cp_decomp(self):
|
||||||
|
model = {
|
||||||
|
"foo.lora_down": torch.from_numpy(np.ones((2, 4, 1, 1))),
|
||||||
|
"foo.lora_mid": torch.from_numpy(np.ones((2, 2, 3, 3))),
|
||||||
|
"foo.lora_up": torch.from_numpy(np.ones((4, 2, 1, 1))),
|
||||||
|
"foo.alpha": 1,
|
||||||
|
}
|
||||||
|
key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32)
|
||||||
|
self.assertEqual(result.shape, (4, 4, 3, 3))
|
||||||
|
|
||||||
|
def test_blend_unknown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlendNodeConvGemmTests(unittest.TestCase):
|
||||||
|
def test_blend_kernel_1x1_and_1x1(self):
|
||||||
|
node = from_array(np.ones((4, 4, 1, 1)))
|
||||||
|
result = blend_node_conv_gemm(node, np.ones((4, 4, 1, 1)))
|
||||||
|
|
||||||
|
self.assertEqual(result.dims, [4, 4, 1, 1])
|
||||||
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
||||||
|
|
||||||
|
def test_blend_kernel_1x1_and_none(self):
|
||||||
|
node = from_array(np.ones((4, 4, 1, 1)))
|
||||||
|
result = blend_node_conv_gemm(node, np.ones((4, 4)))
|
||||||
|
|
||||||
|
self.assertEqual(result.dims, [4, 4, 1, 1])
|
||||||
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
||||||
|
|
||||||
|
def test_blend_other_matching(self):
|
||||||
|
node = from_array(np.ones((4, 4)))
|
||||||
|
result = blend_node_conv_gemm(node, np.ones((4, 4)))
|
||||||
|
|
||||||
|
self.assertEqual(result.dims, [4, 4])
|
||||||
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
||||||
|
|
||||||
|
def test_blend_other_mismatched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlendNodeMatMulTests(unittest.TestCase):
|
||||||
|
def test_blend_matching(self):
|
||||||
|
node = from_array(np.ones((4, 4)))
|
||||||
|
result = blend_node_matmul(node, np.ones((4, 4)), "test")
|
||||||
|
|
||||||
|
self.assertEqual(result.dims, [4, 4])
|
||||||
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
||||||
|
|
||||||
|
def test_blend_mismatched(self):
|
||||||
|
node = from_array(np.ones((4, 4)))
|
||||||
|
result = blend_node_matmul(node, np.ones((2, 2)), "test")
|
||||||
|
|
||||||
|
self.assertEqual(result.dims, [4, 4])
|
||||||
|
self.assertEqual(len(result.raw_data), 4 * 4 * 8)
|
||||||
|
|
Loading…
Reference in New Issue