add region and reseed tokens to arpeggio parser
This commit is contained in:
parent
4713169ad9
commit
86a2db1c1a
|
@ -19,7 +19,7 @@ def get_inference_session(model):
|
|||
raise ValueError("Model does not have an inference session")
|
||||
|
||||
|
||||
def wrap_encoder(text_encoder, sdxl=False):
|
||||
def wrap_encoder(text_encoder):
|
||||
class WrappedEncoder:
|
||||
device = "cpu"
|
||||
|
||||
|
@ -120,8 +120,8 @@ def encode_prompt_compel_sdxl(
|
|||
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
wrapped_encoder = wrap_encoder(self.text_encoder, sdxl=True)
|
||||
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2, sdxl=True)
|
||||
wrapped_encoder = wrap_encoder(self.text_encoder)
|
||||
wrapped_encoder_2 = wrap_encoder(self.text_encoder_2)
|
||||
compel = Compel(
|
||||
tokenizer=[self.tokenizer, self.tokenizer_2],
|
||||
text_encoder=[wrapped_encoder, wrapped_encoder_2],
|
||||
|
|
|
@ -19,6 +19,10 @@ def decimal():
|
|||
return RegExMatch(r"\d+\.\d*")
|
||||
|
||||
|
||||
def integer():
|
||||
return RegExMatch(r"\d+")
|
||||
|
||||
|
||||
def token_inversion():
|
||||
return ("inversion", token_delimiter, token_run, token_delimiter, decimal)
|
||||
|
||||
|
@ -27,8 +31,44 @@ def token_lora():
|
|||
return ("lora", token_delimiter, token_run, token_delimiter, decimal)
|
||||
|
||||
|
||||
def token_region():
|
||||
return (
|
||||
"region",
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
decimal,
|
||||
token_delimiter,
|
||||
decimal,
|
||||
token_delimiter,
|
||||
token_run,
|
||||
)
|
||||
|
||||
|
||||
def token_reseed():
|
||||
return (
|
||||
"reseed",
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
integer,
|
||||
token_delimiter,
|
||||
integer,
|
||||
)
|
||||
|
||||
|
||||
def token_inner():
|
||||
return [token_inversion, token_lora]
|
||||
return [token_inversion, token_lora, token_region, token_reseed]
|
||||
|
||||
|
||||
def phrase_inner():
|
||||
|
@ -100,15 +140,24 @@ class OnnxPromptVisitor(PTNodeVisitor):
|
|||
def visit_decimal(self, node, children):
|
||||
return float(node.value)
|
||||
|
||||
def visit_integer(self, node, children):
|
||||
return int(node.value)
|
||||
|
||||
def visit_token(self, node, children):
|
||||
return str(node.value)
|
||||
|
||||
def visit_token_inversion(self, node, children):
|
||||
return PromptToken("lora", children[0][0], children[1])
|
||||
return PromptToken("inversion", children[0][0], children[1])
|
||||
|
||||
def visit_token_lora(self, node, children):
|
||||
return PromptToken("lora", children[0][0], children[1])
|
||||
|
||||
def visit_token_region(self, node, children):
|
||||
return PromptToken("region", None, children)
|
||||
|
||||
def visit_token_reseed(self, node, children):
|
||||
return PromptToken("reseed", None, children)
|
||||
|
||||
def visit_token_run(self, node, children):
|
||||
return children
|
||||
|
||||
|
|
|
@ -48,3 +48,27 @@ class ParserTests(unittest.TestCase):
|
|||
str(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_region_token(self):
|
||||
res = parse_prompt_onnx(
|
||||
None, "foo <region:1:2:3:4:0.5:0.75:prompt> bin", debug=False
|
||||
)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptToken("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]])),
|
||||
str(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_reseed_token(self):
|
||||
res = parse_prompt_onnx(None, "foo <reseed:1:2:3:4:12345> bin", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptToken("reseed", None, [1, 2, 3, 4, 12345])),
|
||||
str(["bin"]),
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue