diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 29f51fc1..9fb16517 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -77,6 +77,13 @@ def fix_node_name(key: str): return fixed_name +def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]: + return ( + max(x, shape[2]), + max(y, shape[3]), + ) + + def blend_loras( _conversion: ServerContext, base_name: Union[str, ModelProto], @@ -278,9 +285,12 @@ def blend_loras( for w in range(kernel[0]): for h in range(kernel[1]): - weights[:, :, w, h] = up_weight.squeeze(3).squeeze( + down_w, down_h = kernel_slice(w, h, down_weight.shape) + up_w, up_h = kernel_slice(w, h, up_weight.shape) + + weights[:, :, w, h] = up_weight[:, :, up_w, up_h].squeeze(3).squeeze( 2 - ) @ down_weight.squeeze(3).squeeze(2) + ) @ down_weight[:, :, down_w, down_h].squeeze(3).squeeze(2) np_weights = weights.numpy() * (alpha / dim) else: