1
0
Fork 0

fix(api): convert hidden states to fp32 before doing normalization on the CPU (#309)

This commit is contained in:
Sean Sube 2023-04-01 17:49:25 -05:00
parent e8ac20b51c
commit 85b332467e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 1 additions and 1 deletions

View File

@ -86,7 +86,7 @@ def expand_prompt(
if skip_clip_states > 0:
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = layer_norm(
torch.from_numpy(hidden_states[-skip_clip_states]).detach()
torch.from_numpy(hidden_states[-skip_clip_states].astype(np.float32)).detach()
)
logger.trace(
"normalized results after skipping %s layers: %s",