parse CLIP skip count from prompt
This commit is contained in:
parent
46d1b5636d
commit
b82246fdab
|
@ -10,6 +10,7 @@ from diffusers import OnnxStableDiffusionPipeline
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_TOKEN = compile(r"\<clip:([-\w]+):([\.|\d]+)\>")
|
||||||
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):([\.|\d]+)\>")
|
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):([\.|\d]+)\>")
|
||||||
LORA_TOKEN = compile(r"\<lora:([-\w]+):([\.|\d]+)\>")
|
LORA_TOKEN = compile(r"\<lora:([-\w]+):([\.|\d]+)\>")
|
||||||
MAX_TOKENS_PER_GROUP = 77
|
MAX_TOKENS_PER_GROUP = 77
|
||||||
|
@ -34,12 +35,16 @@ def expand_prompt(
|
||||||
num_images_per_prompt: int,
|
num_images_per_prompt: int,
|
||||||
do_classifier_free_guidance: bool,
|
do_classifier_free_guidance: bool,
|
||||||
negative_prompt: Optional[str] = None,
|
negative_prompt: Optional[str] = None,
|
||||||
skip_clip_states: Optional[str] = 0,
|
skip_clip_states: Optional[int] = 0,
|
||||||
) -> "np.NDArray":
|
) -> "np.NDArray":
|
||||||
# self provides:
|
# self provides:
|
||||||
# tokenizer: CLIPTokenizer
|
# tokenizer: CLIPTokenizer
|
||||||
# encoder: OnnxRuntimeModel
|
# encoder: OnnxRuntimeModel
|
||||||
|
|
||||||
|
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
|
||||||
|
if len(clip_tokens) > 0:
|
||||||
|
skip_clip_states = int(clip_tokens[0][1])
|
||||||
|
logger.info("skipping %s CLIP layers", skip_clip_states)
|
||||||
|
|
||||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||||
prompt = expand_prompt_ranges(prompt)
|
prompt = expand_prompt_ranges(prompt)
|
||||||
|
|
Loading…
Reference in New Issue