diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index a29c8876..d96e3d5e 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -220,8 +220,15 @@ def expand_prompt( return prompt_embeds +def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]: + name, weight = group + return (name, float(weight)) + + def get_tokens_from_prompt( - prompt: str, pattern: Pattern + prompt: str, + pattern: Pattern, + parser = parse_float_group, ) -> Tuple[str, List[Tuple[str, float]]]: """ TODO: replace with Arpeggio @@ -232,8 +239,9 @@ def get_tokens_from_prompt( next_match = pattern.search(remaining_prompt) while next_match is not None: logger.debug("found token in prompt: %s", next_match) - name, weight = next_match.groups() - tokens.append((name, float(weight))) + group = next_match.groups() + tokens.append(parser(group)) + # remove this match and look for another remaining_prompt = ( remaining_prompt[: next_match.start()] @@ -451,4 +459,4 @@ Region = Tuple[int, int, int, int, Literal["add", "replace"], str] def parse_regions(prompt: str) -> List[Region]: - return get_tokens_from_prompt(prompt, REGION_TOKEN) + return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)