1
0
Fork 0

fix encoder patch for SDXL

This commit is contained in:
Sean Sube 2024-03-02 23:43:34 -06:00
parent ec6421a310
commit 069b79583a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 16 additions and 4 deletions

View File

@ -26,8 +26,17 @@ def wrap_encoder(text_encoder, sdxl=False):
# TODO: does compel use attention masks?
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
if return_dict and not sdxl:
if output_hidden_states:
if output_hidden_states:
if sdxl:
hidden_states = outputs[1:]
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
)
else:
hidden_states = outputs[2:]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
@ -36,13 +45,16 @@ def wrap_encoder(text_encoder, sdxl=False):
np.concatenate(hidden_states, axis=0)
),
)
else:
if sdxl:
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
)
else:
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
)
else:
return outputs
def __getattr__(self, name):
return getattr(self.text_encoder, name)