1
0
Fork 0

parse CLIP skip count from prompt

This commit is contained in:
Sean Sube 2023-03-19 08:43:39 -05:00
parent 46d1b5636d
commit b82246fdab
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 1 deletions

View File

@ -10,6 +10,7 @@ from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__) logger = getLogger(__name__)
CLIP_TOKEN = compile(r"\<clip:([-\w]+):([\.|\d]+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):([\.|\d]+)\>") INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):([\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([-\w]+):([\.|\d]+)\>") LORA_TOKEN = compile(r"\<lora:([-\w]+):([\.|\d]+)\>")
MAX_TOKENS_PER_GROUP = 77 MAX_TOKENS_PER_GROUP = 77
@ -34,12 +35,16 @@ def expand_prompt(
num_images_per_prompt: int, num_images_per_prompt: int,
do_classifier_free_guidance: bool, do_classifier_free_guidance: bool,
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
skip_clip_states: Optional[str] = 0, skip_clip_states: Optional[int] = 0,
) -> "np.NDArray": ) -> "np.NDArray":
# self provides: # self provides:
# tokenizer: CLIPTokenizer # tokenizer: CLIPTokenizer
# encoder: OnnxRuntimeModel # encoder: OnnxRuntimeModel
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
if len(clip_tokens) > 0:
skip_clip_states = int(clip_tokens[0][1])
logger.info("skipping %s CLIP layers", skip_clip_states)
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1
prompt = expand_prompt_ranges(prompt) prompt = expand_prompt_ranges(prompt)