From 9113cc53d932a5df971b802ecda7fef9e9521e5c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 2 Mar 2024 22:44:39 -0600 Subject: [PATCH] detect dtype in encoder patch --- api/onnx_web/prompt/compel.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/prompt/compel.py b/api/onnx_web/prompt/compel.py index e1dca077..4e53afdc 100644 --- a/api/onnx_web/prompt/compel.py +++ b/api/onnx_web/prompt/compel.py @@ -20,15 +20,21 @@ def wrap_encoder(text_encoder): def forward( self, token_ids, attention_mask, output_hidden_states=True, return_dict=True ): + dtype = np.int32 + if text_encoder.session.get_inputs()[0].type == "tensor(int64)": + dtype = np.int64 + # TODO: does compel use attention masks? - outputs = text_encoder(input_ids=token_ids.numpy().astype(np.int32)) + outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype)) if return_dict: if output_hidden_states: hidden_states = outputs[2:] return SimpleNamespace( last_hidden_state=torch.from_numpy(outputs[0]), pooler_output=torch.from_numpy(outputs[1]), - hidden_states=torch.from_numpy(hidden_states), + hidden_states=torch.from_numpy( + np.concatenate(hidden_states, axis=0) + ), ) else: return SimpleNamespace(