From b82246fdab08e03e98df8da4e6b3597c98997013 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 08:43:39 -0500 Subject: [PATCH] parse CLIP skip count from prompt --- api/onnx_web/diffusers/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 0b8c6d47..7cc39875 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -10,6 +10,7 @@ from diffusers import OnnxStableDiffusionPipeline logger = getLogger(__name__) +CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\") LORA_TOKEN = compile(r"\") MAX_TOKENS_PER_GROUP = 77 @@ -34,12 +35,16 @@ def expand_prompt( num_images_per_prompt: int, do_classifier_free_guidance: bool, negative_prompt: Optional[str] = None, - skip_clip_states: Optional[str] = 0, + skip_clip_states: Optional[int] = 0, ) -> "np.NDArray": # self provides: # tokenizer: CLIPTokenizer # encoder: OnnxRuntimeModel + prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN) + if len(clip_tokens) > 0: + skip_clip_states = int(clip_tokens[0][1]) + logger.info("skipping %s CLIP layers", skip_clip_states) batch_size = len(prompt) if isinstance(prompt, list) else 1 prompt = expand_prompt_ranges(prompt)