parse CLIP skip tokens for Compel
This commit is contained in:
parent
ce45e63d65
commit
4713169ad9
|
@ -79,6 +79,17 @@ def expand_alternative_ranges(prompt: str) -> List[str]:
|
|||
return prompts
|
||||
|
||||
|
||||
def split_clip_skip(prompt: str) -> Tuple[str, int]:
|
||||
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
|
||||
|
||||
skip_clip_states = 0
|
||||
if len(clip_tokens) > 0:
|
||||
skip_clip_states = int(clip_tokens[0][1])
|
||||
logger.info("skipping %s CLIP layers", skip_clip_states)
|
||||
|
||||
return prompt, skip_clip_states
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def expand_prompt(
|
||||
self: OnnxStableDiffusionPipeline,
|
||||
|
@ -94,10 +105,7 @@ def expand_prompt(
|
|||
# tokenizer: CLIPTokenizer
|
||||
# encoder: OnnxRuntimeModel
|
||||
|
||||
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
|
||||
if len(clip_tokens) > 0:
|
||||
skip_clip_states = int(clip_tokens[0][1])
|
||||
logger.info("skipping %s CLIP layers", skip_clip_states)
|
||||
prompt, skip_clip_states = split_clip_skip(prompt)
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
prompt = expand_interval_ranges(prompt)
|
||||
|
@ -403,9 +411,6 @@ def encode_prompt(
|
|||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
TODO: does not work with SDXL, fix or turn into a pipeline patch
|
||||
"""
|
||||
return [
|
||||
pipe._encode_prompt(
|
||||
remove_tokens(prompt),
|
||||
|
|
|
@ -6,6 +6,8 @@ import torch
|
|||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
|
||||
from ..diffusers.utils import split_clip_skip
|
||||
|
||||
|
||||
def get_inference_session(model):
|
||||
if hasattr(model, "session"):
|
||||
|
@ -73,8 +75,9 @@ def encode_prompt_compel(
|
|||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
skip_clip_states: int = 0,
|
||||
) -> np.ndarray:
|
||||
prompt, skip_clip_states = split_clip_skip(prompt)
|
||||
|
||||
embeddings_type = (
|
||||
ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
|
||||
if skip_clip_states == 0
|
||||
|
@ -116,7 +119,6 @@ def encode_prompt_compel_sdxl(
|
|||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
skip_clip_states: int = 0,
|
||||
) -> np.ndarray:
|
||||
wrapped_encoder = wrap_encoder(self.text_encoder, sdxl=True)
|
||||
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2, sdxl=True)
|
||||
|
@ -127,6 +129,7 @@ def encode_prompt_compel_sdxl(
|
|||
requires_pooled=[False, True],
|
||||
)
|
||||
|
||||
prompt, _skip_clip_states = split_clip_skip(prompt)
|
||||
prompt_embeds, prompt_pooled = compel(prompt)
|
||||
|
||||
if negative_prompt is not None:
|
||||
|
|
Loading…
Reference in New Issue