1
0
Fork 0

detect dtype in encoder patch

This commit is contained in:
Sean Sube 2024-03-02 22:44:39 -06:00
parent a1657a6b09
commit 9113cc53d9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 8 additions and 2 deletions

View File

@ -20,15 +20,21 @@ def wrap_encoder(text_encoder):
def forward( def forward(
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
): ):
dtype = np.int32
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
dtype = np.int64
# TODO: does compel use attention masks? # TODO: does compel use attention masks?
outputs = text_encoder(input_ids=token_ids.numpy().astype(np.int32)) outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
if return_dict: if return_dict:
if output_hidden_states: if output_hidden_states:
hidden_states = outputs[2:] hidden_states = outputs[2:]
return SimpleNamespace( return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]), last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]), pooler_output=torch.from_numpy(outputs[1]),
hidden_states=torch.from_numpy(hidden_states), hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
) )
else: else:
return SimpleNamespace( return SimpleNamespace(