1
0
Fork 0

feat(api): add range syntax to expand numbered tokens (#179)

This commit is contained in:
Sean Sube 2023-03-07 20:48:26 -06:00
parent 66c42485cb
commit 0a4f83ac0f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 0 deletions

View File

@ -1,5 +1,6 @@
from logging import getLogger
from math import ceil
from re import compile
from typing import List, Optional
import numpy as np
@ -7,7 +8,18 @@ from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__)
MAX_TOKENS_PER_GROUP = 77
PATTERN_RANGE = compile("(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
def expand_prompt_ranges(prompt: str) -> str:
def expand_range(match):
(base_token, start, end, step) = match.groups(default=1)
num_tokens = [f"{base_token}-{i}" for i in range(int(start), int(end), int(step))]
return " ".join(num_tokens)
return PATTERN_RANGE.sub(expand_range, prompt)
def expand_prompt(
@ -22,6 +34,7 @@ def expand_prompt(
# encoder: OnnxRuntimeModel
batch_size = len(prompt) if isinstance(prompt, list) else 1
prompt = expand_prompt_ranges(prompt)
# split prompt into 75 token chunks
tokens = self.tokenizer(