1
0
Fork 0

add clip skip to custom prompt parser, collapse phrases with the same weight

This commit is contained in:
Sean Sube 2024-03-09 21:07:57 -06:00
parent 322aa3fd7f
commit 501dbff8a5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 126 additions and 41 deletions

View File

@ -110,6 +110,7 @@ class PromptSeed:
class Prompt:
clip_skip: int
networks: List[PromptNetwork]
positive_phrases: List[PromptPhrase]
negative_phrases: List[PromptPhrase]
@ -123,12 +124,14 @@ class Prompt:
negative_phrases: List[PromptPhrase],
region_prompts: List[PromptRegion],
region_seeds: List[PromptSeed],
clip_skip: int,
) -> None:
self.positive_phrases = positive_phrases
self.negative_prompt = negative_phrases
self.networks = networks or []
self.region_prompts = region_prompts or []
self.region_seeds = region_seeds or []
self.clip_skip = clip_skip
def __eq__(self, other: object) -> bool:
return (
@ -138,7 +141,8 @@ class Prompt:
and other.negative_phrases == self.negative_phrases
and other.region_prompts == self.region_prompts
and other.region_seeds == self.region_seeds
and other.clip_skip == self.clip_skip
)
def __repr__(self) -> str:
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds})"
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds}, {self.clip_skip})"

View File

@ -169,29 +169,85 @@ class OnnxPromptVisitor(PTNodeVisitor):
return children
def visit_phrase_inner(self, node, children):
if isinstance(children[0], PhraseNode):
return children[0]
elif isinstance(children[0], TokenNode):
return children[0]
else:
return PhraseNode(children[0])
return [
(
child
if isinstance(child, (PhraseNode, TokenNode, list))
else PhraseNode(child)
)
for child in children
]
def visit_pos_phrase(self, node, children):
c = children[0]
if isinstance(c, PhraseNode):
return PhraseNode(c.tokens, c.weight * self.pos_weight)
elif isinstance(c, str):
return PhraseNode(c, self.pos_weight)
print("positive phrase", len(children), children)
return parse_phrase(children, self.pos_weight)
def visit_neg_phrase(self, node, children):
c = children[0]
if isinstance(c, PhraseNode):
return PhraseNode(c.tokens, c.weight * self.neg_weight)
elif isinstance(c, str):
return PhraseNode(c, self.neg_weight)
print("negative phrase", len(children), children)
return parse_phrase(children, self.neg_weight)
def visit_phrase(self, node, children):
return children[0]
return list(flatten(children))
def visit_prompt(self, node, children):
return children
return collapse_phrases(list(flatten(children)))
def parse_phrase(child, weight):
if isinstance(child, PhraseNode):
return PhraseNode(child.tokens, child.weight * weight)
elif isinstance(child, str):
return PhraseNode([child], weight)
elif isinstance(child, list):
# TODO: when this is a list of strings, create a single node with all of them
# if all(isinstance(c, str) for c in child):
# return PhraseNode(child, weight)
return [parse_phrase(c, weight) for c in child]
def flatten(lst):
for el in lst:
if isinstance(el, list):
yield from flatten(el)
else:
yield el
def collapse_phrases(
nodes: List[Union[PhraseNode, str]]
) -> List[Union[PhraseNode, str]]:
"""
Combine phrases with the same weight.
"""
weight = None
tokens = []
phrases = []
def flush_tokens():
nonlocal weight, tokens
if len(tokens) > 0:
phrases.append(PhraseNode(tokens, weight))
tokens = []
weight = None
for node in nodes:
if isinstance(node, str):
node = PhraseNode([node])
elif isinstance(node, TokenNode):
flush_tokens()
phrases.append(node)
continue
if node.weight == weight:
tokens.extend(node.tokens)
else:
flush_tokens()
tokens = [*node.tokens]
weight = node.weight
flush_tokens()
return phrases

View File

