1
0
Fork 0

feat(api): do not repeat wildcard values

This commit is contained in:
Sean Sube 2023-07-06 18:10:08 -05:00
parent 2121c7aa5d
commit e65de8115e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 1 deletions

View File

@ -3,6 +3,7 @@ from logging import getLogger
from math import ceil from math import ceil
from re import Pattern, compile from re import Pattern, compile
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
@ -375,6 +376,8 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
next_match = WILDCARD_TOKEN.search(prompt) next_match = WILDCARD_TOKEN.search(prompt)
remaining_prompt = prompt remaining_prompt = prompt
# prep a local copy to avoid mutating the main one
wildcards = deepcopy(wildcards)
random.seed(seed) random.seed(seed)
while next_match is not None: while next_match is not None:
@ -383,7 +386,7 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
wildcard = "" wildcard = ""
if name in wildcards: if name in wildcards:
wildcard = random.choice(wildcards.get(name)) wildcard = pop_random(wildcards.get(name))
else: else:
logger.warning("unknown wildcard: %s", name) logger.warning("unknown wildcard: %s", name)
@ -395,3 +398,12 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
next_match = WILDCARD_TOKEN.search(remaining_prompt) next_match = WILDCARD_TOKEN.search(remaining_prompt)
return remaining_prompt return remaining_prompt
def pop_random(list: List[str]) -> str:
"""
From https://stackoverflow.com/a/14088129
"""
i = random.randrange(len(list))
list[i], list[-1] = list[-1], list[i]
return list.pop()