1
0
Fork 0

add network tokens to experimental prompt parser

This commit is contained in:
Sean Sube 2024-01-13 20:54:59 -06:00
parent a7568dbef1
commit 150ed7564d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 68 additions and 2 deletions

View File

@ -3,6 +3,10 @@ from typing import List, Union
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
def token_delimiter():
return ":"
def token(): def token():
return RegExMatch(r"\w+") return RegExMatch(r"\w+")
@ -11,6 +15,22 @@ def token_run():
return OneOrMore(token) return OneOrMore(token)
def decimal():
return RegExMatch(r"\d+\.\d*")
def token_inversion():
return ("inversion", token_delimiter, token_run, token_delimiter, decimal)
def token_lora():
return ("lora", token_delimiter, token_run, token_delimiter, decimal)
def token_inner():
return [token_inversion, token_lora]
def phrase_inner(): def phrase_inner():
return [phrase, token_run] return [phrase, token_run]
@ -23,8 +43,12 @@ def neg_phrase():
return ("[", OneOrMore(phrase_inner), "]") return ("[", OneOrMore(phrase_inner), "]")
def token_phrase():
return ("<", OneOrMore(token_inner), ">")
def phrase(): def phrase():
return [pos_phrase, neg_phrase, token_run] return [pos_phrase, neg_phrase, token_phrase, token_run]
def prompt(): def prompt():
@ -46,6 +70,26 @@ class PromptPhrase:
return False return False
class PromptToken:
def __init__(self, token_type: str, token_name: str, *rest):
self.token_type = token_type
self.token_name = token_name
self.rest = rest
def __repr__(self) -> str:
return f"<{self.token_type}:{self.token_name}:{self.rest}>"
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return (
other.token_type == self.token_type
and other.token_name == self.token_name
and other.rest == self.rest
)
return False
class OnnxPromptVisitor(PTNodeVisitor): class OnnxPromptVisitor(PTNodeVisitor):
def __init__(self, defaults=True, weight=0.5, **kwargs): def __init__(self, defaults=True, weight=0.5, **kwargs):
super().__init__(defaults, **kwargs) super().__init__(defaults, **kwargs)
@ -53,15 +97,26 @@ class OnnxPromptVisitor(PTNodeVisitor):
self.neg_weight = weight self.neg_weight = weight
self.pos_weight = 1.0 + weight self.pos_weight = 1.0 + weight
def visit_decimal(self, node, children):
return float(node.value)
def visit_token(self, node, children): def visit_token(self, node, children):
return str(node.value) return str(node.value)
def visit_token_inversion(self, node, children):
return PromptToken("lora", children[0][0], children[1])
def visit_token_lora(self, node, children):
return PromptToken("lora", children[0][0], children[1])
def visit_token_run(self, node, children): def visit_token_run(self, node, children):
return children return children
def visit_phrase_inner(self, node, children): def visit_phrase_inner(self, node, children):
if isinstance(children[0], PromptPhrase): if isinstance(children[0], PromptPhrase):
return children[0] return children[0]
elif isinstance(children[0], PromptToken):
return children[0]
else: else:
return PromptPhrase(children[0]) return PromptPhrase(children[0])

View File

@ -1,6 +1,6 @@
import unittest import unittest
from onnx_web.prompt.grammar import PromptPhrase from onnx_web.prompt.grammar import PromptPhrase, PromptToken
from onnx_web.prompt.parser import parse_prompt_onnx from onnx_web.prompt.parser import parse_prompt_onnx
@ -37,3 +37,14 @@ class ParserTests(unittest.TestCase):
str(["bin"]), str(["bin"]),
], ],
) )
def test_lora_token(self):
res = parse_prompt_onnx(None, "foo <lora:name:1.5> bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo"]),
str(PromptToken("lora", "name", 1.5)),
str(["bin"]),
],
)