feat(api): do not repeat wildcard values
This commit is contained in:
parent
2121c7aa5d
commit
e65de8115e
|
@ -3,6 +3,7 @@ from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from re import Pattern, compile
|
from re import Pattern, compile
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -375,6 +376,8 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
|
||||||
next_match = WILDCARD_TOKEN.search(prompt)
|
next_match = WILDCARD_TOKEN.search(prompt)
|
||||||
remaining_prompt = prompt
|
remaining_prompt = prompt
|
||||||
|
|
||||||
|
# prep a local copy to avoid mutating the main one
|
||||||
|
wildcards = deepcopy(wildcards)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
while next_match is not None:
|
while next_match is not None:
|
||||||
|
@ -383,7 +386,7 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
|
||||||
|
|
||||||
wildcard = ""
|
wildcard = ""
|
||||||
if name in wildcards:
|
if name in wildcards:
|
||||||
wildcard = random.choice(wildcards.get(name))
|
wildcard = pop_random(wildcards.get(name))
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown wildcard: %s", name)
|
logger.warning("unknown wildcard: %s", name)
|
||||||
|
|
||||||
|
@ -395,3 +398,12 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
|
||||||
next_match = WILDCARD_TOKEN.search(remaining_prompt)
|
next_match = WILDCARD_TOKEN.search(remaining_prompt)
|
||||||
|
|
||||||
return remaining_prompt
|
return remaining_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def pop_random(list: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
From https://stackoverflow.com/a/14088129
|
||||||
|
"""
|
||||||
|
i = random.randrange(len(list))
|
||||||
|
list[i], list[-1] = list[-1], list[i]
|
||||||
|
return list.pop()
|
Loading…
Reference in New Issue