134 lines
4.7 KiB
Python
134 lines
4.7 KiB
Python
from types import SimpleNamespace
|
|
from typing import List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from compel import Compel, ReturnedEmbeddingsType
|
|
from diffusers import OnnxStableDiffusionPipeline
|
|
|
|
|
|
def wrap_encoder(text_encoder, sdxl=False):
|
|
class WrappedEncoder:
|
|
device = "cpu"
|
|
|
|
def __init__(self, text_encoder):
|
|
self.text_encoder = text_encoder
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.forward(*args, **kwargs)
|
|
|
|
def forward(
|
|
self, token_ids, attention_mask, output_hidden_states=None, return_dict=True
|
|
):
|
|
"""
|
|
If `output_hidden_states` is None, return pooled embeds.
|
|
"""
|
|
dtype = np.int32
|
|
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
|
|
dtype = np.int64
|
|
|
|
# TODO: does compel use attention masks?
|
|
outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype))
|
|
|
|
if output_hidden_states is None:
|
|
return SimpleNamespace(
|
|
text_embeds=torch.from_numpy(outputs[0]),
|
|
last_hidden_state=torch.from_numpy(outputs[1]),
|
|
)
|
|
elif output_hidden_states is True:
|
|
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
|
|
return SimpleNamespace(
|
|
last_hidden_state=torch.from_numpy(outputs[0]),
|
|
pooler_output=torch.from_numpy(outputs[1]),
|
|
hidden_states=hidden_states,
|
|
)
|
|
else:
|
|
return SimpleNamespace(
|
|
last_hidden_state=torch.from_numpy(outputs[0]),
|
|
pooler_output=torch.from_numpy(outputs[1]),
|
|
)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.text_encoder, name)
|
|
|
|
return WrappedEncoder(text_encoder)
|
|
|
|
|
|
def encode_prompt_compel(
|
|
self: OnnxStableDiffusionPipeline,
|
|
prompt: str,
|
|
num_images_per_prompt: int,
|
|
do_classifier_free_guidance: bool,
|
|
negative_prompt: Optional[str] = None,
|
|
prompt_embeds: Optional[np.ndarray] = None,
|
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
|
skip_clip_states: int = 0,
|
|
) -> np.ndarray:
|
|
wrapped_encoder = wrap_encoder(self.text_encoder)
|
|
compel = Compel(tokenizer=self.tokenizer, text_encoder=wrapped_encoder)
|
|
|
|
prompt_embeds = compel(prompt)
|
|
|
|
if negative_prompt is not None:
|
|
negative_prompt_embeds = compel(negative_prompt)
|
|
|
|
if negative_prompt_embeds is not None:
|
|
[prompt_embeds, negative_prompt_embeds] = (
|
|
compel.pad_conditioning_tensors_to_same_length(
|
|
[prompt_embeds, negative_prompt_embeds]
|
|
)
|
|
)
|
|
|
|
prompt_embeds = prompt_embeds.numpy().astype(np.int32)
|
|
if negative_prompt_embeds is not None:
|
|
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.int32)
|
|
|
|
return np.concatenate([negative_prompt_embeds, prompt_embeds])
|
|
|
|
|
|
def encode_prompt_compel_sdxl(
|
|
self: OnnxStableDiffusionPipeline,
|
|
prompt: Union[str, List[str]],
|
|
num_images_per_prompt: int,
|
|
do_classifier_free_guidance: bool,
|
|
negative_prompt: Optional[Union[str, list]],
|
|
prompt_embeds: Optional[np.ndarray] = None,
|
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
|
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
|
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
|
skip_clip_states: int = 0,
|
|
) -> np.ndarray:
|
|
wrapped_encoder = wrap_encoder(self.text_encoder, sdxl=True)
|
|
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2, sdxl=True)
|
|
compel = Compel(
|
|
tokenizer=[self.tokenizer, self.tokenizer_2],
|
|
text_encoder=[wrapped_encoder, wrapped_encoder_2],
|
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
|
requires_pooled=[False, True],
|
|
)
|
|
|
|
prompt_embeds, prompt_pooled = compel(prompt)
|
|
|
|
if negative_prompt is not None:
|
|
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
|
|
|
|
if negative_prompt_embeds is not None:
|
|
[prompt_embeds, negative_prompt_embeds] = (
|
|
compel.pad_conditioning_tensors_to_same_length(
|
|
[prompt_embeds, negative_prompt_embeds]
|
|
)
|
|
)
|
|
|
|
prompt_embeds = prompt_embeds.numpy().astype(np.float32)
|
|
prompt_pooled = prompt_pooled.numpy().astype(np.float32)
|
|
if negative_prompt_embeds is not None:
|
|
negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32)
|
|
negative_pooled = negative_pooled.numpy().astype(np.float32)
|
|
|
|
return (
|
|
prompt_embeds,
|
|
negative_prompt_embeds,
|
|
prompt_pooled,
|
|
negative_pooled,
|
|
)
|