2023-02-22 03:16:34 +00:00
|
|
|
from logging import getLogger
|
2023-08-26 04:33:41 +00:00
|
|
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
2023-02-22 03:16:34 +00:00
|
|
|
|
2023-03-14 23:00:26 +00:00
|
|
|
import numpy as np
|
2023-02-25 18:03:00 +00:00
|
|
|
import torch
|
2023-09-21 23:24:08 +00:00
|
|
|
from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
|
|
|
|
from onnx.external_data_helper import set_external_data
|
|
|
|
from onnxruntime import OrtValue
|
2023-08-26 04:36:30 +00:00
|
|
|
from scipy import interpolate
|
2023-02-25 13:40:51 +00:00
|
|
|
|
2023-03-16 00:27:29 +00:00
|
|
|
from ...server.context import ServerContext
|
2023-09-21 23:24:08 +00:00
|
|
|
from ..utils import load_tensor
|
2023-02-22 03:16:34 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
2023-04-21 01:06:43 +00:00
|
|
|
def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
|
|
logger.trace("summing weights with shapes: %s + %s", a.shape, b.shape)
|
|
|
|
|
2023-04-21 01:24:37 +00:00
|
|
|
# if they are the same, simply add them
|
|
|
|
if len(a.shape) == len(b.shape):
|
|
|
|
return a + b
|
|
|
|
|
2023-04-21 01:06:43 +00:00
|
|
|
# get the kernel size from the tensor with the higher rank
|
|
|
|
if len(a.shape) > len(b.shape):
|
|
|
|
kernel = a.shape[-2:]
|
|
|
|
hr = a
|
|
|
|
lr = b
|
|
|
|
else:
|
|
|
|
kernel = b.shape[-2:]
|
|
|
|
hr = b
|
|
|
|
lr = a
|
|
|
|
|
|
|
|
if kernel == (1, 1):
|
2023-09-15 00:35:48 +00:00
|
|
|
lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
|
2023-04-21 01:06:43 +00:00
|
|
|
|
|
|
|
return hr + lr
|
|
|
|
|
|
|
|
|
2023-03-15 04:32:47 +00:00
|
|
|
def buffer_external_data_tensors(
|
|
|
|
model: ModelProto,
|
|
|
|
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
|
2023-03-15 00:38:27 +00:00
|
|
|
external_data = []
|
|
|
|
for tensor in model.graph.initializer:
|
|
|
|
name = tensor.name
|
|
|
|
|
2023-03-26 16:09:13 +00:00
|
|
|
logger.trace("externalizing tensor: %s", name)
|
2023-03-15 00:38:27 +00:00
|
|
|
if tensor.HasField("raw_data"):
|
|
|
|
npt = numpy_helper.to_array(tensor)
|
|
|
|
orv = OrtValue.ortvalue_from_numpy(npt)
|
|
|
|
external_data.append((name, orv))
|
|
|
|
# mimic set_external_data
|
|
|
|
set_external_data(tensor, location="foo.bin")
|
|
|
|
tensor.name = name
|
|
|
|
tensor.ClearField("raw_data")
|
|
|
|
|
|
|
|
return (model, external_data)
|
|
|
|
|
|
|
|
|
2023-03-14 23:00:26 +00:00
|
|
|
def fix_initializer_name(key: str):
|
2023-03-12 18:38:51 +00:00
|
|
|
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
|
|
|
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
|
|
|
return key.replace(".", "_")
|
2023-02-22 05:50:27 +00:00
|
|
|
|
|
|
|
|
2023-03-14 23:00:26 +00:00
|
|
|
def fix_node_name(key: str):
|
|
|
|
fixed_name = fix_initializer_name(key.replace("/", "_"))
|
|
|
|
if fixed_name[0] == "_":
|
|
|
|
return fixed_name[1:]
|
|
|
|
else:
|
|
|
|
return fixed_name
|
2023-03-12 18:38:51 +00:00
|
|
|
|
|
|
|
|
2023-11-22 04:33:17 +00:00
|
|
|
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
|
2023-08-26 04:33:41 +00:00
|
|
|
fixed = {}
|
2023-11-24 22:51:03 +00:00
|
|
|
names = [fix_node_name(node.name) for node in nodes]
|
2023-08-26 04:33:41 +00:00
|
|
|
|
|
|
|
for key, value in keys.items():
|
2023-12-03 18:53:50 +00:00
|
|
|
root, *_rest = key.split(".")
|
2023-11-05 22:35:32 +00:00
|
|
|
logger.trace("fixing XL node name: %s -> %s", key, root)
|
2023-08-26 04:33:41 +00:00
|
|
|
|
2023-11-22 04:33:17 +00:00
|
|
|
simple = False
|
2023-08-26 04:33:41 +00:00
|
|
|
if root.startswith("input"):
|
|
|
|
block = "down_blocks"
|
|
|
|
elif root.startswith("middle"):
|
2023-08-26 04:36:30 +00:00
|
|
|
block = "mid_block" # not plural
|
2023-08-26 04:33:41 +00:00
|
|
|
elif root.startswith("output"):
|
|
|
|
block = "up_blocks"
|
|
|
|
elif root.startswith("text_model"):
|
|
|
|
block = "text_model"
|
2023-11-22 01:10:11 +00:00
|
|
|
elif root.startswith("down_blocks"):
|
2023-11-22 03:46:34 +00:00
|
|
|
block = "down_blocks"
|
2023-11-22 04:33:17 +00:00
|
|
|
simple = True
|
2023-11-22 03:46:34 +00:00
|
|
|
elif root.startswith("mid_block"):
|
|
|
|
block = "mid_block"
|
2023-11-22 04:33:17 +00:00
|
|
|
simple = True
|
2023-11-22 01:10:11 +00:00
|
|
|
elif root.startswith("up_blocks"):
|
2023-11-22 03:46:34 +00:00
|
|
|
block = "up_blocks"
|
2023-11-22 04:33:17 +00:00
|
|
|
simple = True
|
2023-08-26 04:33:41 +00:00
|
|
|
else:
|
|
|
|
logger.warning("unknown XL key name: %s", key)
|
|
|
|
fixed[key] = value
|
|
|
|
continue
|
|
|
|
|
|
|
|
suffix = None
|
2023-08-26 04:36:30 +00:00
|
|
|
for s in [
|
2023-11-22 04:33:17 +00:00
|
|
|
"conv",
|
|
|
|
"conv_shortcut",
|
|
|
|
"conv1",
|
|
|
|
"conv2",
|
2023-08-26 04:36:30 +00:00
|
|
|
"fc1",
|
|
|
|
"fc2",
|
|
|
|
"ff_net_0_proj",
|
|
|
|
"ff_net_2",
|
|
|
|
"proj",
|
|
|
|
"proj_in",
|
|
|
|
"proj_out",
|
|
|
|
"to_k",
|
|
|
|
"to_out_0",
|
|
|
|
"to_q",
|
|
|
|
"to_v",
|
|
|
|
]:
|
2023-08-26 04:33:41 +00:00
|
|
|
if root.endswith(s):
|
|
|
|
suffix = s
|
|
|
|
|
|
|
|
if suffix is None:
|
|
|
|
logger.warning("new XL key type: %s", root)
|
|
|
|
continue
|
|
|
|
|
2023-11-22 05:12:24 +00:00
|
|
|
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
2023-11-24 22:51:03 +00:00
|
|
|
match: Optional[str] = None
|
2023-11-22 04:33:17 +00:00
|
|
|
if "conv" in suffix:
|
2023-11-25 04:40:01 +00:00
|
|
|
match = next(node for node in names if node == f"{root}_Conv")
|
2023-11-22 04:33:17 +00:00
|
|
|
elif "time_emb_proj" in root:
|
2023-11-25 04:40:01 +00:00
|
|
|
match = next(node for node in names if node == f"{root}_Gemm")
|
2023-11-22 04:33:17 +00:00
|
|
|
elif block == "text_model" or simple:
|
2023-11-25 04:40:01 +00:00
|
|
|
match = next(node for node in names if node == f"{root}_MatMul")
|
2023-08-26 04:33:41 +00:00
|
|
|
else:
|
2023-11-22 04:33:17 +00:00
|
|
|
# search in order. one side has sparse indices, so they will not match.
|
2023-09-03 21:08:24 +00:00
|
|
|
match = next(
|
2023-08-26 04:36:30 +00:00
|
|
|
node
|
2023-11-24 22:51:03 +00:00
|
|
|
for node in names
|
|
|
|
if node.startswith(block)
|
|
|
|
and node.endswith(
|
2023-08-26 04:36:30 +00:00
|
|
|
f"{suffix}_MatMul"
|
|
|
|
) # needs to be fixed because some places use to_out.0
|
2023-09-03 21:08:24 +00:00
|
|
|
)
|
2023-08-26 04:33:41 +00:00
|
|
|
|
2023-09-03 21:08:24 +00:00
|
|
|
if match is None:
|
2023-08-26 04:33:41 +00:00
|
|
|
logger.warning("no matches for XL key: %s", root)
|
|
|
|
continue
|
2023-11-22 04:33:17 +00:00
|
|
|
else:
|
2023-11-24 22:51:03 +00:00
|
|
|
logger.trace("matched key: %s -> %s", key, match)
|
2023-11-22 04:33:17 +00:00
|
|
|
|
2023-11-24 22:51:03 +00:00
|
|
|
name = match
|
2023-11-22 05:12:24 +00:00
|
|
|
if name.endswith("_MatMul"):
|
|
|
|
name = name[:-7]
|
|
|
|
elif name.endswith("_Gemm"):
|
|
|
|
name = name[:-5]
|
|
|
|
elif name.endswith("_Conv"):
|
|
|
|
name = name[:-5]
|
|
|
|
|
2023-11-24 22:51:03 +00:00
|
|
|
logger.trace("matching XL key with node: %s -> %s, %s", key, match, name)
|
2023-08-26 04:33:41 +00:00
|
|
|
|
|
|
|
fixed[name] = value
|
2023-11-24 22:51:03 +00:00
|
|
|
names.remove(match)
|
2023-08-26 04:33:41 +00:00
|
|
|
|
2023-11-23 17:19:58 +00:00
|
|
|
logger.debug(
|
2023-11-25 04:40:01 +00:00
|
|
|
"SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining",
|
2023-11-23 17:19:58 +00:00
|
|
|
len(fixed.keys()),
|
2023-11-25 04:40:01 +00:00
|
|
|
len(keys.keys()),
|
2023-11-24 22:51:03 +00:00
|
|
|
len(names),
|
2023-11-23 17:19:58 +00:00
|
|
|
)
|
|
|
|
|
2023-11-22 04:33:17 +00:00
|
|
|
return fixed
|
2023-08-26 04:33:41 +00:00
|
|
|
|
|
|
|
|
2023-06-17 01:41:25 +00:00
|
|
|
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
|
|
|
return (
|
2023-06-17 01:49:43 +00:00
|
|
|
min(x, shape[2] - 1),
|
|
|
|
min(y, shape[3] - 1),
|
2023-06-17 01:41:25 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-09-21 23:24:08 +00:00
|
|
|
def blend_weights_loha(
|
|
|
|
key: str, lora_prefix: str, lora_model: Dict, dtype
|
|
|
|
) -> Tuple[str, np.ndarray]:
|
|
|
|
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
|
|
|
|
|
|
|
|
t1_key = key.replace("hada_w1_a", "hada_t1")
|
|
|
|
t2_key = key.replace("hada_w1_a", "hada_t2")
|
|
|
|
w1b_key = key.replace("hada_w1_a", "hada_w1_b")
|
|
|
|
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
|
|
|
|
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
|
|
|
|
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
|
|
|
|
logger.trace(
|
|
|
|
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
|
|
|
|
key,
|
|
|
|
w1b_key,
|
|
|
|
w2a_key,
|
|
|
|
w2b_key,
|
|
|
|
alpha_key,
|
|
|
|
)
|
|
|
|
|
|
|
|
w1a_weight = lora_model[key].to(dtype=dtype)
|
|
|
|
w1b_weight = lora_model[w1b_key].to(dtype=dtype)
|
|
|
|
w2a_weight = lora_model[w2a_key].to(dtype=dtype)
|
|
|
|
w2b_weight = lora_model[w2b_key].to(dtype=dtype)
|
|
|
|
|
|
|
|
t1_weight = lora_model.get(t1_key, None)
|
|
|
|
t2_weight = lora_model.get(t2_key, None)
|
|
|
|
|
|
|
|
dim = w1b_weight.size()[0]
|
|
|
|
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
|
|
|
|
|
|
|
if t1_weight is not None and t2_weight is not None:
|
|
|
|
t1_weight = t1_weight.to(dtype=dtype)
|
|
|
|
t2_weight = t2_weight.to(dtype=dtype)
|
|
|
|
|
|
|
|
logger.trace(
|
|
|
|
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
|
|
|
|
t1_weight.shape,
|
|
|
|
w1a_weight.shape,
|
|
|
|
w1b_weight.shape,
|
|
|
|
t2_weight.shape,
|
|
|
|
w2a_weight.shape,
|
|
|
|
w2b_weight.shape,
|
|
|
|
)
|
|
|
|
weights_1 = torch.einsum(
|
|
|
|
"i j k l, j r, i p -> p r k l",
|
|
|
|
t1_weight,
|
|
|
|
w1b_weight,
|
|
|
|
w1a_weight,
|
|
|
|
)
|
|
|
|
weights_2 = torch.einsum(
|
|
|
|
"i j k l, j r, i p -> p r k l",
|
|
|
|
t2_weight,
|
|
|
|
w2b_weight,
|
|
|
|
w2a_weight,
|
|
|
|
)
|
|
|
|
weights = weights_1 * weights_2
|
|
|
|
np_weights = weights.numpy() * (alpha / dim)
|
|
|
|
else:
|
|
|
|
logger.trace(
|
|
|
|
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
|
|
|
|
w1a_weight.shape,
|
|
|
|
w1b_weight.shape,
|
|
|
|
w2a_weight.shape,
|
|
|
|
w2b_weight.shape,
|
|
|
|
)
|
|
|
|
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
|
|
|
np_weights = weights.numpy() * (alpha / dim)
|
|
|
|
|
|
|
|
return base_key, np_weights
|
|
|
|
|
|
|
|
|
|
|
|
def blend_weights_lora(
|
|
|
|
key: str, lora_prefix: str, lora_model: Dict, dtype
|
|
|
|
) -> Tuple[str, np.ndarray]:
|
|
|
|
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
|
|
|
|
|
|
|
mid_key = key.replace("lora_down", "lora_mid")
|
|
|
|
up_key = key.replace("lora_down", "lora_up")
|
|
|
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
|
|
|
logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
|
|
|
|
|
|
|
|
down_weight = lora_model[key].to(dtype=dtype)
|
|
|
|
up_weight = lora_model[up_key].to(dtype=dtype)
|
|
|
|
|
|
|
|
mid_weight = None
|
|
|
|
if mid_key in lora_model:
|
|
|
|
mid_weight = lora_model[mid_key].to(dtype=dtype)
|
|
|
|
|
|
|
|
dim = down_weight.size()[0]
|
|
|
|
alpha = lora_model.get(alpha_key, dim)
|
|
|
|
|
|
|
|
if not isinstance(alpha, int):
|
|
|
|
alpha = alpha.to(dtype).numpy()
|
|
|
|
|
|
|
|
kernel = down_weight.shape[-2:]
|
|
|
|
if mid_weight is not None:
|
|
|
|
kernel = mid_weight.shape[-2:]
|
|
|
|
|
|
|
|
if len(down_weight.size()) == 2:
|
|
|
|
# blend for nn.Linear
|
|
|
|
logger.trace(
|
|
|
|
"blending weights for Linear node: (%s @ %s) * %s",
|
|
|
|
down_weight.shape,
|
|
|
|
up_weight.shape,
|
|
|
|
alpha,
|
|
|
|
)
|
|
|
|
weights = up_weight @ down_weight
|
|
|
|
np_weights = weights.numpy() * (alpha / dim)
|
|
|
|
elif len(down_weight.size()) == 4 and kernel == (
|
|
|
|
1,
|
|
|
|
1,
|
|
|
|
):
|
|
|
|
# blend for nn.Conv2d 1x1
|
|
|
|
logger.trace(
|
|
|
|
"blending weights for Conv 1x1 node: %s, %s, %s",
|
|
|
|
down_weight.shape,
|
|
|
|
up_weight.shape,
|
|
|
|
alpha,
|
|
|
|
)
|
|
|
|
weights = (
|
|
|
|
(up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
|
|
|
|
.unsqueeze(2)
|
|
|
|
.unsqueeze(3)
|
|
|
|
)
|
|
|
|
np_weights = weights.numpy() * (alpha / dim)
|
|
|
|
elif len(down_weight.size()) == 4 and kernel == (
|
|
|
|
3,
|
|
|
|
3,
|
|
|
|
):
|
|
|
|
if mid_weight is not None:
|
|
|
|
# blend for nn.Conv2d 3x3 with CP decomp
|
|
|
|
logger.trace(
|
|
|
|
"composing weights for Conv 3x3 node: %s, %s, %s, %s",
|
|
|
|
down_weight.shape,
|
|
|
|
up_weight.shape,
|
|
|
|
mid_weight.shape,
|
|
|
|
alpha,
|
|
|
|
)
|
|
|
|
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
|
|
|
|
|
|
|
for w in range(kernel[0]):
|
|
|
|
for h in range(kernel[1]):
|
|
|
|
weights[:, :, w, h] = (
|
|
|
|
up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
|
|
|
|
) @ down_weight.squeeze(3).squeeze(2)
|
|
|
|
|
|
|
|
np_weights = weights.numpy() * (alpha / dim)
|
|
|
|
else:
|
|
|
|
# blend for nn.Conv2d 3x3
|
|
|
|
logger.trace(
|
|
|
|
"blending weights for Conv 3x3 node: %s, %s, %s",
|
|
|
|
down_weight.shape,
|
|
|
|
up_weight.shape,
|
|
|
|
alpha,
|
|
|
|
)
|
|
|
|
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
|
|
|
|
|
|
|
for w in range(kernel[0]):
|
|
|
|
for h in range(kernel[1]):
|
|
|
|
down_w, down_h = kernel_slice(w, h, down_weight.shape)
|
|
|
|
up_w, up_h = kernel_slice(w, h, up_weight.shape)
|
|
|
|
|
|
|
|
weights[:, :, w, h] = (
|
|
|
|
up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
|
|
|
|
)
|
|
|
|
|
|
|
|
np_weights = weights.numpy() * (alpha / dim)
|
|
|
|
else:
|
|
|
|
logger.warning(
|
|
|
|
"unknown LoRA node type at %s: %s",
|
|
|
|
base_key,
|
|
|
|
up_weight.shape[-2:],
|
|
|
|
)
|
|
|
|
# TODO: should this be None?
|
|
|
|
np_weights = np.zeros((1, 1, 1, 1))
|
|
|
|
|
|
|
|
return base_key, np_weights
|
|
|
|
|
|
|
|
|
|
|
|
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
|
|
|
|
# blending
|
|
|
|
onnx_weights = numpy_helper.to_array(weight_node)
|
|
|
|
logger.trace(
|
|
|
|
"found blended weights for conv: %s, %s",
|
|
|
|
onnx_weights.shape,
|
|
|
|
weights.shape,
|
|
|
|
)
|
|
|
|
|
|
|
|
if onnx_weights.shape[-2:] == (1, 1):
|
|
|
|
if weights.shape[-2:] == (1, 1):
|
|
|
|
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
|
|
|
else:
|
|
|
|
blended = onnx_weights.squeeze((3, 2)) + weights
|
|
|
|
|
|
|
|
blended = np.expand_dims(blended, (2, 3))
|
|
|
|
else:
|
|
|
|
if onnx_weights.shape != weights.shape:
|
|
|
|
logger.warning(
|
|
|
|
"reshaping weights for mismatched Conv node: %s, %s",
|
|
|
|
onnx_weights.shape,
|
|
|
|
weights.shape,
|
|
|
|
)
|
|
|
|
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
|
|
|
|
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
|
|
|
else:
|
|
|
|
blended = onnx_weights + weights
|
|
|
|
|
|
|
|
logger.trace("blended weight shape: %s", blended.shape)
|
|
|
|
|
|
|
|
# replace the original initializer
|
|
|
|
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
|
|
|
|
|
|
|
|
|
|
|
|
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
|
|
|
|
onnx_weights = numpy_helper.to_array(matmul_node)
|
|
|
|
logger.trace(
|
|
|
|
"found blended weights for matmul: %s, %s",
|
|
|
|
weights.shape,
|
|
|
|
onnx_weights.shape,
|
|
|
|
)
|
|
|
|
|
|
|
|
t_weights = weights.transpose()
|
|
|
|
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
|
|
|
|
logger.warning(
|
|
|
|
"weight shapes do not match for %s: %s vs %s",
|
|
|
|
matmul_key,
|
|
|
|
weights.shape,
|
|
|
|
onnx_weights.shape,
|
|
|
|
)
|
|
|
|
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
|
|
|
|
|
|
|
blended = onnx_weights + t_weights
|
|
|
|
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
|
|
|
|
|
|
|
|
# replace the original initializer
|
|
|
|
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
|
|
|
|
|
|
|
|
|
2023-03-15 22:14:52 +00:00
|
|
|
def blend_loras(
|
2023-04-10 01:33:03 +00:00
|
|
|
_conversion: ServerContext,
|
2023-03-18 15:50:48 +00:00
|
|
|
base_name: Union[str, ModelProto],
|
|
|
|
loras: List[Tuple[str, float]],
|
|
|
|
model_type: Literal["text_encoder", "unet"],
|
2023-08-26 04:33:41 +00:00
|
|
|
model_index: Optional[int] = None,
|
2023-08-30 00:03:57 +00:00
|
|
|
xl: Optional[bool] = False,
|
2023-03-14 23:00:26 +00:00
|
|
|
):
|
2023-03-19 20:13:54 +00:00
|
|
|
# always load to CPU for blending
|
|
|
|
device = torch.device("cpu")
|
2023-03-22 03:05:14 +00:00
|
|
|
dtype = torch.float32
|
2023-03-19 20:13:54 +00:00
|
|
|
|
2023-03-15 22:14:52 +00:00
|
|
|
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
2023-03-19 20:13:54 +00:00
|
|
|
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
2023-03-12 18:38:51 +00:00
|
|
|
|
2023-03-18 15:50:48 +00:00
|
|
|
if model_type == "text_encoder":
|
2023-08-26 04:33:41 +00:00
|
|
|
if model_index is None:
|
|
|
|
lora_prefix = "lora_te_"
|
|
|
|
else:
|
|
|
|
lora_prefix = f"lora_te{model_index}_"
|
2023-03-12 18:38:51 +00:00
|
|
|
else:
|
2023-03-18 15:50:48 +00:00
|
|
|
lora_prefix = f"lora_{model_type}_"
|
2023-03-14 23:00:26 +00:00
|
|
|
|
2023-11-24 21:22:07 +00:00
|
|
|
layers = []
|
2023-03-18 15:50:48 +00:00
|
|
|
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
2023-03-26 16:09:13 +00:00
|
|
|
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
2023-03-19 20:38:43 +00:00
|
|
|
if lora_model is None:
|
|
|
|
logger.warning("unable to load tensor for LoRA")
|
|
|
|
continue
|
|
|
|
|
2023-11-24 21:22:07 +00:00
|
|
|
blended: Dict[str, np.ndarray] = {}
|
|
|
|
layers.append(blended)
|
|
|
|
|
2023-03-14 23:00:26 +00:00
|
|
|
for key in lora_model.keys():
|
2023-04-07 23:50:12 +00:00
|
|
|
if ".hada_w1_a" in key and lora_prefix in key:
|
|
|
|
# LoHA
|
2023-09-21 23:24:08 +00:00
|
|
|
base_key, np_weights = blend_weights_loha(
|
|
|
|
key, lora_prefix, lora_model, dtype
|
2023-04-07 23:50:12 +00:00
|
|
|
)
|
2023-09-21 23:24:08 +00:00
|
|
|
np_weights = np_weights * lora_weight
|
2023-11-24 21:22:07 +00:00
|
|
|
logger.trace(
|
|
|
|
"adding LoHA weights: %s",
|
|
|
|
np_weights.shape,
|
|
|
|
)
|
|
|
|
blended[base_key] = np_weights
|
2023-04-07 23:50:12 +00:00
|
|
|
elif ".lora_down" in key and lora_prefix in key:
|
|
|
|
# LoRA or LoCON
|
2023-09-21 23:24:08 +00:00
|
|
|
base_key, np_weights = blend_weights_lora(
|
|
|
|
key, lora_prefix, lora_model, dtype
|
2023-03-15 04:32:47 +00:00
|
|
|
)
|
2023-09-21 23:24:08 +00:00
|
|
|
np_weights = np_weights * lora_weight
|
2023-11-24 21:22:07 +00:00
|
|
|
logger.trace(
|
|
|
|
"adding LoRA weights: %s",
|
|
|
|
np_weights.shape,
|
|
|
|
)
|
|
|
|
blended[base_key] = np_weights
|
|
|
|
|
|
|
|
# rewrite node names for XL and flatten layers
|
2023-11-24 22:51:03 +00:00
|
|
|
weights: Dict[str, np.ndarray] = {}
|
2023-11-24 21:22:07 +00:00
|
|
|
|
|
|
|
for blended in layers:
|
|
|
|
if xl:
|
|
|
|
nodes = list(base_model.graph.node)
|
|
|
|
blended = fix_xl_names(blended, nodes)
|
|
|
|
|
|
|
|
for key, value in blended.items():
|
|
|
|
if key in weights:
|
|
|
|
weights[key] = sum_weights(weights[key], value)
|
|
|
|
else:
|
|
|
|
weights[key] = value
|
2023-03-14 23:00:26 +00:00
|
|
|
|
2023-11-22 03:46:34 +00:00
|
|
|
# fix node names once
|
|
|
|
fixed_initializer_names = [
|
|
|
|
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
|
|
|
]
|
|
|
|
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
|
|
|
|
2023-09-21 23:24:08 +00:00
|
|
|
logger.debug(
|
2023-08-30 00:03:57 +00:00
|
|
|
"updating %s of %s initializers",
|
2023-11-24 21:22:07 +00:00
|
|
|
len(weights.keys()),
|
2023-03-14 23:00:26 +00:00
|
|
|
len(base_model.graph.initializer),
|
|
|
|
)
|
|
|
|
|
2023-04-10 22:49:56 +00:00
|
|
|
unmatched_keys = []
|
2023-11-24 21:22:07 +00:00
|
|
|
for base_key, weights in weights.items():
|
2023-03-14 23:00:26 +00:00
|
|
|
conv_key = base_key + "_Conv"
|
2023-04-10 03:45:27 +00:00
|
|
|
gemm_key = base_key + "_Gemm"
|
2023-03-14 23:00:26 +00:00
|
|
|
matmul_key = base_key + "_MatMul"
|
|
|
|
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.trace(
|
2023-09-21 23:24:08 +00:00
|
|
|
"key %s has conv: %s, gemm: %s, matmul: %s",
|
2023-03-15 04:32:47 +00:00
|
|
|
base_key,
|
|
|
|
conv_key in fixed_node_names,
|
2023-09-21 23:24:08 +00:00
|
|
|
gemm_key in fixed_node_names,
|
2023-03-15 04:32:47 +00:00
|
|
|
matmul_key in fixed_node_names,
|
|
|
|
)
|
2023-03-14 23:00:26 +00:00
|
|
|
|
2023-04-10 03:45:27 +00:00
|
|
|
if conv_key in fixed_node_names or gemm_key in fixed_node_names:
|
|
|
|
if conv_key in fixed_node_names:
|
|
|
|
conv_idx = fixed_node_names.index(conv_key)
|
|
|
|
conv_node = base_model.graph.node[conv_idx]
|
2023-04-10 13:09:29 +00:00
|
|
|
logger.trace(
|
|
|
|
"found conv node %s using %s", conv_node.name, conv_node.input
|
|
|
|
)
|
2023-04-10 03:45:27 +00:00
|
|
|
else:
|
|
|
|
conv_idx = fixed_node_names.index(gemm_key)
|
|
|
|
conv_node = base_model.graph.node[conv_idx]
|
2023-04-10 13:09:29 +00:00
|
|
|
logger.trace(
|
|
|
|
"found gemm node %s using %s", conv_node.name, conv_node.input
|
|
|
|
)
|
2023-03-14 23:00:26 +00:00
|
|
|
|
|
|
|
# find weight initializer
|
|
|
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
|
|
|
weight_name = fix_initializer_name(weight_name)
|
|
|
|
|
|
|
|
weight_idx = fixed_initializer_names.index(weight_name)
|
|
|
|
weight_node = base_model.graph.initializer[weight_idx]
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.trace("found weight initializer: %s", weight_node.name)
|
2023-03-14 23:00:26 +00:00
|
|
|
|
2023-09-21 23:24:08 +00:00
|
|
|
# replace the previous node
|
|
|
|
updated_node = blend_node_conv_gemm(weight_node, weights)
|
2023-03-14 23:00:26 +00:00
|
|
|
|
|
|
|
del base_model.graph.initializer[weight_idx]
|
|
|
|
base_model.graph.initializer.insert(weight_idx, updated_node)
|
|
|
|
elif matmul_key in fixed_node_names:
|
|
|
|
weight_idx = fixed_node_names.index(matmul_key)
|
|
|
|
weight_node = base_model.graph.node[weight_idx]
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.trace(
|
|
|
|
"found matmul node %s using %s", weight_node.name, weight_node.input
|
|
|
|
)
|
2023-03-14 23:00:26 +00:00
|
|
|
|
|
|
|
# find the MatMul initializer
|
|
|
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
|
|
|
|
|
|
|
matmul_idx = fixed_initializer_names.index(matmul_name)
|
|
|
|
matmul_node = base_model.graph.initializer[matmul_idx]
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.trace("found matmul initializer: %s", matmul_node.name)
|
2023-03-14 23:00:26 +00:00
|
|
|
|
2023-09-21 23:24:08 +00:00
|
|
|
# replace the previous node
|
|
|
|
updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
|
2023-03-14 23:00:26 +00:00
|
|
|
|
|
|
|
del base_model.graph.initializer[matmul_idx]
|
|
|
|
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
2023-03-12 18:38:51 +00:00
|
|
|
else:
|
2023-04-10 22:49:56 +00:00
|
|
|
unmatched_keys.append(base_key)
|
2023-02-25 13:40:51 +00:00
|
|
|
|
2023-11-24 21:22:07 +00:00
|
|
|
logger.trace(
|
2023-03-15 02:27:23 +00:00
|
|
|
"node counts: %s -> %s, %s -> %s",
|
|
|
|
len(fixed_initializer_names),
|
|
|
|
len(base_model.graph.initializer),
|
|
|
|
len(fixed_node_names),
|
2023-03-15 04:32:47 +00:00
|
|
|
len(base_model.graph.node),
|
2023-03-15 02:27:23 +00:00
|
|
|
)
|
2023-02-25 13:40:51 +00:00
|
|
|
|
2023-04-10 22:49:56 +00:00
|
|
|
if len(unmatched_keys) > 0:
|
2023-11-24 21:22:07 +00:00
|
|
|
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
|
2023-04-10 22:49:56 +00:00
|
|
|
|
2023-03-15 02:27:23 +00:00
|
|
|
return base_model
|
2023-02-25 13:40:51 +00:00
|
|
|
|
|
|
|
|
2023-08-26 04:33:41 +00:00
|
|
|
def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
|
|
|
|
res_x = np.linspace(0, 1, resize.shape[0])
|
|
|
|
res_y = np.linspace(0, 1, resize.shape[1])
|
|
|
|
ref_x = np.linspace(0, 1, ref.shape[0])
|
|
|
|
ref_y = np.linspace(0, 1, ref.shape[1])
|
2023-08-26 04:36:30 +00:00
|
|
|
logger.debug(
|
|
|
|
"dims: %s, %s, %s, %s",
|
|
|
|
resize.shape[0],
|
|
|
|
resize.shape[1],
|
|
|
|
ref.shape[0],
|
|
|
|
ref.shape[1],
|
|
|
|
)
|
2023-08-26 04:33:41 +00:00
|
|
|
|
2023-08-26 04:36:30 +00:00
|
|
|
f = interpolate.RegularGridInterpolator((ref_x, ref_y), ref, method="linear")
|
2023-08-26 04:33:41 +00:00
|
|
|
xg, yg = np.meshgrid(res_x, res_y)
|
|
|
|
output = f((xg, yg))
|
|
|
|
logger.debug("weights after interpolation: %s", output.shape)
|
|
|
|
|
|
|
|
return output
|