fix(api): add theoretical support for 3x3 conv in LoRA
This commit is contained in:
parent
8e8e230ffd
commit
315e5a3837
|
@ -107,7 +107,7 @@ def blend_loras(
|
|||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
||||
# blend for nn.Conv2d 1x1
|
||||
logger.debug(
|
||||
"blending weights for Conv node: %s, %s, %s",
|
||||
"blending weights for Conv 1x1 node: %s, %s, %s",
|
||||
down_weight.shape,
|
||||
up_weight.shape,
|
||||
alpha,
|
||||
|
@ -121,8 +121,17 @@ def blend_loras(
|
|||
.unsqueeze(3)
|
||||
)
|
||||
np_weights = weights.numpy() * (alpha / dim)
|
||||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
|
||||
# blend for nn.Conv2d 3x3
|
||||
logger.debug(
|
||||
"blending weights for Conv 3x3 node: %s, %s, %s",
|
||||
down_weight.shape,
|
||||
up_weight.shape,
|
||||
alpha,
|
||||
)
|
||||
weights = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
np_weights = weights.numpy() * (alpha / dim)
|
||||
else:
|
||||
# TODO: add support for Conv2d 3x3
|
||||
logger.warning(
|
||||
"unknown LoRA node type at %s: %s",
|
||||
base_key,
|
||||
|
|
Loading…
Reference in New Issue