Compare commits
2 Commits
fceffb8040
...
5ae3d72968
Author | SHA1 | Date |
---|---|---|
Sean Sube | 5ae3d72968 | |
Sean Sube | 17749396b5 |
|
@ -1,30 +1,144 @@
|
|||
from typing import List, Optional
|
||||
|
||||
|
||||
class NetworkWeight:
|
||||
pass
|
||||
class PromptNetwork:
|
||||
type: str
|
||||
name: str
|
||||
strength: float
|
||||
|
||||
def __init__(self, type: str, name: str, strength: float) -> None:
|
||||
self.type = type
|
||||
self.name = name
|
||||
self.strength = strength
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and other.type == self.type
|
||||
and other.name == self.name
|
||||
and other.strength == self.strength
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PromptNetwork({self.type}, {self.name}, {self.strength})"
|
||||
|
||||
|
||||
class PromptPhrase:
|
||||
phrase: str
|
||||
weight: float
|
||||
|
||||
def __init__(self, phrase: str, weight: float = 1.0) -> None:
|
||||
self.phrase = phrase
|
||||
self.weight = weight
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and other.phrase == self.phrase
|
||||
and other.weight == self.weight
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PromptPhrase({self.phrase}, {self.weight})"
|
||||
|
||||
|
||||
class PromptRegion:
|
||||
pass
|
||||
top: int
|
||||
left: int
|
||||
bottom: int
|
||||
right: int
|
||||
prompt: str
|
||||
append: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top: int,
|
||||
left: int,
|
||||
bottom: int,
|
||||
right: int,
|
||||
prompt: str,
|
||||
append: bool,
|
||||
) -> None:
|
||||
self.top = top
|
||||
self.left = left
|
||||
self.bottom = bottom
|
||||
self.right = right
|
||||
self.prompt = prompt
|
||||
self.append = append
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and other.top == self.top
|
||||
and other.left == self.left
|
||||
and other.bottom == self.bottom
|
||||
and other.right == self.right
|
||||
and other.prompt == self.prompt
|
||||
and other.append == self.append
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PromptRegion({self.top}, {self.left}, {self.bottom}, {self.right}, {self.prompt}, {self.append})"
|
||||
|
||||
|
||||
class PromptSeed:
|
||||
pass
|
||||
top: int
|
||||
left: int
|
||||
bottom: int
|
||||
right: int
|
||||
seed: int
|
||||
|
||||
def __init__(self, top: int, left: int, bottom: int, right: int, seed: int) -> None:
|
||||
self.top = top
|
||||
self.left = left
|
||||
self.bottom = bottom
|
||||
self.right = right
|
||||
self.seed = seed
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and other.top == self.top
|
||||
and other.left == self.left
|
||||
and other.bottom == self.bottom
|
||||
and other.right == self.right
|
||||
and other.seed == self.seed
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PromptSeed({self.top}, {self.left}, {self.bottom}, {self.right}, {self.seed})"
|
||||
|
||||
|
||||
class StructuredPrompt:
|
||||
prompt: str
|
||||
negative_prompt: Optional[str]
|
||||
networks: List[NetworkWeight]
|
||||
class Prompt:
|
||||
networks: List[PromptNetwork]
|
||||
positive_phrases: List[PromptPhrase]
|
||||
negative_phrases: List[PromptPhrase]
|
||||
region_prompts: List[PromptRegion]
|
||||
region_seeds: List[PromptSeed]
|
||||
|
||||
def __init__(
|
||||
self, prompt: str, negative_prompt: Optional[str], networks: List[NetworkWeight]
|
||||
self,
|
||||
networks: Optional[List[PromptNetwork]],
|
||||
positive_phrases: List[PromptPhrase],
|
||||
negative_phrases: List[PromptPhrase],
|
||||
region_prompts: List[PromptRegion],
|
||||
region_seeds: List[PromptSeed],
|
||||
) -> None:
|
||||
self.prompt = prompt
|
||||
self.negative_prompt = negative_prompt
|
||||
self.positive_phrases = positive_phrases
|
||||
self.negative_prompt = negative_phrases
|
||||
self.networks = networks or []
|
||||
self.region_prompts = []
|
||||
self.region_seeds = []
|
||||
self.region_prompts = region_prompts or []
|
||||
self.region_seeds = region_seeds or []
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and other.networks == self.networks
|
||||
and other.positive_phrases == self.positive_phrases
|
||||
and other.negative_phrases == self.negative_phrases
|
||||
and other.region_prompts == self.region_prompts
|
||||
and other.region_seeds == self.region_seeds
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds})"
|
||||
|
|
|
@ -98,8 +98,10 @@ def encode_prompt_compel(
|
|||
|
||||
prompt_embeds = compel(prompt)
|
||||
|
||||
if negative_prompt is not None:
|
||||
negative_prompt_embeds = compel(negative_prompt)
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
negative_prompt_embeds = compel(negative_prompt)
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
[prompt_embeds, negative_prompt_embeds] = (
|
||||
|
@ -142,8 +144,10 @@ def encode_prompt_compel_sdxl(
|
|||
prompt_embeds, prompt_pooled = compel(prompt)
|
||||
|
||||
negative_pooled = None
|
||||
if negative_prompt is not None:
|
||||
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
[prompt_embeds, negative_prompt_embeds] = (
|
||||
|
|
|
@ -99,7 +99,7 @@ def prompt():
|
|||
return OneOrMore(phrase), EOF
|
||||
|
||||
|
||||
class PromptPhrase:
|
||||
class PhraseNode:
|
||||
def __init__(self, tokens: Union[List[str], str], weight: float = 1.0) -> None:
|
||||
self.tokens = tokens
|
||||
self.weight = weight
|
||||
|
@ -114,20 +114,20 @@ class PromptPhrase:
|
|||
return False
|
||||
|
||||
|
||||
class PromptToken:
|
||||
def __init__(self, token_type: str, token_name: str, *rest):
|
||||
self.token_type = token_type
|
||||
self.token_name = token_name
|
||||
class TokenNode:
|
||||
def __init__(self, type: str, name: str, *rest):
|
||||
self.type = type
|
||||
self.name = name
|
||||
self.rest = rest
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.token_type}:{self.token_name}:{self.rest}>"
|
||||
return f"<{self.type}:{self.name}:{self.rest}>"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return (
|
||||
other.token_type == self.token_type
|
||||
and other.token_name == self.token_name
|
||||
other.type == self.type
|
||||
and other.name == self.name
|
||||
and other.rest == self.rest
|
||||
)
|
||||
|
||||
|
@ -151,44 +151,44 @@ class OnnxPromptVisitor(PTNodeVisitor):
|
|||
return str(node.value)
|
||||
|
||||
def visit_token_clip_skip(self, node, children):
|
||||
return PromptToken("clip", "skip", children[0])
|
||||
return TokenNode("clip", "skip", children[0])
|
||||
|
||||
def visit_token_inversion(self, node, children):
|
||||
return PromptToken("inversion", children[0][0], children[1])
|
||||
return TokenNode("inversion", children[0][0], children[1])
|
||||
|
||||
def visit_token_lora(self, node, children):
|
||||
return PromptToken("lora", children[0][0], children[1])
|
||||
return TokenNode("lora", children[0][0], children[1])
|
||||
|
||||
def visit_token_region(self, node, children):
|
||||
return PromptToken("region", None, children)
|
||||
return TokenNode("region", None, children)
|
||||
|
||||
def visit_token_reseed(self, node, children):
|
||||
return PromptToken("reseed", None, children)
|
||||
return TokenNode("reseed", None, children)
|
||||
|
||||
def visit_token_run(self, node, children):
|
||||
return children
|
||||
|
||||
def visit_phrase_inner(self, node, children):
|
||||
if isinstance(children[0], PromptPhrase):
|
||||
if isinstance(children[0], PhraseNode):
|
||||
return children[0]
|
||||
elif isinstance(children[0], PromptToken):
|
||||
elif isinstance(children[0], TokenNode):
|
||||
return children[0]
|
||||
else:
|
||||
return PromptPhrase(children[0])
|
||||
return PhraseNode(children[0])
|
||||
|
||||
def visit_pos_phrase(self, node, children):
|
||||
c = children[0]
|
||||
if isinstance(c, PromptPhrase):
|
||||
return PromptPhrase(c.tokens, c.weight * self.pos_weight)
|
||||
if isinstance(c, PhraseNode):
|
||||
return PhraseNode(c.tokens, c.weight * self.pos_weight)
|
||||
elif isinstance(c, str):
|
||||
return PromptPhrase(c, self.pos_weight)
|
||||
return PhraseNode(c, self.pos_weight)
|
||||
|
||||
def visit_neg_phrase(self, node, children):
|
||||
c = children[0]
|
||||
if isinstance(c, PromptPhrase):
|
||||
return PromptPhrase(c.tokens, c.weight * self.neg_weight)
|
||||
if isinstance(c, PhraseNode):
|
||||
return PhraseNode(c.tokens, c.weight * self.neg_weight)
|
||||
elif isinstance(c, str):
|
||||
return PromptPhrase(c, self.neg_weight)
|
||||
return PhraseNode(c, self.neg_weight)
|
||||
|
||||
def visit_phrase(self, node, children):
|
||||
return children[0]
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Literal
|
||||
from typing import Literal, Union
|
||||
|
||||
import numpy as np
|
||||
from arpeggio import ParserPython, visit_parse_tree
|
||||
|
||||
from .grammar import OnnxPromptVisitor
|
||||
from .base import Prompt, PromptNetwork, PromptPhrase, PromptRegion, PromptSeed
|
||||
from .grammar import OnnxPromptVisitor, PhraseNode, TokenNode
|
||||
from .grammar import prompt as prompt_base
|
||||
|
||||
|
||||
|
@ -22,8 +23,10 @@ def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:
|
|||
parser = ParserPython(prompt_base, debug=debug)
|
||||
visitor = OnnxPromptVisitor()
|
||||
|
||||
ast = parser.parse(prompt)
|
||||
return visit_parse_tree(ast, visitor)
|
||||
lst = parser.parse(prompt)
|
||||
ast = visit_parse_tree(lst, visitor)
|
||||
|
||||
return ast
|
||||
|
||||
|
||||
def parse_prompt_vanilla(pipeline, prompt: str) -> np.ndarray:
|
||||
|
@ -45,3 +48,49 @@ def parse_prompt(
|
|||
return parse_prompt_vanilla(pipeline, prompt)
|
||||
else:
|
||||
raise ValueError("invalid prompt parser")
|
||||
|
||||
|
||||
def compile_prompt_onnx(prompt: str) -> Prompt:
|
||||
ast = parse_prompt_onnx(None, prompt)
|
||||
|
||||
tokens = [node for node in ast if isinstance(node, TokenNode)]
|
||||
networks = [
|
||||
PromptNetwork(token.type, token.name, token.rest[0])
|
||||
for token in tokens
|
||||
if token.type in ["lora", "inversion"]
|
||||
]
|
||||
regions = [PromptRegion(*token.rest) for token in tokens if token.type == "region"]
|
||||
reseeds = [PromptSeed(*token.rest) for token in tokens if token.type == "reseed"]
|
||||
|
||||
phrases = [
|
||||
compile_prompt_phrase(node)
|
||||
for node in ast
|
||||
if isinstance(node, (list, PhraseNode, str))
|
||||
]
|
||||
phrases = list(flatten(phrases))
|
||||
|
||||
return Prompt(
|
||||
networks=networks,
|
||||
positive_phrases=phrases,
|
||||
negative_phrases=[],
|
||||
region_prompts=regions,
|
||||
region_seeds=reseeds,
|
||||
)
|
||||
|
||||
|
||||
def compile_prompt_phrase(node: Union[PhraseNode, str]) -> PromptPhrase:
|
||||
if isinstance(node, list):
|
||||
return [compile_prompt_phrase(subnode) for subnode in node]
|
||||
|
||||
if isinstance(node, str):
|
||||
return PromptPhrase(node)
|
||||
|
||||
return PromptPhrase(node.tokens, node.weight)
|
||||
|
||||
|
||||
def flatten(val):
|
||||
if isinstance(val, list):
|
||||
for subval in val:
|
||||
yield from flatten(subval)
|
||||
else:
|
||||
yield val
|
||||
|
|
|
@ -1,62 +1,63 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.prompt.grammar import PromptPhrase, PromptToken
|
||||
from onnx_web.prompt.parser import parse_prompt_onnx
|
||||
from onnx_web.prompt.base import PromptNetwork, PromptPhrase
|
||||
from onnx_web.prompt.grammar import PhraseNode, TokenNode
|
||||
from onnx_web.prompt.parser import compile_prompt_onnx, parse_prompt_onnx
|
||||
|
||||
|
||||
class ParserTests(unittest.TestCase):
|
||||
def test_single_word_phrase(self):
|
||||
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
res,
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptPhrase(["bar"], weight=1.5)),
|
||||
str(["bin"]),
|
||||
["foo"],
|
||||
PhraseNode(["bar"], weight=1.5),
|
||||
["bin"],
|
||||
],
|
||||
)
|
||||
|
||||
def test_multi_word_phrase(self):
|
||||
res = parse_prompt_onnx(None, "foo bar (middle words) bin bun", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
res,
|
||||
[
|
||||
str(["foo", "bar"]),
|
||||
str(PromptPhrase(["middle", "words"], weight=1.5)),
|
||||
str(["bin", "bun"]),
|
||||
["foo", "bar"],
|
||||
PhraseNode(["middle", "words"], weight=1.5),
|
||||
["bin", "bun"],
|
||||
],
|
||||
)
|
||||
|
||||
def test_nested_phrase(self):
|
||||
res = parse_prompt_onnx(None, "foo (((bar))) bin", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
res,
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptPhrase(["bar"], weight=(1.5**3))),
|
||||
str(["bin"]),
|
||||
["foo"],
|
||||
PhraseNode(["bar"], weight=(1.5**3)),
|
||||
["bin"],
|
||||
],
|
||||
)
|
||||
|
||||
def test_clip_skip_token(self):
|
||||
res = parse_prompt_onnx(None, "foo <clip:skip:2> bin", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
res,
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptToken("clip", "skip", 2)),
|
||||
str(["bin"]),
|
||||
["foo"],
|
||||
TokenNode("clip", "skip", 2),
|
||||
["bin"],
|
||||
],
|
||||
)
|
||||
|
||||
def test_lora_token(self):
|
||||
res = parse_prompt_onnx(None, "foo <lora:name:1.5> bin", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
res,
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptToken("lora", "name", 1.5)),
|
||||
str(["bin"]),
|
||||
["foo"],
|
||||
TokenNode("lora", "name", 1.5),
|
||||
["bin"],
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -65,21 +66,33 @@ class ParserTests(unittest.TestCase):
|
|||
None, "foo <region:1:2:3:4:0.5:0.75:prompt> bin", debug=False
|
||||
)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
res,
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptToken("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]])),
|
||||
str(["bin"]),
|
||||
["foo"],
|
||||
TokenNode("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]]),
|
||||
["bin"],
|
||||
],
|
||||
)
|
||||
|
||||
def test_reseed_token(self):
|
||||
res = parse_prompt_onnx(None, "foo <reseed:1:2:3:4:12345> bin", debug=False)
|
||||
self.assertListEqual(
|
||||
[str(i) for i in res],
|
||||
res,
|
||||
[
|
||||
str(["foo"]),
|
||||
str(PromptToken("reseed", None, [1, 2, 3, 4, 12345])),
|
||||
str(["bin"]),
|
||||
["foo"],
|
||||
TokenNode("reseed", None, [1, 2, 3, 4, 12345]),
|
||||
["bin"],
|
||||
],
|
||||
)
|
||||
|
||||
def test_compile_basic(self):
|
||||
prompt = compile_prompt_onnx("foo <clip:skip:2> bar (baz) <lora:qux:1.5>")
|
||||
self.assertEqual(prompt.networks, [PromptNetwork("lora", "qux", 1.5)])
|
||||
self.assertEqual(
|
||||
prompt.positive_phrases,
|
||||
[
|
||||
PromptPhrase("foo"),
|
||||
PromptPhrase("bar"),
|
||||
PromptPhrase(["baz"], weight=1.5),
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue