detect dtype in encoder patch
This commit is contained in:
parent
a1657a6b09
commit
9113cc53d9
|
@ -20,15 +20,21 @@ def wrap_encoder(text_encoder):
|
|||
def forward(
|
||||
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
|
||||
):
|
||||
dtype = np.int32
|
||||
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
|
||||
dtype = np.int64
|
||||
|
||||
# TODO: does compel use attention masks?
|
||||
outputs = text_encoder(input_ids=token_ids.numpy().astype(np.int32))
|
||||
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
||||
if return_dict:
|
||||
if output_hidden_states:
|
||||
hidden_states = outputs[2:]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
hidden_states=torch.from_numpy(hidden_states),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
else:
|
||||
return SimpleNamespace(
|
||||
|
|
Loading…
Reference in New Issue