convert LoHA T weights to same dtype, log shapes rather than data
This commit is contained in:
parent
1ed51352c4
commit
7d2d865c19
|
@ -111,14 +111,17 @@ def blend_loras(
|
||||||
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
||||||
|
|
||||||
if t1_weight is not None and t2_weight is not None:
|
if t1_weight is not None and t2_weight is not None:
|
||||||
|
t1_weight = t1_weight.to(dtype=dtype)
|
||||||
|
t2_weight = t2_weight.to(dtype=dtype)
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
|
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
|
||||||
t1_weight,
|
t1_weight.shape,
|
||||||
w1a_weight,
|
w1a_weight.shape,
|
||||||
w1b_weight,
|
w1b_weight.shape,
|
||||||
t2_weight,
|
t2_weight.shape,
|
||||||
w2a_weight,
|
w2a_weight.shape,
|
||||||
w2b_weight,
|
w2b_weight.shape,
|
||||||
)
|
)
|
||||||
weights_1 = torch.einsum('i j k l, j r, i p -> p r k l', t1_weight, w1b_weight, w1a_weight)
|
weights_1 = torch.einsum('i j k l, j r, i p -> p r k l', t1_weight, w1b_weight, w1a_weight)
|
||||||
weights_2 = torch.einsum('i j k l, j r, i p -> p r k l', t2_weight, w2b_weight, w2a_weight)
|
weights_2 = torch.einsum('i j k l, j r, i p -> p r k l', t2_weight, w2b_weight, w2a_weight)
|
||||||
|
@ -127,10 +130,10 @@ def blend_loras(
|
||||||
else:
|
else:
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
|
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
|
||||||
w1a_weight,
|
w1a_weight.shape,
|
||||||
w1b_weight,
|
w1b_weight.shape,
|
||||||
w2a_weight,
|
w2a_weight.shape,
|
||||||
w2b_weight,
|
w2b_weight.shape,
|
||||||
)
|
)
|
||||||
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
|
Loading…
Reference in New Issue