1
0
Fork 0

Compare commits

...

3 Commits

Author SHA1 Message Date
Sean Sube 6dbed8c114
return text embeds when requested 2024-03-03 11:57:55 -06:00
Sean Sube 4f5f87bc96
leave hidden states in a list 2024-03-03 11:26:56 -06:00
Sean Sube 463799f6c8
fix SDXL patch output 2024-03-03 07:31:50 -06:00
1 changed files with 24 additions and 32 deletions

View File

@ -18,8 +18,11 @@ def wrap_encoder(text_encoder, sdxl=False):
return self.forward(*args, **kwargs)
def forward(
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
self, token_ids, attention_mask, output_hidden_states=None, return_dict=True
):
"""
If `output_hidden_states` is None, return pooled embeds.
"""
dtype = np.int32
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
dtype = np.int64
@ -27,34 +30,23 @@ def wrap_encoder(text_encoder, sdxl=False):
# TODO: does compel use attention masks?
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
if output_hidden_states:
if sdxl:
hidden_states = outputs[1:]
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
)
else:
hidden_states = outputs[2:]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
)
if output_hidden_states is None:
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
last_hidden_state=torch.from_numpy(outputs[1]),
)
elif output_hidden_states is True:
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
hidden_states=hidden_states,
)
else:
if sdxl:
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
)
else:
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
)
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
)
def __getattr__(self, name):
return getattr(self.text_encoder, name)
@ -127,11 +119,11 @@ def encode_prompt_compel_sdxl(
)
)
prompt_embeds = prompt_embeds.numpy().astype(np.int32)
prompt_pooled = prompt_pooled.numpy().astype(np.int32)
prompt_embeds = prompt_embeds.numpy().astype(np.float32)
prompt_pooled = prompt_pooled.numpy().astype(np.float32)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.int32)
negative_pooled = negative_pooled.numpy().astype(np.int32)
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32)
negative_pooled = negative_pooled.numpy().astype(np.float32)
return (
prompt_embeds,