1
0
Fork 0
onnx-web/api/onnx_web/prompt/compel.py

134 lines
4.7 KiB
Python
Raw Normal View History

from types import SimpleNamespace
2024-03-03 04:35:04 +00:00
from typing import List, Optional, Union
import numpy as np
import torch
2024-03-03 04:28:07 +00:00
from compel import Compel, ReturnedEmbeddingsType
from diffusers import OnnxStableDiffusionPipeline
2024-03-03 04:53:54 +00:00
def wrap_encoder(text_encoder, sdxl=False):
class WrappedEncoder:
2024-03-03 04:09:13 +00:00
device = "cpu"
def __init__(self, text_encoder):
self.text_encoder = text_encoder
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
2024-03-03 17:57:55 +00:00
self, token_ids, attention_mask, output_hidden_states=None, return_dict=True
):
2024-03-03 17:57:55 +00:00
"""
If `output_hidden_states` is None, return pooled embeds.
"""
2024-03-03 04:44:39 +00:00
dtype = np.int32
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
dtype = np.int64
2024-03-03 04:09:13 +00:00
# TODO: does compel use attention masks?
2024-03-03 04:44:39 +00:00
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
2024-03-03 05:43:34 +00:00
2024-03-03 17:57:55 +00:00
if output_hidden_states is None:
return SimpleNamespace(
text_embeds=torch.from_numpy(outputs[0]),
last_hidden_state=torch.from_numpy(outputs[1]),
)
elif output_hidden_states is True:
2024-03-03 17:26:56 +00:00
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
2024-03-03 13:31:50 +00:00
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
2024-03-03 17:26:56 +00:00
hidden_states=hidden_states,
2024-03-03 13:31:50 +00:00
)
2024-03-03 05:43:34 +00:00
else:
2024-03-03 13:31:50 +00:00
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
)
def __getattr__(self, name):
return getattr(self.text_encoder, name)
return WrappedEncoder(text_encoder)
2024-03-03 04:28:07 +00:00
def encode_prompt_compel(
self: OnnxStableDiffusionPipeline,
prompt: str,
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
skip_clip_states: int = 0,
) -> np.ndarray:
wrapped_encoder = wrap_encoder(self.text_encoder)
compel = Compel(tokenizer=self.tokenizer, text_encoder=wrapped_encoder)
prompt_embeds = compel(prompt)
if negative_prompt is not None:
2024-03-03 04:35:04 +00:00
negative_prompt_embeds = compel(negative_prompt)
if negative_prompt_embeds is not None:
[prompt_embeds, negative_prompt_embeds] = (
compel.pad_conditioning_tensors_to_same_length(
[prompt_embeds, negative_prompt_embeds]
)
)
2024-03-03 04:09:13 +00:00
prompt_embeds = prompt_embeds.numpy().astype(np.int32)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.int32)
return np.concatenate([negative_prompt_embeds, prompt_embeds])
2024-03-03 04:28:07 +00:00
def encode_prompt_compel_sdxl(
self: OnnxStableDiffusionPipeline,
2024-03-03 04:35:04 +00:00
prompt: Union[str, List[str]],
2024-03-03 04:28:07 +00:00
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
2024-03-03 04:35:04 +00:00
negative_prompt: Optional[Union[str, list]],
2024-03-03 04:28:07 +00:00
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
2024-03-03 04:35:04 +00:00
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
2024-03-03 04:28:07 +00:00
skip_clip_states: int = 0,
) -> np.ndarray:
2024-03-03 04:53:54 +00:00
wrapped_encoder = wrap_encoder(self.text_encoder, sdxl=True)
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2, sdxl=True)
2024-03-03 04:28:07 +00:00
compel = Compel(
tokenizer=[self.tokenizer, self.tokenizer_2],
text_encoder=[wrapped_encoder, wrapped_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
prompt_embeds, prompt_pooled = compel(prompt)
if negative_prompt is not None:
2024-03-03 04:35:04 +00:00
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
2024-03-03 04:28:07 +00:00
if negative_prompt_embeds is not None:
[prompt_embeds, negative_prompt_embeds] = (
compel.pad_conditioning_tensors_to_same_length(
[prompt_embeds, negative_prompt_embeds]
)
)
2024-03-03 13:31:50 +00:00
prompt_embeds = prompt_embeds.numpy().astype(np.float32)
prompt_pooled = prompt_pooled.numpy().astype(np.float32)
2024-03-03 04:28:07 +00:00
if negative_prompt_embeds is not None:
2024-03-03 13:31:50 +00:00
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32)
negative_pooled = negative_pooled.numpy().astype(np.float32)
2024-03-03 04:28:07 +00:00
return (
prompt_embeds,
negative_prompt_embeds,
prompt_pooled,
negative_pooled,
)