add clip skip tokens to arpeggio parser
This commit is contained in:
parent
86a2db1c1a
commit
1e73eac68d
|
@ -23,6 +23,10 @@ def integer():
|
|||
return RegExMatch(r"\d+")
|
||||
|
||||
|
||||
def token_clip_skip():
|
||||
return ("clip", token_delimiter, "skip", token_delimiter, integer)
|
||||
|
||||
|
||||
def token_inversion():
|
||||
return ("inversion", token_delimiter, token_run, token_delimiter, decimal)
|
||||
|
||||
|
@ -68,7 +72,7 @@ def token_reseed():
|
|||
|
||||
|
||||
def token_inner():
|
||||
return [token_inversion, token_lora, token_region, token_reseed]
|
||||
return [token_clip_skip, token_inversion, token_lora, token_region, token_reseed]
|
||||
|
||||
|
||||
def phrase_inner():
|
||||
|
@ -146,6 +150,9 @@ class OnnxPromptVisitor(PTNodeVisitor):
|
|||
def visit_token(self, node, children):
|
||||
return str(node.value)
|
||||
|
||||
def visit_token_clip_skip(self, node, children):
|
||||
return PromptToken("clip", "skip", children[0])
|
||||
|
||||
def visit_token_inversion(self, node, children):
|
||||
return PromptToken("inversion", children[0][0], children[1])
|
||||
|
||||
|
|
|
@ -38,6 +38,17 @@ class ParserTests(unittest.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
def test_clip_skip_token(self):
|
||||
res = parse_prompt_onnx(None, "foo <clip:skip:2> bin", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptToken("clip", "skip", 2)),
|
||||
str(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_lora_token(self):
|
||||
res = parse_prompt_onnx(None, "foo <lora:name:1.5> bin", debug=False)
|
||||
self.assertListEqual(
|
||||
|
|
Loading…
Reference in New Issue