add network tokens to experimental prompt parser
This commit is contained in:
parent
a7568dbef1
commit
150ed7564d
|
@ -3,6 +3,10 @@ from typing import List, Union
|
||||||
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
|
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
|
||||||
|
|
||||||
|
|
||||||
|
def token_delimiter():
|
||||||
|
return ":"
|
||||||
|
|
||||||
|
|
||||||
def token():
|
def token():
|
||||||
return RegExMatch(r"\w+")
|
return RegExMatch(r"\w+")
|
||||||
|
|
||||||
|
@ -11,6 +15,22 @@ def token_run():
|
||||||
return OneOrMore(token)
|
return OneOrMore(token)
|
||||||
|
|
||||||
|
|
||||||
|
def decimal():
|
||||||
|
return RegExMatch(r"\d+\.\d*")
|
||||||
|
|
||||||
|
|
||||||
|
def token_inversion():
|
||||||
|
return ("inversion", token_delimiter, token_run, token_delimiter, decimal)
|
||||||
|
|
||||||
|
|
||||||
|
def token_lora():
|
||||||
|
return ("lora", token_delimiter, token_run, token_delimiter, decimal)
|
||||||
|
|
||||||
|
|
||||||
|
def token_inner():
|
||||||
|
return [token_inversion, token_lora]
|
||||||
|
|
||||||
|
|
||||||
def phrase_inner():
|
def phrase_inner():
|
||||||
return [phrase, token_run]
|
return [phrase, token_run]
|
||||||
|
|
||||||
|
@ -23,8 +43,12 @@ def neg_phrase():
|
||||||
return ("[", OneOrMore(phrase_inner), "]")
|
return ("[", OneOrMore(phrase_inner), "]")
|
||||||
|
|
||||||
|
|
||||||
|
def token_phrase():
|
||||||
|
return ("<", OneOrMore(token_inner), ">")
|
||||||
|
|
||||||
|
|
||||||
def phrase():
|
def phrase():
|
||||||
return [pos_phrase, neg_phrase, token_run]
|
return [pos_phrase, neg_phrase, token_phrase, token_run]
|
||||||
|
|
||||||
|
|
||||||
def prompt():
|
def prompt():
|
||||||
|
@ -46,6 +70,26 @@ class PromptPhrase:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class PromptToken:
|
||||||
|
def __init__(self, token_type: str, token_name: str, *rest):
|
||||||
|
self.token_type = token_type
|
||||||
|
self.token_name = token_name
|
||||||
|
self.rest = rest
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.token_type}:{self.token_name}:{self.rest}>"
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, self.__class__):
|
||||||
|
return (
|
||||||
|
other.token_type == self.token_type
|
||||||
|
and other.token_name == self.token_name
|
||||||
|
and other.rest == self.rest
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class OnnxPromptVisitor(PTNodeVisitor):
|
class OnnxPromptVisitor(PTNodeVisitor):
|
||||||
def __init__(self, defaults=True, weight=0.5, **kwargs):
|
def __init__(self, defaults=True, weight=0.5, **kwargs):
|
||||||
super().__init__(defaults, **kwargs)
|
super().__init__(defaults, **kwargs)
|
||||||
|
@ -53,15 +97,26 @@ class OnnxPromptVisitor(PTNodeVisitor):
|
||||||
self.neg_weight = weight
|
self.neg_weight = weight
|
||||||
self.pos_weight = 1.0 + weight
|
self.pos_weight = 1.0 + weight
|
||||||
|
|
||||||
|
def visit_decimal(self, node, children):
|
||||||
|
return float(node.value)
|
||||||
|
|
||||||
def visit_token(self, node, children):
|
def visit_token(self, node, children):
|
||||||
return str(node.value)
|
return str(node.value)
|
||||||
|
|
||||||
|
def visit_token_inversion(self, node, children):
|
||||||
|
return PromptToken("lora", children[0][0], children[1])
|
||||||
|
|
||||||
|
def visit_token_lora(self, node, children):
|
||||||
|
return PromptToken("lora", children[0][0], children[1])
|
||||||
|
|
||||||
def visit_token_run(self, node, children):
|
def visit_token_run(self, node, children):
|
||||||
return children
|
return children
|
||||||
|
|
||||||
def visit_phrase_inner(self, node, children):
|
def visit_phrase_inner(self, node, children):
|
||||||
if isinstance(children[0], PromptPhrase):
|
if isinstance(children[0], PromptPhrase):
|
||||||
return children[0]
|
return children[0]
|
||||||
|
elif isinstance(children[0], PromptToken):
|
||||||
|
return children[0]
|
||||||
else:
|
else:
|
||||||
return PromptPhrase(children[0])
|
return PromptPhrase(children[0])
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from onnx_web.prompt.grammar import PromptPhrase
|
from onnx_web.prompt.grammar import PromptPhrase, PromptToken
|
||||||
from onnx_web.prompt.parser import parse_prompt_onnx
|
from onnx_web.prompt.parser import parse_prompt_onnx
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,3 +37,14 @@ class ParserTests(unittest.TestCase):
|
||||||
str(["bin"]),
|
str(["bin"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_lora_token(self):
|
||||||
|
res = parse_prompt_onnx(None, "foo <lora:name:1.5> bin", debug=False)
|
||||||
|
self.assertListEqual(
|
||||||
|
[str(i) for i in res],
|
||||||
|
[
|
||||||
|
str(["foo"]),
|
||||||
|
str(PromptToken("lora", "name", 1.5)),
|
||||||
|
str(["bin"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue