detect dtype in encoder patch
This commit is contained in:
parent
a1657a6b09
commit
9113cc53d9
|
@ -20,15 +20,21 @@ def wrap_encoder(text_encoder):
|
||||||
def forward(
|
def forward(
|
||||||
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
|
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
|
||||||
):
|
):
|
||||||
|
dtype = np.int32
|
||||||
|
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
|
||||||
|
dtype = np.int64
|
||||||
|
|
||||||
# TODO: does compel use attention masks?
|
# TODO: does compel use attention masks?
|
||||||
outputs = text_encoder(input_ids=token_ids.numpy().astype(np.int32))
|
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
||||||
if return_dict:
|
if return_dict:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
hidden_states = outputs[2:]
|
hidden_states = outputs[2:]
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||||
pooler_output=torch.from_numpy(outputs[1]),
|
pooler_output=torch.from_numpy(outputs[1]),
|
||||||
hidden_states=torch.from_numpy(hidden_states),
|
hidden_states=torch.from_numpy(
|
||||||
|
np.concatenate(hidden_states, axis=0)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
|
|
Loading…
Reference in New Issue