1
0
Fork 0

leave hidden states in a list

This commit is contained in:
Sean Sube 2024-03-03 11:26:56 -06:00
parent 463799f6c8
commit 4f5f87bc96
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 2 additions and 4 deletions

View File

@ -28,13 +28,11 @@ def wrap_encoder(text_encoder, sdxl=False):
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
if output_hidden_states:
hidden_states = outputs[2:]
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
hidden_states=hidden_states,
)
else:
return SimpleNamespace(