feat(api): start adding support for multiple prompt parsers
This commit is contained in:
parent
98f99b1523
commit
0f1298824c
|
@ -2,6 +2,7 @@
|
||||||
coverage.xml
|
coverage.xml
|
||||||
entry.py
|
entry.py
|
||||||
|
|
||||||
|
*.dot
|
||||||
*.log
|
*.log
|
||||||
*.swp
|
*.swp
|
||||||
*.pyc
|
*.pyc
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
|
||||||
|
|
||||||
|
|
||||||
|
def token():
|
||||||
|
return RegExMatch(r"\w+")
|
||||||
|
|
||||||
|
|
||||||
|
def token_run():
|
||||||
|
return OneOrMore(token)
|
||||||
|
|
||||||
|
|
||||||
|
def phrase_inner():
|
||||||
|
return [phrase, token_run]
|
||||||
|
|
||||||
|
|
||||||
|
def pos_phrase():
|
||||||
|
return ("(", OneOrMore(phrase_inner), ")")
|
||||||
|
|
||||||
|
|
||||||
|
def neg_phrase():
|
||||||
|
return ("[", OneOrMore(phrase_inner), "]")
|
||||||
|
|
||||||
|
|
||||||
|
def phrase():
|
||||||
|
return [pos_phrase, neg_phrase, token_run]
|
||||||
|
|
||||||
|
|
||||||
|
def prompt():
|
||||||
|
return OneOrMore(phrase), EOF
|
||||||
|
|
||||||
|
|
||||||
|
class PromptPhrase:
|
||||||
|
def __init__(self, tokens: Union[List[str], str], weight: float = 1.0) -> None:
|
||||||
|
self.tokens = tokens
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.tokens} * {self.weight}"
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, self.__class__):
|
||||||
|
return other.tokens == self.tokens and other.weight == self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxPromptVisitor(PTNodeVisitor):
|
||||||
|
def __init__(self, defaults=True, weight=0.5, **kwargs):
|
||||||
|
super().__init__(defaults, **kwargs)
|
||||||
|
|
||||||
|
self.neg_weight = weight
|
||||||
|
self.pos_weight = 1.0 + weight
|
||||||
|
|
||||||
|
def visit_token(self, node, children):
|
||||||
|
return str(node.value)
|
||||||
|
|
||||||
|
def visit_token_run(self, node, children):
|
||||||
|
return children
|
||||||
|
|
||||||
|
def visit_phrase_inner(self, node, children):
|
||||||
|
if isinstance(children[0], PromptPhrase):
|
||||||
|
return children[0]
|
||||||
|
else:
|
||||||
|
return PromptPhrase(children[0])
|
||||||
|
|
||||||
|
def visit_pos_phrase(self, node, children):
|
||||||
|
c = children[0]
|
||||||
|
if isinstance(c, PromptPhrase):
|
||||||
|
return PromptPhrase(c.tokens, c.weight * self.pos_weight)
|
||||||
|
elif isinstance(c, str):
|
||||||
|
return PromptPhrase(c, self.pos_weight)
|
||||||
|
|
||||||
|
def visit_neg_phrase(self, node, children):
|
||||||
|
c = children[0]
|
||||||
|
if isinstance(c, PromptPhrase):
|
||||||
|
return PromptPhrase(c.tokens, c.weight * self.neg_weight)
|
||||||
|
elif isinstance(c, str):
|
||||||
|
return PromptPhrase(c, self.neg_weight)
|
||||||
|
|
||||||
|
def visit_phrase(self, node, children):
|
||||||
|
return children[0]
|
||||||
|
|
||||||
|
def visit_prompt(self, node, children):
|
||||||
|
return children
|
|
@ -0,0 +1,47 @@
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from arpeggio import ParserPython, visit_parse_tree
|
||||||
|
|
||||||
|
from .grammar import OnnxPromptVisitor
|
||||||
|
from .grammar import prompt as prompt_base
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray:
|
||||||
|
from compel import Compel
|
||||||
|
|
||||||
|
parser = Compel(tokenizer=pipeline.tokenizer, text_encoder=pipeline.text_encoder)
|
||||||
|
return parser([prompt])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:
|
||||||
|
parser = ParserPython(prompt_base, debug=debug)
|
||||||
|
visitor = OnnxPromptVisitor()
|
||||||
|
|
||||||
|
ast = parser.parse(prompt)
|
||||||
|
return visit_parse_tree(ast, visitor)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_vanilla(pipeline, prompt: str) -> np.ndarray:
|
||||||
|
return pipeline._encode_prompt(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt(
|
||||||
|
pipeline,
|
||||||
|
prompt: str,
|
||||||
|
engine: Literal["compel", "lpw", "onnx-web", "pipeline"] = "onnx-web",
|
||||||
|
) -> np.ndarray:
|
||||||
|
if engine == "compel":
|
||||||
|
return parse_prompt_compel(pipeline, prompt)
|
||||||
|
if engine == "lpw":
|
||||||
|
return parse_prompt_lpw(pipeline, prompt)
|
||||||
|
elif engine == "onnx-web":
|
||||||
|
return parse_prompt_onnx(pipeline, prompt)
|
||||||
|
elif engine == "pipeline":
|
||||||
|
return parse_prompt_vanilla(pipeline, prompt)
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid prompt parser")
|
|
@ -0,0 +1,37 @@
|
||||||
|
import unittest
|
||||||
|
from onnx_web.prompt.grammar import PromptPhrase
|
||||||
|
from onnx_web.prompt.parser import parse_prompt_onnx
|
||||||
|
|
||||||
|
class ParserTests(unittest.TestCase):
|
||||||
|
def test_single_word_phrase(self):
|
||||||
|
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)
|
||||||
|
self.assertListEqual(
|
||||||
|
[str(i) for i in res],
|
||||||
|
[
|
||||||
|
str(["foo"]),
|
||||||
|
str(PromptPhrase(["bar"], weight=1.5)),
|
||||||
|
str(["bin"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multi_word_phrase(self):
|
||||||
|
res = parse_prompt_onnx(None, "foo bar (middle words) bin bun", debug=False)
|
||||||
|
self.assertListEqual(
|
||||||
|
[str(i) for i in res],
|
||||||
|
[
|
||||||
|
str(["foo", "bar"]),
|
||||||
|
str(PromptPhrase(["middle", "words"], weight=1.5)),
|
||||||
|
str(["bin", "bun"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_nested_phrase(self):
|
||||||
|
res = parse_prompt_onnx(None, "foo (((bar))) bin", debug=False)
|
||||||
|
self.assertListEqual(
|
||||||
|
[str(i) for i in res],
|
||||||
|
[
|
||||||
|
str(["foo"]),
|
||||||
|
str(PromptPhrase(["bar"], weight=(1.5 ** 3))),
|
||||||
|
str(["bin"]),
|
||||||
|
]
|
||||||
|
)
|
Loading…
Reference in New Issue