parse region groups better
This commit is contained in:
parent
c3f4c52004
commit
8498252c75
|
@ -220,8 +220,15 @@ def expand_prompt(
|
|||
return prompt_embeds
|
||||
|
||||
|
||||
def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]:
|
||||
name, weight = group
|
||||
return (name, float(weight))
|
||||
|
||||
|
||||
def get_tokens_from_prompt(
|
||||
prompt: str, pattern: Pattern
|
||||
prompt: str,
|
||||
pattern: Pattern,
|
||||
parser = parse_float_group,
|
||||
) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
"""
|
||||
TODO: replace with Arpeggio
|
||||
|
@ -232,8 +239,9 @@ def get_tokens_from_prompt(
|
|||
next_match = pattern.search(remaining_prompt)
|
||||
while next_match is not None:
|
||||
logger.debug("found token in prompt: %s", next_match)
|
||||
name, weight = next_match.groups()
|
||||
tokens.append((name, float(weight)))
|
||||
group = next_match.groups()
|
||||
tokens.append(parser(group))
|
||||
|
||||
# remove this match and look for another
|
||||
remaining_prompt = (
|
||||
remaining_prompt[: next_match.start()]
|
||||
|
@ -451,4 +459,4 @@ Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
|
|||
|
||||
|
||||
def parse_regions(prompt: str) -> List[Region]:
|
||||
return get_tokens_from_prompt(prompt, REGION_TOKEN)
|
||||
return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)
|
||||
|
|
Loading…
Reference in New Issue