diff --git a/api/.gitignore b/api/.gitignore index 17c17360..165cb300 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -2,6 +2,7 @@ coverage.xml entry.py +*.dot *.log *.swp *.pyc diff --git a/api/onnx_web/prompt/grammar.py b/api/onnx_web/prompt/grammar.py new file mode 100644 index 00000000..ae28d23e --- /dev/null +++ b/api/onnx_web/prompt/grammar.py @@ -0,0 +1,84 @@ +from typing import List, Union + +from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch + + +def token(): + return RegExMatch(r"\w+") + + +def token_run(): + return OneOrMore(token) + + +def phrase_inner(): + return [phrase, token_run] + + +def pos_phrase(): + return ("(", OneOrMore(phrase_inner), ")") + + +def neg_phrase(): + return ("[", OneOrMore(phrase_inner), "]") + + +def phrase(): + return [pos_phrase, neg_phrase, token_run] + + +def prompt(): + return OneOrMore(phrase), EOF + + +class PromptPhrase: + def __init__(self, tokens: Union[List[str], str], weight: float = 1.0) -> None: + self.tokens = tokens + self.weight = weight + + def __repr__(self) -> str: + return f"{self.tokens} * {self.weight}" + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return other.tokens == self.tokens and other.weight == self.weight + + +class OnnxPromptVisitor(PTNodeVisitor): + def __init__(self, defaults=True, weight=0.5, **kwargs): + super().__init__(defaults, **kwargs) + + self.neg_weight = weight + self.pos_weight = 1.0 + weight + + def visit_token(self, node, children): + return str(node.value) + + def visit_token_run(self, node, children): + return children + + def visit_phrase_inner(self, node, children): + if isinstance(children[0], PromptPhrase): + return children[0] + else: + return PromptPhrase(children[0]) + + def visit_pos_phrase(self, node, children): + c = children[0] + if isinstance(c, PromptPhrase): + return PromptPhrase(c.tokens, c.weight * self.pos_weight) + elif isinstance(c, str): + return PromptPhrase(c, self.pos_weight) + + def visit_neg_phrase(self, node, children): + c = children[0] + if isinstance(c, PromptPhrase): + return PromptPhrase(c.tokens, c.weight * self.neg_weight) + elif isinstance(c, str): + return PromptPhrase(c, self.neg_weight) + + def visit_phrase(self, node, children): + return children[0] + + def visit_prompt(self, node, children): + return children diff --git a/api/onnx_web/prompt/parser.py b/api/onnx_web/prompt/parser.py new file mode 100644 index 00000000..6a0fde22 --- /dev/null +++ b/api/onnx_web/prompt/parser.py @@ -0,0 +1,47 @@ +from typing import Literal + +import numpy as np +from arpeggio import ParserPython, visit_parse_tree + +from .grammar import OnnxPromptVisitor +from .grammar import prompt as prompt_base + + +def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray: + from compel import Compel + + parser = Compel(tokenizer=pipeline.tokenizer, text_encoder=pipeline.text_encoder) + return parser([prompt]) + + +def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray: + pass + + +def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray: + parser = ParserPython(prompt_base, debug=debug) + visitor = OnnxPromptVisitor() + + ast = parser.parse(prompt) + return visit_parse_tree(ast, visitor) + + +def parse_prompt_vanilla(pipeline, prompt: str) -> np.ndarray: + return pipeline._encode_prompt(prompt) + + +def parse_prompt( + pipeline, + prompt: str, + engine: Literal["compel", "lpw", "onnx-web", "pipeline"] = "onnx-web", +) -> np.ndarray: + if engine == "compel": + return parse_prompt_compel(pipeline, prompt) + if engine == "lpw": + return parse_prompt_lpw(pipeline, prompt) + elif engine == "onnx-web": + return parse_prompt_onnx(pipeline, prompt) + elif engine == "pipeline": + return parse_prompt_vanilla(pipeline, prompt) + else: + raise ValueError("invalid prompt parser") diff --git a/api/tests/prompt/__init__.py b/api/tests/prompt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/prompt/test_parser.py b/api/tests/prompt/test_parser.py new file mode 100644 index 00000000..20c03341 --- /dev/null +++ b/api/tests/prompt/test_parser.py @@ -0,0 +1,37 @@ +import unittest +from onnx_web.prompt.grammar import PromptPhrase +from onnx_web.prompt.parser import parse_prompt_onnx + +class ParserTests(unittest.TestCase): + def test_single_word_phrase(self): + res = parse_prompt_onnx(None, "foo (bar) bin", debug=False) + self.assertListEqual( + [str(i) for i in res], + [ + str(["foo"]), + str(PromptPhrase(["bar"], weight=1.5)), + str(["bin"]), + ] + ) + + def test_multi_word_phrase(self): + res = parse_prompt_onnx(None, "foo bar (middle words) bin bun", debug=False) + self.assertListEqual( + [str(i) for i in res], + [ + str(["foo", "bar"]), + str(PromptPhrase(["middle", "words"], weight=1.5)), + str(["bin", "bun"]), + ] + ) + + def test_nested_phrase(self): + res = parse_prompt_onnx(None, "foo (((bar))) bin", debug=False) + self.assertListEqual( + [str(i) for i in res], + [ + str(["foo"]), + str(PromptPhrase(["bar"], weight=(1.5 ** 3))), + str(["bin"]), + ] + )