fix(api): convert hidden states to fp32 before doing normalization on the CPU (#309)
This commit is contained in:
parent
e8ac20b51c
commit
85b332467e
|
@ -86,7 +86,7 @@ def expand_prompt(
|
|||
if skip_clip_states > 0:
|
||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||
norm_state = layer_norm(
|
||||
torch.from_numpy(hidden_states[-skip_clip_states]).detach()
|
||||
torch.from_numpy(hidden_states[-skip_clip_states].astype(np.float32)).detach()
|
||||
)
|
||||
logger.trace(
|
||||
"normalized results after skipping %s layers: %s",
|
||||
|
|
Loading…
Reference in New Issue