fix SDXL patch output
This commit is contained in:
parent
069b79583a
commit
463799f6c8
|
@ -28,33 +28,19 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
||||
|
||||
if output_hidden_states:
|
||||
if sdxl:
|
||||
hidden_states = outputs[1:]
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs[2:]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
hidden_states = outputs[2:]
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
hidden_states=torch.from_numpy(
|
||||
np.concatenate(hidden_states, axis=0)
|
||||
),
|
||||
)
|
||||
else:
|
||||
if sdxl:
|
||||
return SimpleNamespace(
|
||||
text_embeds=torch.from_numpy(outputs[0]),
|
||||
)
|
||||
else:
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
)
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.text_encoder, name)
|
||||
|
@ -127,11 +113,11 @@ def encode_prompt_compel_sdxl(
|
|||
)
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.numpy().astype(np.int32)
|
||||
prompt_pooled = prompt_pooled.numpy().astype(np.int32)
|
||||
prompt_embeds = prompt_embeds.numpy().astype(np.float32)
|
||||
prompt_pooled = prompt_pooled.numpy().astype(np.float32)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.int32)
|
||||
negative_pooled = negative_pooled.numpy().astype(np.int32)
|
||||
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32)
|
||||
negative_pooled = negative_pooled.numpy().astype(np.float32)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
|
|
Loading…
Reference in New Issue