diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 0b8c6d47..7cc39875 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -10,6 +10,7 @@ from diffusers import OnnxStableDiffusionPipeline logger = getLogger(__name__) +CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\") LORA_TOKEN = compile(r"\") MAX_TOKENS_PER_GROUP = 77 @@ -34,12 +35,16 @@ def expand_prompt( num_images_per_prompt: int, do_classifier_free_guidance: bool, negative_prompt: Optional[str] = None, - skip_clip_states: Optional[str] = 0, + skip_clip_states: Optional[int] = 0, ) -> "np.NDArray": # self provides: # tokenizer: CLIPTokenizer # encoder: OnnxRuntimeModel + prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN) + if len(clip_tokens) > 0: + skip_clip_states = int(clip_tokens[0][1]) + logger.info("skipping %s CLIP layers", skip_clip_states) batch_size = len(prompt) if isinstance(prompt, list) else 1 prompt = expand_prompt_ranges(prompt)