From 150ed7564da34b8d03e225dd992aab3fa40fded5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 13 Jan 2024 20:54:59 -0600 Subject: [PATCH] add network tokens to experimental prompt parser --- api/onnx_web/prompt/grammar.py | 57 ++++++++++++++++++++++++++++++++- api/tests/prompt/test_parser.py | 13 +++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/prompt/grammar.py b/api/onnx_web/prompt/grammar.py index 20127030..eee58140 100644 --- a/api/onnx_web/prompt/grammar.py +++ b/api/onnx_web/prompt/grammar.py @@ -3,6 +3,10 @@ from typing import List, Union from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch +def token_delimiter(): + return ":" + + def token(): return RegExMatch(r"\w+") @@ -11,6 +15,22 @@ def token_run(): return OneOrMore(token) +def decimal(): + return RegExMatch(r"\d+\.\d*") + + +def token_inversion(): + return ("inversion", token_delimiter, token_run, token_delimiter, decimal) + + +def token_lora(): + return ("lora", token_delimiter, token_run, token_delimiter, decimal) + + +def token_inner(): + return [token_inversion, token_lora] + + def phrase_inner(): return [phrase, token_run] @@ -23,8 +43,12 @@ def neg_phrase(): return ("[", OneOrMore(phrase_inner), "]") +def token_phrase(): + return ("<", OneOrMore(token_inner), ">") + + def phrase(): - return [pos_phrase, neg_phrase, token_run] + return [pos_phrase, neg_phrase, token_phrase, token_run] def prompt(): @@ -46,6 +70,26 @@ class PromptPhrase: return False +class PromptToken: + def __init__(self, token_type: str, token_name: str, *rest): + self.token_type = token_type + self.token_name = token_name + self.rest = rest + + def __repr__(self) -> str: + return f"<{self.token_type}:{self.token_name}:{self.rest}>" + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return ( + other.token_type == self.token_type + and other.token_name == self.token_name + and other.rest == self.rest + ) + + return False + + class OnnxPromptVisitor(PTNodeVisitor): def __init__(self, defaults=True, weight=0.5, **kwargs): super().__init__(defaults, **kwargs) @@ -53,15 +97,26 @@ class OnnxPromptVisitor(PTNodeVisitor): self.neg_weight = weight self.pos_weight = 1.0 + weight + def visit_decimal(self, node, children): + return float(node.value) + def visit_token(self, node, children): return str(node.value) + def visit_token_inversion(self, node, children): + return PromptToken("lora", children[0][0], children[1]) + + def visit_token_lora(self, node, children): + return PromptToken("lora", children[0][0], children[1]) + def visit_token_run(self, node, children): return children def visit_phrase_inner(self, node, children): if isinstance(children[0], PromptPhrase): return children[0] + elif isinstance(children[0], PromptToken): + return children[0] else: return PromptPhrase(children[0]) diff --git a/api/tests/prompt/test_parser.py b/api/tests/prompt/test_parser.py index 15c91d6c..d0097bb1 100644 --- a/api/tests/prompt/test_parser.py +++ b/api/tests/prompt/test_parser.py @@ -1,6 +1,6 @@ import unittest -from onnx_web.prompt.grammar import PromptPhrase +from onnx_web.prompt.grammar import PromptPhrase, PromptToken from onnx_web.prompt.parser import parse_prompt_onnx @@ -37,3 +37,14 @@ class ParserTests(unittest.TestCase): str(["bin"]), ], ) + + def test_lora_token(self): + res = parse_prompt_onnx(None, "foo bin", debug=False) + self.assertListEqual( + [str(i) for i in res], + [ + str(["foo"]), + str(PromptToken("lora", "name", 1.5)), + str(["bin"]), + ], + )