From 315e5a383771b5875dbaec623520d980d1b6904d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 15 Mar 2023 19:37:17 -0500 Subject: [PATCH] fix(api): add theoretical support for 3x3 conv in LoRA --- api/onnx_web/convert/diffusion/lora.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 0c9edd30..62a02898 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -107,7 +107,7 @@ def blend_loras( elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1): # blend for nn.Conv2d 1x1 logger.debug( - "blending weights for Conv node: %s, %s, %s", + "blending weights for Conv 1x1 node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha, @@ -121,8 +121,17 @@ def blend_loras( .unsqueeze(3) ) np_weights = weights.numpy() * (alpha / dim) + elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3): + # blend for nn.Conv2d 3x3 + logger.debug( + "blending weights for Conv 3x3 node: %s, %s, %s", + down_weight.shape, + up_weight.shape, + alpha, + ) + weights = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + np_weights = weights.numpy() * (alpha / dim) else: - # TODO: add support for Conv2d 3x3 logger.warning( "unknown LoRA node type at %s: %s", base_key,