@ -54,6 +54,7 @@ def compile_prompt_onnx(prompt: str) -> Prompt:
ast = parse_prompt_onnx(None, prompt)
tokens = [node for node in ast if isinstance(node, TokenNode)]
clip_skip = [token.rest[0] for token in tokens if token.type == "clip"]
networks = [
PromptNetwork(token.type, token.name, token.rest[0])
for token in tokens
@ -68,6 +69,7 @@ def compile_prompt_onnx(prompt: str) -> Prompt:
if isinstance(node, (list, PhraseNode, str))
]
phrases = list(flatten(phrases))
# TODO: collapse phrases with the same weight
return Prompt(
networks=networks,
@ -75,6 +77,7 @@ def compile_prompt_onnx(prompt: str) -> Prompt:
negative_phrases=[],
region_prompts=regions,
region_seeds=reseeds,
clip_skip=next(iter(clip_skip), 0),
)
@ -83,7 +86,7 @@ def compile_prompt_phrase(node: Union[PhraseNode, str]) -> PromptPhrase:
return [compile_prompt_phrase(subnode) for subnode in node]
if isinstance(node, str):
return PromptPhrase(node)
return PromptPhrase([node])
return PromptPhrase(node.tokens, node.weight)

View File

@ -125,24 +125,28 @@ class TestGenerateTileSpiral(unittest.TestCase):
class TestProcessTileStack(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_stack(
result = process_tile_stack(
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
32,
1,
[],
generate_tile_grid,
)
images = result.as_images()
self.assertEqual(blend[0].size, (64, 64))
self.assertEqual(len(images), 1)
self.assertEqual(images[0].size, (64, 64))
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_stack(
result = process_tile_stack(
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
32,
1,
[],
generate_tile_grid,
)
images = result.as_images()
self.assertEqual(blend[0].size, (72, 72))
self.assertEqual(len(images), 1)
self.assertEqual(images[0].size, (72, 72))

View File

@ -11,9 +11,9 @@ class ParserTests(unittest.TestCase):
self.assertListEqual(
res,
[
["foo"],
PhraseNode(["foo"]),
PhraseNode(["bar"], weight=1.5),
["bin"],
PhraseNode(["bin"]),
],
)
@ -22,9 +22,9 @@ class ParserTests(unittest.TestCase):
self.assertListEqual(
res,
[
["foo", "bar"],
PhraseNode(["foo", "bar"]),
PhraseNode(["middle", "words"], weight=1.5),
["bin", "bun"],
PhraseNode(["bin", "bun"]),
],
)
@ -33,9 +33,9 @@ class ParserTests(unittest.TestCase):
self.assertListEqual(
res,
[
["foo"],
PhraseNode(["foo"]),
PhraseNode(["bar"], weight=(1.5**3)),
["bin"],
PhraseNode(["bin"]),
],
)
@ -44,9 +44,9 @@ class ParserTests(unittest.TestCase):
self.assertListEqual(
res,
[
["foo"],
PhraseNode(["foo"]),
TokenNode("clip", "skip", 2),
["bin"],
PhraseNode(["bin"]),
],
)
@ -55,9 +55,9 @@ class ParserTests(unittest.TestCase):
self.assertListEqual(
res,
[
["foo"],
PhraseNode(["foo"]),
TokenNode("lora", "name", 1.5),
["bin"],
PhraseNode(["bin"]),
],
)
@ -68,9 +68,9 @@ class ParserTests(unittest.TestCase):
self.assertListEqual(
res,
[
["foo"],
PhraseNode(["foo"]),
TokenNode("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]]),
["bin"],
PhraseNode(["bin"]),
],
)
@ -79,20 +79,38 @@ class ParserTests(unittest.TestCase):
self.assertListEqual(
res,
[
["foo"],
PhraseNode(["foo"]),
TokenNode("reseed", None, [1, 2, 3, 4, 12345]),
["bin"],
PhraseNode(["bin"]),
],
)
def test_compile_basic(self):
def test_compile_tokens(self):
prompt = compile_prompt_onnx("foo <clip:skip:2> bar (baz) <lora:qux:1.5>")
self.assertEqual(prompt.clip_skip, 2)
self.assertEqual(prompt.networks, [PromptNetwork("lora", "qux", 1.5)])
self.assertEqual(
prompt.positive_phrases,
[
PromptPhrase("foo"),
PromptPhrase("bar"),
PromptPhrase(["foo"]),
PromptPhrase(["bar"]),
PromptPhrase(["baz"], weight=1.5),
],
)
def test_compile_weights(self):
prompt = compile_prompt_onnx("foo ((bar)) baz [[qux]] bun ([nest] me)")
self.assertEqual(
prompt.positive_phrases,
[
PromptPhrase(["foo"]),
PromptPhrase(["bar"], weight=2.25),
PromptPhrase(["baz"]),
PromptPhrase(["qux"], weight=0.25),
PromptPhrase(["bun"]),
PromptPhrase(["nest"], weight=0.75),
PromptPhrase(["me"], weight=1.5),
],
)