diff --git a/api/onnx_web/prompt/compel.py b/api/onnx_web/prompt/compel.py index e1dca077..4e53afdc 100644 --- a/api/onnx_web/prompt/compel.py +++ b/api/onnx_web/prompt/compel.py @@ -20,15 +20,21 @@ def wrap_encoder(text_encoder): def forward( self, token_ids, attention_mask, output_hidden_states=True, return_dict=True ): + dtype = np.int32 + if text_encoder.session.get_inputs()[0].type == "tensor(int64)": + dtype = np.int64 + # TODO: does compel use attention masks? - outputs = text_encoder(input_ids=token_ids.numpy().astype(np.int32)) + outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype)) if return_dict: if output_hidden_states: hidden_states = outputs[2:] return SimpleNamespace( last_hidden_state=torch.from_numpy(outputs[0]), pooler_output=torch.from_numpy(outputs[1]), - hidden_states=torch.from_numpy(hidden_states), + hidden_states=torch.from_numpy( + np.concatenate(hidden_states, axis=0) + ), ) else: return SimpleNamespace(