1
0
Fork 0

Compare commits

...

2 Commits

Author SHA1 Message Date
Sean Sube 5ae3d72968
fix Compel parsing when negative prompt is empty 2024-03-03 21:45:14 -06:00
Sean Sube 17749396b5
compile AST to structured prompt 2024-03-03 21:44:51 -06:00
5 changed files with 253 additions and 73 deletions

View File

@ -1,30 +1,144 @@
from typing import List, Optional
class NetworkWeight:
pass
class PromptNetwork:
type: str
name: str
strength: float
def __init__(self, type: str, name: str, strength: float) -> None:
self.type = type
self.name = name
self.strength = strength
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and other.type == self.type
and other.name == self.name
and other.strength == self.strength
)
def __repr__(self) -> str:
return f"PromptNetwork({self.type}, {self.name}, {self.strength})"
class PromptPhrase:
phrase: str
weight: float
def __init__(self, phrase: str, weight: float = 1.0) -> None:
self.phrase = phrase
self.weight = weight
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and other.phrase == self.phrase
and other.weight == self.weight
)
def __repr__(self) -> str:
return f"PromptPhrase({self.phrase}, {self.weight})"
class PromptRegion:
pass
top: int
left: int
bottom: int
right: int
prompt: str
append: bool
def __init__(
self,
top: int,
left: int,
bottom: int,
right: int,
prompt: str,
append: bool,
) -> None:
self.top = top
self.left = left
self.bottom = bottom
self.right = right
self.prompt = prompt
self.append = append
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and other.top == self.top
and other.left == self.left
and other.bottom == self.bottom
and other.right == self.right
and other.prompt == self.prompt
and other.append == self.append
)
def __repr__(self) -> str:
return f"PromptRegion({self.top}, {self.left}, {self.bottom}, {self.right}, {self.prompt}, {self.append})"
class PromptSeed:
pass
top: int
left: int
bottom: int
right: int
seed: int
def __init__(self, top: int, left: int, bottom: int, right: int, seed: int) -> None:
self.top = top
self.left = left
self.bottom = bottom
self.right = right
self.seed = seed
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and other.top == self.top
and other.left == self.left
and other.bottom == self.bottom
and other.right == self.right
and other.seed == self.seed
)
def __repr__(self) -> str:
return f"PromptSeed({self.top}, {self.left}, {self.bottom}, {self.right}, {self.seed})"
class StructuredPrompt:
prompt: str
negative_prompt: Optional[str]
networks: List[NetworkWeight]
class Prompt:
networks: List[PromptNetwork]
positive_phrases: List[PromptPhrase]
negative_phrases: List[PromptPhrase]
region_prompts: List[PromptRegion]
region_seeds: List[PromptSeed]
def __init__(
self, prompt: str, negative_prompt: Optional[str], networks: List[NetworkWeight]
self,
networks: Optional[List[PromptNetwork]],
positive_phrases: List[PromptPhrase],
negative_phrases: List[PromptPhrase],
region_prompts: List[PromptRegion],
region_seeds: List[PromptSeed],
) -> None:
self.prompt = prompt
self.negative_prompt = negative_prompt
self.positive_phrases = positive_phrases
self.negative_prompt = negative_phrases
self.networks = networks or []
self.region_prompts = []
self.region_seeds = []
self.region_prompts = region_prompts or []
self.region_seeds = region_seeds or []
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and other.networks == self.networks
and other.positive_phrases == self.positive_phrases
and other.negative_phrases == self.negative_phrases
and other.region_prompts == self.region_prompts
and other.region_seeds == self.region_seeds
)
def __repr__(self) -> str:
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds})"

View File

