1
0
Fork 0

fix(api): add theoretical support for 3x3 conv in LoRA

This commit is contained in:
Sean Sube 2023-03-15 19:37:17 -05:00
parent 8e8e230ffd
commit 315e5a3837
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 11 additions and 2 deletions

View File

@ -107,7 +107,7 @@ def blend_loras(
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1): elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
# blend for nn.Conv2d 1x1 # blend for nn.Conv2d 1x1
logger.debug( logger.debug(
"blending weights for Conv node: %s, %s, %s", "blending weights for Conv 1x1 node: %s, %s, %s",
down_weight.shape, down_weight.shape,
up_weight.shape, up_weight.shape,
alpha, alpha,
@ -121,8 +121,17 @@ def blend_loras(
.unsqueeze(3) .unsqueeze(3)
) )
np_weights = weights.numpy() * (alpha / dim) np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
# blend for nn.Conv2d 3x3
logger.debug(
"blending weights for Conv 3x3 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
np_weights = weights.numpy() * (alpha / dim)
else: else:
# TODO: add support for Conv2d 3x3
logger.warning( logger.warning(
"unknown LoRA node type at %s: %s", "unknown LoRA node type at %s: %s",
base_key, base_key,