return text embeds when requested
This commit is contained in:
parent
4f5f87bc96
commit
6dbed8c114
|
@ -18,8 +18,11 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
|
||||
self, token_ids, attention_mask, output_hidden_states=None, return_dict=True
|
||||
):
|
||||
"""
|
||||
If `output_hidden_states` is None, return pooled embeds.
|
||||
"""
|
||||
dtype = np.int32
|
||||
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
|
||||
dtype = np.int64
|
||||
|
@ -27,7 +30,12 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
# TODO: does compel use attention masks?
|
||||
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
||||
|
||||
if output_hidden_states:
|
||||
if output_hidden_states is None:
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
last_hidden_state=torch.from_numpy(outputs[1]),
|
||||
)
|
||||
elif output_hidden_states is True:
|
||||
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
|
|
Loading…
Reference in New Issue