1
0
Fork 0

parse CLIP skip tokens for Compel

This commit is contained in:
Sean Sube 2024-03-03 12:45:27 -06:00
parent ce45e63d65
commit 4713169ad9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 17 additions and 9 deletions

View File

@ -79,6 +79,17 @@ def expand_alternative_ranges(prompt: str) -> List[str]:
return prompts
def split_clip_skip(prompt: str) -> Tuple[str, int]:
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
skip_clip_states = 0
if len(clip_tokens) > 0:
skip_clip_states = int(clip_tokens[0][1])
logger.info("skipping %s CLIP layers", skip_clip_states)
return prompt, skip_clip_states
@torch.no_grad()
def expand_prompt(
self: OnnxStableDiffusionPipeline,
@ -94,10 +105,7 @@ def expand_prompt(
# tokenizer: CLIPTokenizer
# encoder: OnnxRuntimeModel
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
if len(clip_tokens) > 0:
skip_clip_states = int(clip_tokens[0][1])
logger.info("skipping %s CLIP layers", skip_clip_states)
prompt, skip_clip_states = split_clip_skip(prompt)
batch_size = len(prompt) if isinstance(prompt, list) else 1
prompt = expand_interval_ranges(prompt)
@ -403,9 +411,6 @@ def encode_prompt(
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [
pipe._encode_prompt(
remove_tokens(prompt),

View File

@ -6,6 +6,8 @@ import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import OnnxStableDiffusionPipeline
from ..diffusers.utils import split_clip_skip
def get_inference_session(model):
if hasattr(model, "session"):
@ -73,8 +75,9 @@ def encode_prompt_compel(
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
skip_clip_states: int = 0,
) -> np.ndarray:
prompt, skip_clip_states = split_clip_skip(prompt)
embeddings_type = (
ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
if skip_clip_states == 0
@ -116,7 +119,6 @@ def encode_prompt_compel_sdxl(
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
skip_clip_states: int = 0,
) -> np.ndarray:
wrapped_encoder = wrap_encoder(self.text_encoder, sdxl=True)
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2, sdxl=True)
@ -127,6 +129,7 @@ def encode_prompt_compel_sdxl(
requires_pooled=[False, True],
)
prompt, _skip_clip_states = split_clip_skip(prompt)
prompt_embeds, prompt_pooled = compel(prompt)
if negative_prompt is not None: