1
0
Fork 0

add region and reseed tokens to arpeggio parser

This commit is contained in:
Sean Sube 2024-03-03 13:10:04 -06:00
parent 4713169ad9
commit 86a2db1c1a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 78 additions and 5 deletions

View File

@ -19,7 +19,7 @@ def get_inference_session(model):
raise ValueError("Model does not have an inference session")
def wrap_encoder(text_encoder, sdxl=False):
def wrap_encoder(text_encoder):
class WrappedEncoder:
device = "cpu"
@ -120,8 +120,8 @@ def encode_prompt_compel_sdxl(
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
) -> np.ndarray:
wrapped_encoder = wrap_encoder(self.text_encoder, sdxl=True)
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2, sdxl=True)
wrapped_encoder = wrap_encoder(self.text_encoder)
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2)
compel = Compel(
tokenizer=[self.tokenizer, self.tokenizer_2],
text_encoder=[wrapped_encoder, wrapped_encoder_2],

View File

@ -19,6 +19,10 @@ def decimal():
return RegExMatch(r"\d+\.\d*")
def integer():
return RegExMatch(r"\d+")
def token_inversion():
return ("inversion", token_delimiter, token_run, token_delimiter, decimal)
@ -27,8 +31,44 @@ def token_lora():
return ("lora", token_delimiter, token_run, token_delimiter, decimal)
def token_region():
return (
"region",
token_delimiter,
integer,
token_delimiter,
integer,
token_delimiter,
integer,
token_delimiter,
integer,
token_delimiter,
decimal,
token_delimiter,
decimal,
token_delimiter,
token_run,
)
def token_reseed():
return (
"reseed",
token_delimiter,
integer,
token_delimiter,
integer,
token_delimiter,
integer,
token_delimiter,
integer,
token_delimiter,
integer,
)
def token_inner():
return [token_inversion, token_lora]
return [token_inversion, token_lora, token_region, token_reseed]
def phrase_inner():
@ -100,15 +140,24 @@ class OnnxPromptVisitor(PTNodeVisitor):
def visit_decimal(self, node, children):
return float(node.value)
def visit_integer(self, node, children):
return int(node.value)
def visit_token(self, node, children):
return str(node.value)
def visit_token_inversion(self, node, children):
return PromptToken("lora", children[0][0], children[1])
return PromptToken("inversion", children[0][0], children[1])
def visit_token_lora(self, node, children):
return PromptToken("lora", children[0][0], children[1])
def visit_token_region(self, node, children):
return PromptToken("region", None, children)
def visit_token_reseed(self, node, children):
return PromptToken("reseed", None, children)
def visit_token_run(self, node, children):
return children

View File

@ -48,3 +48,27 @@ class ParserTests(unittest.TestCase):
str(["bin"]),
],
)
def test_region_token(self):
res = parse_prompt_onnx(
None, "foo <region:1:2:3:4:0.5:0.75:prompt> bin", debug=False
)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo"]),
str(PromptToken("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]])),
str(["bin"]),
],
)
def test_reseed_token(self):
res = parse_prompt_onnx(None, "foo <reseed:1:2:3:4:12345> bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo"]),
str(PromptToken("reseed", None, [1, 2, 3, 4, 12345])),
str(["bin"]),
],
)