1
0
Fork 0

fix(api): add theoretical support for 3x3 conv in LoRA

This commit is contained in:
Sean Sube 2023-03-15 19:37:17 -05:00
parent 8e8e230ffd
commit 315e5a3837
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 11 additions and 2 deletions

View File

@ -107,7 +107,7 @@ def blend_loras(
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
# blend for nn.Conv2d 1x1
logger.debug(
"blending weights for Conv node: %s, %s, %s",
"blending weights for Conv 1x1 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
@ -121,8 +121,17 @@ def blend_loras(
.unsqueeze(3)
)
np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
# blend for nn.Conv2d 3x3
logger.debug(
"blending weights for Conv 3x3 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
np_weights = weights.numpy() * (alpha / dim)
else:
# TODO: add support for Conv2d 3x3
logger.warning(
"unknown LoRA node type at %s: %s",
base_key,