From 7d2d865c19ab8f84b4e5cca5b800b4aec268c20c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 9 Apr 2023 23:12:32 -0500 Subject: [PATCH] convert LoHA T weights to same dtype, log shapes rather than data --- api/onnx_web/convert/diffusion/lora.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 478b50aa..22e259e4 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -111,14 +111,17 @@ def blend_loras( alpha = lora_model.get(alpha_key, dim).to(dtype).numpy() if t1_weight is not None and t2_weight is not None: + t1_weight = t1_weight.to(dtype=dtype) + t2_weight = t2_weight.to(dtype=dtype) + logger.trace( "composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)", - t1_weight, - w1a_weight, - w1b_weight, - t2_weight, - w2a_weight, - w2b_weight, + t1_weight.shape, + w1a_weight.shape, + w1b_weight.shape, + t2_weight.shape, + w2a_weight.shape, + w2b_weight.shape, ) weights_1 = torch.einsum('i j k l, j r, i p -> p r k l', t1_weight, w1b_weight, w1a_weight) weights_2 = torch.einsum('i j k l, j r, i p -> p r k l', t2_weight, w2b_weight, w2a_weight) @@ -127,10 +130,10 @@ def blend_loras( else: logger.trace( "blending weights for LoHA node: (%s @ %s) * (%s @ %s)", - w1a_weight, - w1b_weight, - w2a_weight, - w2b_weight, + w1a_weight.shape, + w1b_weight.shape, + w2a_weight.shape, + w2b_weight.shape, ) weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight) np_weights = weights.numpy() * (alpha / dim)