Compare commits
3 Commits
069b79583a
...
6dbed8c114
Author | SHA1 | Date |
---|---|---|
Sean Sube | 6dbed8c114 | |
Sean Sube | 4f5f87bc96 | |
Sean Sube | 463799f6c8 |
|
@ -18,8 +18,11 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self, token_ids, attention_mask, output_hidden_states=True, return_dict=True
|
||||
self, token_ids, attention_mask, output_hidden_states=None, return_dict=True
|
||||
):
|
||||
"""
|
||||
If `output_hidden_states` is None, return pooled embeds.
|
||||
"""
|
||||
dtype = np.int32
|
||||
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
|
||||
dtype = np.int64
|
||||
|
@ -27,34 +30,23 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
# TODO: does compel use attention masks?
|
||||
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
||||
|
||||
if output_hidden_states:
|
||||
if sdxl:
|
||||
hidden_states = outputs[1:]
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs[2:]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
if output_hidden_states is None:
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
last_hidden_state=torch.from_numpy(outputs[1]),
|
||||
)
|
||||
elif output_hidden_states is True:
|
||||
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
else:
|
||||
if sdxl:
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
)
|
||||
else:
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
)
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.text_encoder, name)
|
||||
|
@ -127,11 +119,11 @@ def encode_prompt_compel_sdxl(
|
|||
)
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.numpy().astype(np.int32)
|
||||
prompt_pooled = prompt_pooled.numpy().astype(np.int32)
|
||||
prompt_embeds = prompt_embeds.numpy().astype(np.float32)
|
||||
prompt_pooled = prompt_pooled.numpy().astype(np.float32)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.int32)
|
||||
negative_pooled = negative_pooled.numpy().astype(np.int32)
|
||||
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32)
|
||||
negative_pooled = negative_pooled.numpy().astype(np.float32)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
|
|
Loading…
Reference in New Issue