1
0
Fork 0

parse region groups better

This commit is contained in:
Sean Sube 2023-11-05 16:06:49 -06:00
parent c3f4c52004
commit 8498252c75
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 12 additions and 4 deletions

View File

@ -220,8 +220,15 @@ def expand_prompt(
return prompt_embeds return prompt_embeds
def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]:
name, weight = group
return (name, float(weight))
def get_tokens_from_prompt( def get_tokens_from_prompt(
prompt: str, pattern: Pattern prompt: str,
pattern: Pattern,
parser = parse_float_group,
) -> Tuple[str, List[Tuple[str, float]]]: ) -> Tuple[str, List[Tuple[str, float]]]:
""" """
TODO: replace with Arpeggio TODO: replace with Arpeggio
@ -232,8 +239,9 @@ def get_tokens_from_prompt(
next_match = pattern.search(remaining_prompt) next_match = pattern.search(remaining_prompt)
while next_match is not None: while next_match is not None:
logger.debug("found token in prompt: %s", next_match) logger.debug("found token in prompt: %s", next_match)
name, weight = next_match.groups() group = next_match.groups()
tokens.append((name, float(weight))) tokens.append(parser(group))
# remove this match and look for another # remove this match and look for another
remaining_prompt = ( remaining_prompt = (
remaining_prompt[: next_match.start()] remaining_prompt[: next_match.start()]
@ -451,4 +459,4 @@ Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
def parse_regions(prompt: str) -> List[Region]: def parse_regions(prompt: str) -> List[Region]:
return get_tokens_from_prompt(prompt, REGION_TOKEN) return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)