1
0
Fork 0

feat(api): do not repeat wildcard values

This commit is contained in:
Sean Sube 2023-07-06 18:10:08 -05:00
parent 2121c7aa5d
commit e65de8115e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 1 deletions

View File

@ -3,6 +3,7 @@ from logging import getLogger
from math import ceil
from re import Pattern, compile
from typing import Dict, List, Optional, Tuple
from copy import deepcopy
import numpy as np
import torch
@ -375,6 +376,8 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
next_match = WILDCARD_TOKEN.search(prompt)
remaining_prompt = prompt
# prep a local copy to avoid mutating the main one
wildcards = deepcopy(wildcards)
random.seed(seed)
while next_match is not None:
@ -383,7 +386,7 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
wildcard = ""
if name in wildcards:
wildcard = random.choice(wildcards.get(name))
wildcard = pop_random(wildcards.get(name))
else:
logger.warning("unknown wildcard: %s", name)
@ -395,3 +398,12 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
next_match = WILDCARD_TOKEN.search(remaining_prompt)
return remaining_prompt
def pop_random(list: List[str]) -> str:
"""
From https://stackoverflow.com/a/14088129
"""
i = random.randrange(len(list))
list[i], list[-1] = list[-1], list[i]
return list.pop()