1
0
Fork 0

feat(api): add range syntax to expand numbered tokens (#179)

This commit is contained in:
Sean Sube 2023-03-07 20:48:26 -06:00
parent 66c42485cb
commit 0a4f83ac0f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 0 deletions

View File

@ -1,5 +1,6 @@
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from re import compile
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
@ -7,7 +8,18 @@ from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__) logger = getLogger(__name__)
MAX_TOKENS_PER_GROUP = 77 MAX_TOKENS_PER_GROUP = 77
PATTERN_RANGE = compile("(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
def expand_prompt_ranges(prompt: str) -> str:
def expand_range(match):
(base_token, start, end, step) = match.groups(default=1)
num_tokens = [f"{base_token}-{i}" for i in range(int(start), int(end), int(step))]
return " ".join(num_tokens)
return PATTERN_RANGE.sub(expand_range, prompt)
def expand_prompt( def expand_prompt(
@ -22,6 +34,7 @@ def expand_prompt(
# encoder: OnnxRuntimeModel # encoder: OnnxRuntimeModel
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1
prompt = expand_prompt_ranges(prompt)
# split prompt into 75 token chunks # split prompt into 75 token chunks
tokens = self.tokenizer( tokens = self.tokenizer(