1
0
Fork 0

fix(api): handle blending of mismatched kernels

This commit is contained in:
Sean Sube 2023-06-16 20:41:25 -05:00
parent f8d59ab65a
commit 719b34967f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 12 additions and 2 deletions

View File

@ -77,6 +77,13 @@ def fix_node_name(key: str):
return fixed_name
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
return (
max(x, shape[2]),
max(y, shape[3]),
)
def blend_loras(
_conversion: ServerContext,
base_name: Union[str, ModelProto],
@ -278,9 +285,12 @@ def blend_loras(
for w in range(kernel[0]):
for h in range(kernel[1]):
weights[:, :, w, h] = up_weight.squeeze(3).squeeze(
down_w, down_h = kernel_slice(w, h, down_weight.shape)
up_w, up_h = kernel_slice(w, h, up_weight.shape)
weights[:, :, w, h] = up_weight[:, :, up_w, up_h].squeeze(3).squeeze(
2
) @ down_weight.squeeze(3).squeeze(2)
) @ down_weight[:, :, down_w, down_h].squeeze(3).squeeze(2)
np_weights = weights.numpy() * (alpha / dim)
else: