From e65de8115e6cf961cd1f889f156271f914fbc76a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 6 Jul 2023 18:10:08 -0500 Subject: [PATCH] feat(api): do not repeat wildcard values --- api/onnx_web/diffusers/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index fc45fdff..f16cfe8a 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -3,6 +3,7 @@ from logging import getLogger from math import ceil from re import Pattern, compile from typing import Dict, List, Optional, Tuple +from copy import deepcopy import numpy as np import torch @@ -375,6 +376,8 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) - next_match = WILDCARD_TOKEN.search(prompt) remaining_prompt = prompt + # prep a local copy to avoid mutating the main one + wildcards = deepcopy(wildcards) random.seed(seed) while next_match is not None: @@ -383,7 +386,7 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) - wildcard = "" if name in wildcards: - wildcard = random.choice(wildcards.get(name)) + wildcard = pop_random(wildcards.get(name)) else: logger.warning("unknown wildcard: %s", name) @@ -395,3 +398,12 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) - next_match = WILDCARD_TOKEN.search(remaining_prompt) return remaining_prompt + + +def pop_random(list: List[str]) -> str: + """ + From https://stackoverflow.com/a/14088129 + """ + i = random.randrange(len(list)) + list[i], list[-1] = list[-1], list[i] + return list.pop() \ No newline at end of file