convert LoHA T weights to same dtype, log shapes rather than data
This commit is contained in:
parent
1ed51352c4
commit
7d2d865c19
|
@ -111,14 +111,17 @@ def blend_loras(
|
|||
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
||||
|
||||
if t1_weight is not None and t2_weight is not None:
|
||||
t1_weight = t1_weight.to(dtype=dtype)
|
||||
t2_weight = t2_weight.to(dtype=dtype)
|
||||
|
||||
logger.trace(
|
||||
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
|
||||
t1_weight,
|
||||
w1a_weight,
|
||||
w1b_weight,
|
||||
t2_weight,
|
||||
w2a_weight,
|
||||
w2b_weight,
|
||||
t1_weight.shape,
|
||||
w1a_weight.shape,
|
||||
w1b_weight.shape,
|
||||
t2_weight.shape,
|
||||
w2a_weight.shape,
|
||||
w2b_weight.shape,
|
||||
)
|
||||
weights_1 = torch.einsum('i j k l, j r, i p -> p r k l', t1_weight, w1b_weight, w1a_weight)
|
||||
weights_2 = torch.einsum('i j k l, j r, i p -> p r k l', t2_weight, w2b_weight, w2a_weight)
|
||||
|
@ -127,10 +130,10 @@ def blend_loras(
|
|||
else:
|
||||
logger.trace(
|
||||
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
|
||||
w1a_weight,
|
||||
w1b_weight,
|
||||
w2a_weight,
|
||||
w2b_weight,
|
||||
w1a_weight.shape,
|
||||
w1b_weight.shape,
|
||||
w2a_weight.shape,
|
||||
w2b_weight.shape,
|
||||
)
|
||||
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
||||
np_weights = weights.numpy() * (alpha / dim)
|
||||
|
|
Loading…
Reference in New Issue