feat(api): do not repeat wildcard values
This commit is contained in:
parent
2121c7aa5d
commit
e65de8115e
|
@ -3,6 +3,7 @@ from logging import getLogger
|
|||
from math import ceil
|
||||
from re import Pattern, compile
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -375,6 +376,8 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
|
|||
next_match = WILDCARD_TOKEN.search(prompt)
|
||||
remaining_prompt = prompt
|
||||
|
||||
# prep a local copy to avoid mutating the main one
|
||||
wildcards = deepcopy(wildcards)
|
||||
random.seed(seed)
|
||||
|
||||
while next_match is not None:
|
||||
|
@ -383,7 +386,7 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
|
|||
|
||||
wildcard = ""
|
||||
if name in wildcards:
|
||||
wildcard = random.choice(wildcards.get(name))
|
||||
wildcard = pop_random(wildcards.get(name))
|
||||
else:
|
||||
logger.warning("unknown wildcard: %s", name)
|
||||
|
||||
|
@ -395,3 +398,12 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
|
|||
next_match = WILDCARD_TOKEN.search(remaining_prompt)
|
||||
|
||||
return remaining_prompt
|
||||
|
||||
|
||||
def pop_random(list: List[str]) -> str:
|
||||
"""
|
||||
From https://stackoverflow.com/a/14088129
|
||||
"""
|
||||
i = random.randrange(len(list))
|
||||
list[i], list[-1] = list[-1], list[i]
|
||||
return list.pop()
|
Loading…
Reference in New Issue