fix encoder patch for SDXL
This commit is contained in:
parent
ec6421a310
commit
069b79583a
|
@ -26,8 +26,17 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
|
||||
# TODO: does compel use attention masks?
|
||||
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
||||
if return_dict and not sdxl:
|
||||
if output_hidden_states:
|
||||
|
||||
if output_hidden_states:
|
||||
if sdxl:
|
||||
hidden_states = outputs[1:]
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs[2:]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
|
@ -36,13 +45,16 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
else:
|
||||
if sdxl:
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
)
|
||||
else:
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.text_encoder, name)
|
||||
|
|
Loading…
Reference in New Issue