fix(api): add theoretical support for 3x3 conv in LoRA
This commit is contained in:
parent
8e8e230ffd
commit
315e5a3837
|
@ -107,7 +107,7 @@ def blend_loras(
|
||||||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
||||||
# blend for nn.Conv2d 1x1
|
# blend for nn.Conv2d 1x1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"blending weights for Conv node: %s, %s, %s",
|
"blending weights for Conv 1x1 node: %s, %s, %s",
|
||||||
down_weight.shape,
|
down_weight.shape,
|
||||||
up_weight.shape,
|
up_weight.shape,
|
||||||
alpha,
|
alpha,
|
||||||
|
@ -121,8 +121,17 @@ def blend_loras(
|
||||||
.unsqueeze(3)
|
.unsqueeze(3)
|
||||||
)
|
)
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
|
||||||
|
# blend for nn.Conv2d 3x3
|
||||||
|
logger.debug(
|
||||||
|
"blending weights for Conv 3x3 node: %s, %s, %s",
|
||||||
|
down_weight.shape,
|
||||||
|
up_weight.shape,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
weights = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
else:
|
else:
|
||||||
# TODO: add support for Conv2d 3x3
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"unknown LoRA node type at %s: %s",
|
"unknown LoRA node type at %s: %s",
|
||||||
base_key,
|
base_key,
|
||||||
|
|
Loading…
Reference in New Issue