1
0
Fork 0

return text embeds when requested

This commit is contained in:
Sean Sube 2024-03-03 11:57:55 -06:00
parent 4f5f87bc96
commit 6dbed8c114
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 10 additions and 2 deletions

View File

@ -18,8 +18,11 @@ def wrap_encoder(text_encoder, sdxl=False):
return self.forward(*args, **kwargs)
def forward(
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
self, token_ids, attention_mask, output_hidden_states=None, return_dict=True
):
"""
If `output_hidden_states` is None, return pooled embeds.
"""
dtype = np.int32
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
dtype = np.int64
@ -27,7 +30,12 @@ def wrap_encoder(text_encoder, sdxl=False):
# TODO: does compel use attention masks?
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
if output_hidden_states:
if output_hidden_states is None:
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
last_hidden_state=torch.from_numpy(outputs[1]),
)
elif output_hidden_states is True:
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),