1
0
Fork 0

add clip skip tokens to arpeggio parser

This commit is contained in:
Sean Sube 2024-03-03 13:12:25 -06:00
parent 86a2db1c1a
commit 1e73eac68d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 19 additions and 1 deletions

View File

@ -23,6 +23,10 @@ def integer():
return RegExMatch(r"\d+")
def token_clip_skip():
return ("clip", token_delimiter, "skip", token_delimiter, integer)
def token_inversion():
return ("inversion", token_delimiter, token_run, token_delimiter, decimal)
@ -68,7 +72,7 @@ def token_reseed():
def token_inner():
return [token_inversion, token_lora, token_region, token_reseed]
return [token_clip_skip, token_inversion, token_lora, token_region, token_reseed]
def phrase_inner():
@ -146,6 +150,9 @@ class OnnxPromptVisitor(PTNodeVisitor):
def visit_token(self, node, children):
return str(node.value)
def visit_token_clip_skip(self, node, children):
return PromptToken("clip", "skip", children[0])
def visit_token_inversion(self, node, children):
return PromptToken("inversion", children[0][0], children[1])

View File

@ -38,6 +38,17 @@ class ParserTests(unittest.TestCase):
],
)
def test_clip_skip_token(self):
res = parse_prompt_onnx(None, "foo <clip:skip:2> bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo"]),
str(PromptToken("clip", "skip", 2)),
str(["bin"]),
],
)
def test_lora_token(self):
res = parse_prompt_onnx(None, "foo <lora:name:1.5> bin", debug=False)
self.assertListEqual(