parse CLIP skip count from prompt
This commit is contained in:
parent
46d1b5636d
commit
b82246fdab
|
@ -10,6 +10,7 @@ from diffusers import OnnxStableDiffusionPipeline
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):([\.|\d]+)\>")
|
||||
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):([\.|\d]+)\>")
|
||||
LORA_TOKEN = compile(r"\<lora:([-\w]+):([\.|\d]+)\>")
|
||||
MAX_TOKENS_PER_GROUP = 77
|
||||
|
@ -34,12 +35,16 @@ def expand_prompt(
|
|||
num_images_per_prompt: int,
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[str] = None,
|
||||
skip_clip_states: Optional[str] = 0,
|
||||
skip_clip_states: Optional[int] = 0,
|
||||
) -> "np.NDArray":
|
||||
# self provides:
|
||||
# tokenizer: CLIPTokenizer
|
||||
# encoder: OnnxRuntimeModel
|
||||
|
||||
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
|
||||
if len(clip_tokens) > 0:
|
||||
skip_clip_states = int(clip_tokens[0][1])
|
||||
logger.info("skipping %s CLIP layers", skip_clip_states)
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
prompt = expand_prompt_ranges(prompt)
|
||||
|
|
Loading…
Reference in New Issue