From 0a4f83ac0f1ceca57b824593f57348affb5abcad Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 7 Mar 2023 20:48:26 -0600 Subject: [PATCH] feat(api): add range syntax to expand numbered tokens (#179) --- api/onnx_web/diffusers/utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 5bd05e43..0a4e9a89 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -1,5 +1,6 @@ from logging import getLogger from math import ceil +from re import compile from typing import List, Optional import numpy as np @@ -7,7 +8,18 @@ from diffusers import OnnxStableDiffusionPipeline logger = getLogger(__name__) + MAX_TOKENS_PER_GROUP = 77 +PATTERN_RANGE = compile("(\w+)-{(\d+),(\d+)(?:,(\d+))?}") + + +def expand_prompt_ranges(prompt: str) -> str: + def expand_range(match): + (base_token, start, end, step) = match.groups(default=1) + num_tokens = [f"{base_token}-{i}" for i in range(int(start), int(end), int(step))] + return " ".join(num_tokens) + + return PATTERN_RANGE.sub(expand_range, prompt) def expand_prompt( @@ -22,6 +34,7 @@ def expand_prompt( # encoder: OnnxRuntimeModel batch_size = len(prompt) if isinstance(prompt, list) else 1 + prompt = expand_prompt_ranges(prompt) # split prompt into 75 token chunks tokens = self.tokenizer(