@ -98,8 +98,10 @@ def encode_prompt_compel(
prompt_embeds = compel(prompt)
if negative_prompt is not None:
negative_prompt_embeds = compel(negative_prompt)
if negative_prompt is None:
negative_prompt = ""
negative_prompt_embeds = compel(negative_prompt)
if negative_prompt_embeds is not None:
[prompt_embeds, negative_prompt_embeds] = (
@ -142,8 +144,10 @@ def encode_prompt_compel_sdxl(
prompt_embeds, prompt_pooled = compel(prompt)
negative_pooled = None
if negative_prompt is not None:
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
if negative_prompt is None:
negative_prompt = ""
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
if negative_prompt_embeds is not None:
[prompt_embeds, negative_prompt_embeds] = (

View File

@ -99,7 +99,7 @@ def prompt():
return OneOrMore(phrase), EOF
class PromptPhrase:
class PhraseNode:
def __init__(self, tokens: Union[List[str], str], weight: float = 1.0) -> None:
self.tokens = tokens
self.weight = weight
@ -114,20 +114,20 @@ class PromptPhrase:
return False
class PromptToken:
def __init__(self, token_type: str, token_name: str, *rest):
self.token_type = token_type
self.token_name = token_name
class TokenNode:
def __init__(self, type: str, name: str, *rest):
self.type = type
self.name = name
self.rest = rest
def __repr__(self) -> str:
return f"<{self.token_type}:{self.token_name}:{self.rest}>"
return f"<{self.type}:{self.name}:{self.rest}>"
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return (
other.token_type == self.token_type
and other.token_name == self.token_name
other.type == self.type
and other.name == self.name
and other.rest == self.rest
)
@ -151,44 +151,44 @@ class OnnxPromptVisitor(PTNodeVisitor):
return str(node.value)
def visit_token_clip_skip(self, node, children):
return PromptToken("clip", "skip", children[0])
return TokenNode("clip", "skip", children[0])
def visit_token_inversion(self, node, children):
return PromptToken("inversion", children[0][0], children[1])
return TokenNode("inversion", children[0][0], children[1])
def visit_token_lora(self, node, children):
return PromptToken("lora", children[0][0], children[1])
return TokenNode("lora", children[0][0], children[1])
def visit_token_region(self, node, children):
return PromptToken("region", None, children)
return TokenNode("region", None, children)
def visit_token_reseed(self, node, children):
return PromptToken("reseed", None, children)
return TokenNode("reseed", None, children)
def visit_token_run(self, node, children):
return children
def visit_phrase_inner(self, node, children):
if isinstance(children[0], PromptPhrase):
if isinstance(children[0], PhraseNode):
return children[0]
elif isinstance(children[0], PromptToken):
elif isinstance(children[0], TokenNode):
return children[0]
else:
return PromptPhrase(children[0])
return PhraseNode(children[0])
def visit_pos_phrase(self, node, children):
c = children[0]
if isinstance(c, PromptPhrase):
return PromptPhrase(c.tokens, c.weight * self.pos_weight)
if isinstance(c, PhraseNode):
return PhraseNode(c.tokens, c.weight * self.pos_weight)
elif isinstance(c, str):
return PromptPhrase(c, self.pos_weight)
return PhraseNode(c, self.pos_weight)
def visit_neg_phrase(self, node, children):
c = children[0]
if isinstance(c, PromptPhrase):
return PromptPhrase(c.tokens, c.weight * self.neg_weight)
if isinstance(c, PhraseNode):
return PhraseNode(c.tokens, c.weight * self.neg_weight)
elif isinstance(c, str):
return PromptPhrase(c, self.neg_weight)
return PhraseNode(c, self.neg_weight)
def visit_phrase(self, node, children):
return children[0]

View File

@ -1,9 +1,10 @@
from typing import Literal
from typing import Literal, Union
import numpy as np
from arpeggio import ParserPython, visit_parse_tree
from .grammar import OnnxPromptVisitor
from .base import Prompt, PromptNetwork, PromptPhrase, PromptRegion, PromptSeed
from .grammar import OnnxPromptVisitor, PhraseNode, TokenNode
from .grammar import prompt as prompt_base
@ -22,8 +23,10 @@ def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:
parser = ParserPython(prompt_base, debug=debug)
visitor = OnnxPromptVisitor()
ast = parser.parse(prompt)
return visit_parse_tree(ast, visitor)
lst = parser.parse(prompt)
ast = visit_parse_tree(lst, visitor)
return ast
def parse_prompt_vanilla(pipeline, prompt: str) -> np.ndarray:
@ -45,3 +48,49 @@ def parse_prompt(
return parse_prompt_vanilla(pipeline, prompt)
else:
raise ValueError("invalid prompt parser")
def compile_prompt_onnx(prompt: str) -> Prompt:
ast = parse_prompt_onnx(None, prompt)
tokens = [node for node in ast if isinstance(node, TokenNode)]
networks = [
PromptNetwork(token.type, token.name, token.rest[0])
for token in tokens
if token.type in ["lora", "inversion"]
]
regions = [PromptRegion(*token.rest) for token in tokens if token.type == "region"]
reseeds = [PromptSeed(*token.rest) for token in tokens if token.type == "reseed"]
phrases = [
compile_prompt_phrase(node)
for node in ast
if isinstance(node, (list, PhraseNode, str))
]
phrases = list(flatten(phrases))
return Prompt(
networks=networks,
positive_phrases=phrases,
negative_phrases=[],
region_prompts=regions,
region_seeds=reseeds,
)
def compile_prompt_phrase(node: Union[PhraseNode, str]) -> PromptPhrase:
if isinstance(node, list):
return [compile_prompt_phrase(subnode) for subnode in node]
if isinstance(node, str):
return PromptPhrase(node)
return PromptPhrase(node.tokens, node.weight)
def flatten(val):
if isinstance(val, list):
for subval in val:
yield from flatten(subval)
else:
yield val

View File

@ -1,62 +1,63 @@
import unittest
from onnx_web.prompt.grammar import PromptPhrase, PromptToken
from onnx_web.prompt.parser import parse_prompt_onnx
from onnx_web.prompt.base import PromptNetwork, PromptPhrase
from onnx_web.prompt.grammar import PhraseNode, TokenNode
from onnx_web.prompt.parser import compile_prompt_onnx, parse_prompt_onnx
class ParserTests(unittest.TestCase):
def test_single_word_phrase(self):
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
res,
[
str(["foo"]),
str(PromptPhrase(["bar"], weight=1.5)),
str(["bin"]),
["foo"],
PhraseNode(["bar"], weight=1.5),
["bin"],
],
)
def test_multi_word_phrase(self):
res = parse_prompt_onnx(None, "foo bar (middle words) bin bun", debug=False)
self.assertListEqual(
[str(i) for i in res],
res,
[
str(["foo", "bar"]),
str(PromptPhrase(["middle", "words"], weight=1.5)),
str(["bin", "bun"]),
["foo", "bar"],
PhraseNode(["middle", "words"], weight=1.5),
["bin", "bun"],
],
)
def test_nested_phrase(self):
res = parse_prompt_onnx(None, "foo (((bar))) bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
res,
[
str(["foo"]),
str(PromptPhrase(["bar"], weight=(1.5**3))),
str(["bin"]),
["foo"],
PhraseNode(["bar"], weight=(1.5**3)),
["bin"],
],
)
def test_clip_skip_token(self):
res = parse_prompt_onnx(None, "foo <clip:skip:2> bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
res,
[
str(["foo"]),
str(PromptToken("clip", "skip", 2)),
str(["bin"]),
["foo"],
TokenNode("clip", "skip", 2),
["bin"],
],
)
def test_lora_token(self):
res = parse_prompt_onnx(None, "foo <lora:name:1.5> bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
res,
[
str(["foo"]),
str(PromptToken("lora", "name", 1.5)),
str(["bin"]),
["foo"],
TokenNode("lora", "name", 1.5),
["bin"],
],
)
@ -65,21 +66,33 @@ class ParserTests(unittest.TestCase):
None, "foo <region:1:2:3:4:0.5:0.75:prompt> bin", debug=False
)
self.assertListEqual(
[str(i) for i in res],
res,
[
str(["foo"]),
str(PromptToken("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]])),
str(["bin"]),
["foo"],
TokenNode("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]]),
["bin"],
],
)
def test_reseed_token(self):
res = parse_prompt_onnx(None, "foo <reseed:1:2:3:4:12345> bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
res,
[
str(["foo"]),
str(PromptToken("reseed", None, [1, 2, 3, 4, 12345])),
str(["bin"]),
["foo"],
TokenNode("reseed", None, [1, 2, 3, 4, 12345]),
["bin"],
],
)
def test_compile_basic(self):
prompt = compile_prompt_onnx("foo <clip:skip:2> bar (baz) <lora:qux:1.5>")
self.assertEqual(prompt.networks, [PromptNetwork("lora", "qux", 1.5)])
self.assertEqual(
prompt.positive_phrases,
[
PromptPhrase("foo"),
PromptPhrase("bar"),
PromptPhrase(["baz"], weight=1.5),
],
)