1
0
Fork 0

detect dtype in encoder patch

This commit is contained in:
Sean Sube 2024-03-02 22:44:39 -06:00
parent a1657a6b09
commit 9113cc53d9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 8 additions and 2 deletions

View File

@ -20,15 +20,21 @@ def wrap_encoder(text_encoder):
def forward(
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
):
dtype = np.int32
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
dtype = np.int64
# TODO: does compel use attention masks?
outputs = text_encoder(input_ids=token_ids.numpy().astype(np.int32))
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
if return_dict:
if output_hidden_states:
hidden_states = outputs[2:]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
hidden_states=torch.from_numpy(hidden_states),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
)
else:
return SimpleNamespace(