fix(api): handle blending of mismatched kernels
This commit is contained in:
parent
f8d59ab65a
commit
719b34967f
|
@ -77,6 +77,13 @@ def fix_node_name(key: str):
|
|||
return fixed_name
|
||||
|
||||
|
||||
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
||||
return (
|
||||
max(x, shape[2]),
|
||||
max(y, shape[3]),
|
||||
)
|
||||
|
||||
|
||||
def blend_loras(
|
||||
_conversion: ServerContext,
|
||||
base_name: Union[str, ModelProto],
|
||||
|
@ -278,9 +285,12 @@ def blend_loras(
|
|||
|
||||
for w in range(kernel[0]):
|
||||
for h in range(kernel[1]):
|
||||
weights[:, :, w, h] = up_weight.squeeze(3).squeeze(
|
||||
down_w, down_h = kernel_slice(w, h, down_weight.shape)
|
||||
up_w, up_h = kernel_slice(w, h, up_weight.shape)
|
||||
|
||||
weights[:, :, w, h] = up_weight[:, :, up_w, up_h].squeeze(3).squeeze(
|
||||
2
|
||||
) @ down_weight.squeeze(3).squeeze(2)
|
||||
) @ down_weight[:, :, down_w, down_h].squeeze(3).squeeze(2)
|
||||
|
||||
np_weights = weights.numpy() * (alpha / dim)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue