leave hidden states in a list
This commit is contained in:
parent
463799f6c8
commit
4f5f87bc96
|
@ -28,13 +28,11 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = outputs[2:]
|
||||
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
else:
|
||||
return SimpleNamespace(
|
||||
|
|
Loading…
Reference in New Issue