feat(api): add range syntax to expand numbered tokens (#179)
This commit is contained in:
parent
66c42485cb
commit
0a4f83ac0f
|
@ -1,5 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from re import compile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -7,7 +8,18 @@ from diffusers import OnnxStableDiffusionPipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MAX_TOKENS_PER_GROUP = 77
|
MAX_TOKENS_PER_GROUP = 77
|
||||||
|
PATTERN_RANGE = compile("(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
|
|
||||||
|
|
||||||
|
def expand_prompt_ranges(prompt: str) -> str:
|
||||||
|
def expand_range(match):
|
||||||
|
(base_token, start, end, step) = match.groups(default=1)
|
||||||
|
num_tokens = [f"{base_token}-{i}" for i in range(int(start), int(end), int(step))]
|
||||||
|
return " ".join(num_tokens)
|
||||||
|
|
||||||
|
return PATTERN_RANGE.sub(expand_range, prompt)
|
||||||
|
|
||||||
|
|
||||||
def expand_prompt(
|
def expand_prompt(
|
||||||
|
@ -22,6 +34,7 @@ def expand_prompt(
|
||||||
# encoder: OnnxRuntimeModel
|
# encoder: OnnxRuntimeModel
|
||||||
|
|
||||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||||
|
prompt = expand_prompt_ranges(prompt)
|
||||||
|
|
||||||
# split prompt into 75 token chunks
|
# split prompt into 75 token chunks
|
||||||
tokens = self.tokenizer(
|
tokens = self.tokenizer(
|
||||||
|
|
Loading…
Reference in New Issue