1
0
Fork 0

convert LoHA T weights to same dtype, log shapes rather than data

This commit is contained in:
Sean Sube 2023-04-09 23:12:32 -05:00
parent 1ed51352c4
commit 7d2d865c19
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 10 deletions

View File

@ -111,14 +111,17 @@ def blend_loras(
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy() alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
if t1_weight is not None and t2_weight is not None: if t1_weight is not None and t2_weight is not None:
t1_weight = t1_weight.to(dtype=dtype)
t2_weight = t2_weight.to(dtype=dtype)
logger.trace( logger.trace(
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)", "composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
t1_weight, t1_weight.shape,
w1a_weight, w1a_weight.shape,
w1b_weight, w1b_weight.shape,
t2_weight, t2_weight.shape,
w2a_weight, w2a_weight.shape,
w2b_weight, w2b_weight.shape,
) )
weights_1 = torch.einsum('i j k l, j r, i p -> p r k l', t1_weight, w1b_weight, w1a_weight) weights_1 = torch.einsum('i j k l, j r, i p -> p r k l', t1_weight, w1b_weight, w1a_weight)
weights_2 = torch.einsum('i j k l, j r, i p -> p r k l', t2_weight, w2b_weight, w2a_weight) weights_2 = torch.einsum('i j k l, j r, i p -> p r k l', t2_weight, w2b_weight, w2a_weight)
@ -127,10 +130,10 @@ def blend_loras(
else: else:
logger.trace( logger.trace(
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)", "blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
w1a_weight, w1a_weight.shape,
w1b_weight, w1b_weight.shape,
w2a_weight, w2a_weight.shape,
w2b_weight, w2b_weight.shape,
) )
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight) weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim) np_weights = weights.numpy() * (alpha / dim)