1
0
Fork 0

fix SDXL patch output

This commit is contained in:
Sean Sube 2024-03-03 07:31:50 -06:00
parent 069b79583a
commit 463799f6c8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 16 additions and 30 deletions

View File

@ -28,33 +28,19 @@ def wrap_encoder(text_encoder, sdxl=False):
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
if output_hidden_states:
if sdxl:
hidden_states = outputs[1:]
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
)
else:
hidden_states = outputs[2:]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
)
hidden_states = outputs[2:]
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
hidden_states=torch.from_numpy(
np.concatenate(hidden_states, axis=0)
),
)
else:
if sdxl:
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
)
else:
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
)
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
)
def __getattr__(self, name):
return getattr(self.text_encoder, name)
@ -127,11 +113,11 @@ def encode_prompt_compel_sdxl(
)
)
prompt_embeds = prompt_embeds.numpy().astype(np.int32)
prompt_pooled = prompt_pooled.numpy().astype(np.int32)
prompt_embeds = prompt_embeds.numpy().astype(np.float32)
prompt_pooled = prompt_pooled.numpy().astype(np.float32)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.int32)
negative_pooled = negative_pooled.numpy().astype(np.int32)
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32)
negative_pooled = negative_pooled.numpy().astype(np.float32)
return (
prompt_embeds,