add clip skip to custom prompt parser, collapse phrases with the same weight
This commit is contained in:
parent
322aa3fd7f
commit
501dbff8a5
|
@ -110,6 +110,7 @@ class PromptSeed:
|
|||
|
||||
|
||||
class Prompt:
|
||||
clip_skip: int
|
||||
networks: List[PromptNetwork]
|
||||
positive_phrases: List[PromptPhrase]
|
||||
negative_phrases: List[PromptPhrase]
|
||||
|
@ -123,12 +124,14 @@ class Prompt:
|
|||
negative_phrases: List[PromptPhrase],
|
||||
region_prompts: List[PromptRegion],
|
||||
region_seeds: List[PromptSeed],
|
||||
clip_skip: int,
|
||||
) -> None:
|
||||
self.positive_phrases = positive_phrases
|
||||
self.negative_prompt = negative_phrases
|
||||
self.networks = networks or []
|
||||
self.region_prompts = region_prompts or []
|
||||
self.region_seeds = region_seeds or []
|
||||
self.clip_skip = clip_skip
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
|
@ -138,7 +141,8 @@ class Prompt:
|
|||
and other.negative_phrases == self.negative_phrases
|
||||
and other.region_prompts == self.region_prompts
|
||||
and other.region_seeds == self.region_seeds
|
||||
and other.clip_skip == self.clip_skip
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds})"
|
||||
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds}, {self.clip_skip})"
|
||||
|
|
|
@ -169,29 +169,85 @@ class OnnxPromptVisitor(PTNodeVisitor):
|
|||
return children
|
||||
|
||||
def visit_phrase_inner(self, node, children):
|
||||
if isinstance(children[0], PhraseNode):
|
||||
return children[0]
|
||||
elif isinstance(children[0], TokenNode):
|
||||
return children[0]
|
||||
else:
|
||||
return PhraseNode(children[0])
|
||||
return [
|
||||
(
|
||||
child
|
||||
if isinstance(child, (PhraseNode, TokenNode, list))
|
||||
else PhraseNode(child)
|
||||
)
|
||||
for child in children
|
||||
]
|
||||
|
||||
def visit_pos_phrase(self, node, children):
|
||||
c = children[0]
|
||||
if isinstance(c, PhraseNode):
|
||||
return PhraseNode(c.tokens, c.weight * self.pos_weight)
|
||||
elif isinstance(c, str):
|
||||
return PhraseNode(c, self.pos_weight)
|
||||
print("positive phrase", len(children), children)
|
||||
|
||||
return parse_phrase(children, self.pos_weight)
|
||||
|
||||
def visit_neg_phrase(self, node, children):
|
||||
c = children[0]
|
||||
if isinstance(c, PhraseNode):
|
||||
return PhraseNode(c.tokens, c.weight * self.neg_weight)
|
||||
elif isinstance(c, str):
|
||||
return PhraseNode(c, self.neg_weight)
|
||||
print("negative phrase", len(children), children)
|
||||
|
||||
return parse_phrase(children, self.neg_weight)
|
||||
|
||||
def visit_phrase(self, node, children):
|
||||
return children[0]
|
||||
return list(flatten(children))
|
||||
|
||||
def visit_prompt(self, node, children):
|
||||
return children
|
||||
return collapse_phrases(list(flatten(children)))
|
||||
|
||||
|
||||
def parse_phrase(child, weight):
|
||||
if isinstance(child, PhraseNode):
|
||||
return PhraseNode(child.tokens, child.weight * weight)
|
||||
elif isinstance(child, str):
|
||||
return PhraseNode([child], weight)
|
||||
elif isinstance(child, list):
|
||||
# TODO: when this is a list of strings, create a single node with all of them
|
||||
# if all(isinstance(c, str) for c in child):
|
||||
# return PhraseNode(child, weight)
|
||||
|
||||
return [parse_phrase(c, weight) for c in child]
|
||||
|
||||
|
||||
def flatten(lst):
|
||||
for el in lst:
|
||||
if isinstance(el, list):
|
||||
yield from flatten(el)
|
||||
else:
|
||||
yield el
|
||||
|
||||
|
||||
def collapse_phrases(
|
||||
nodes: List[Union[PhraseNode, str]]
|
||||
) -> List[Union[PhraseNode, str]]:
|
||||
"""
|
||||
Combine phrases with the same weight.
|
||||
"""
|
||||
|
||||
weight = None
|
||||
tokens = []
|
||||
phrases = []
|
||||
|
||||
def flush_tokens():
|
||||
nonlocal weight, tokens
|
||||
if len(tokens) > 0:
|
||||
phrases.append(PhraseNode(tokens, weight))
|
||||
tokens = []
|
||||
weight = None
|
||||
|
||||
for node in nodes:
|
||||
if isinstance(node, str):
|
||||
node = PhraseNode([node])
|
||||
elif isinstance(node, TokenNode):
|
||||
flush_tokens()
|
||||
phrases.append(node)
|
||||
continue
|
||||
|
||||
if node.weight == weight:
|
||||
tokens.extend(node.tokens)
|
||||
else:
|
||||
flush_tokens()
|
||||
tokens = [*node.tokens]
|
||||
weight = node.weight
|
||||
|
||||
flush_tokens()
|
||||
return phrases
|
||||
|
|
|
@ -54,6 +54,7 @@ def compile_prompt_onnx(prompt: str) -> Prompt:
|
|||
ast = parse_prompt_onnx(None, prompt)
|
||||
|
||||
tokens = [node for node in ast if isinstance(node, TokenNode)]
|
||||
clip_skip = [token.rest[0] for token in tokens if token.type == "clip"]
|
||||
networks = [
|
||||
PromptNetwork(token.type, token.name, token.rest[0])
|
||||
for token in tokens
|
||||
|
@ -68,6 +69,7 @@ def compile_prompt_onnx(prompt: str) -> Prompt:
|
|||
if isinstance(node, (list, PhraseNode, str))
|
||||
]
|
||||
phrases = list(flatten(phrases))
|
||||
# TODO: collapse phrases with the same weight
|
||||
|
||||
return Prompt(
|
||||
networks=networks,
|
||||
|
@ -75,6 +77,7 @@ def compile_prompt_onnx(prompt: str) -> Prompt:
|
|||
negative_phrases=[],
|
||||
region_prompts=regions,
|
||||
region_seeds=reseeds,
|
||||
clip_skip=next(iter(clip_skip), 0),
|
||||
)
|
||||
|
||||
|
||||
|
@ -83,7 +86,7 @@ def compile_prompt_phrase(node: Union[PhraseNode, str]) -> PromptPhrase:
|
|||
return [compile_prompt_phrase(subnode) for subnode in node]
|
||||
|
||||
if isinstance(node, str):
|
||||
return PromptPhrase(node)
|
||||
return PromptPhrase([node])
|
||||
|
||||
return PromptPhrase(node.tokens, node.weight)
|
||||
|
||||
|
|
|
@ -125,24 +125,28 @@ class TestGenerateTileSpiral(unittest.TestCase):
|
|||
class TestProcessTileStack(unittest.TestCase):
|
||||
def test_grid_full(self):
|
||||
source = Image.new("RGB", (64, 64))
|
||||
blend = process_tile_stack(
|
||||
result = process_tile_stack(
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
32,
|
||||
1,
|
||||
[],
|
||||
generate_tile_grid,
|
||||
)
|
||||
images = result.as_images()
|
||||
|
||||
self.assertEqual(blend[0].size, (64, 64))
|
||||
self.assertEqual(len(images), 1)
|
||||
self.assertEqual(images[0].size, (64, 64))
|
||||
|
||||
def test_grid_partial(self):
|
||||
source = Image.new("RGB", (72, 72))
|
||||
blend = process_tile_stack(
|
||||
result = process_tile_stack(
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
32,
|
||||
1,
|
||||
[],
|
||||
generate_tile_grid,
|
||||
)
|
||||
images = result.as_images()
|
||||
|
||||
self.assertEqual(blend[0].size, (72, 72))
|
||||
self.assertEqual(len(images), 1)
|
||||
self.assertEqual(images[0].size, (72, 72))
|
||||
|
|
|
@ -11,9 +11,9 @@ class ParserTests(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
res,
|
||||
[
|
||||
["foo"],
|
||||
PhraseNode(["foo"]),
|
||||
PhraseNode(["bar"], weight=1.5),
|
||||
["bin"],
|
||||
PhraseNode(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,9 +22,9 @@ class ParserTests(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
res,
|
||||
[
|
||||
["foo", "bar"],
|
||||
PhraseNode(["foo", "bar"]),
|
||||
PhraseNode(["middle", "words"], weight=1.5),
|
||||
["bin", "bun"],
|
||||
PhraseNode(["bin", "bun"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -33,9 +33,9 @@ class ParserTests(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
res,
|
||||
[
|
||||
["foo"],
|
||||
PhraseNode(["foo"]),
|
||||
PhraseNode(["bar"], weight=(1.5**3)),
|
||||
["bin"],
|
||||
PhraseNode(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -44,9 +44,9 @@ class ParserTests(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
res,
|
||||
[
|
||||
["foo"],
|
||||
PhraseNode(["foo"]),
|
||||
TokenNode("clip", "skip", 2),
|
||||
["bin"],
|
||||
PhraseNode(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -55,9 +55,9 @@ class ParserTests(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
res,
|
||||
[
|
||||
["foo"],
|
||||
PhraseNode(["foo"]),
|
||||
TokenNode("lora", "name", 1.5),
|
||||
["bin"],
|
||||
PhraseNode(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -68,9 +68,9 @@ class ParserTests(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
res,
|
||||
[
|
||||
["foo"],
|
||||
PhraseNode(["foo"]),
|
||||
TokenNode("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]]),
|
||||
["bin"],
|
||||
PhraseNode(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -79,20 +79,38 @@ class ParserTests(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
res,
|
||||
[
|
||||
["foo"],
|
||||
PhraseNode(["foo"]),
|
||||
TokenNode("reseed", None, [1, 2, 3, 4, 12345]),
|
||||
["bin"],
|
||||
PhraseNode(["bin"]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_compile_basic(self):
|
||||
def test_compile_tokens(self):
|
||||
prompt = compile_prompt_onnx("foo <clip:skip:2> bar (baz) <lora:qux:1.5>")
|
||||
|
||||
self.assertEqual(prompt.clip_skip, 2)
|
||||
self.assertEqual(prompt.networks, [PromptNetwork("lora", "qux", 1.5)])
|
||||
self.assertEqual(
|
||||
prompt.positive_phrases,
|
||||
[
|
||||
PromptPhrase("foo"),
|
||||
PromptPhrase("bar"),
|
||||
PromptPhrase(["foo"]),
|
||||
PromptPhrase(["bar"]),
|
||||
PromptPhrase(["baz"], weight=1.5),
|
||||
],
|
||||
)
|
||||
|
||||
def test_compile_weights(self):
|
||||
prompt = compile_prompt_onnx("foo ((bar)) baz [[qux]] bun ([nest] me)")
|
||||
|
||||
self.assertEqual(
|
||||
prompt.positive_phrases,
|
||||
[
|
||||
PromptPhrase(["foo"]),
|
||||
PromptPhrase(["bar"], weight=2.25),
|
||||
PromptPhrase(["baz"]),
|
||||
PromptPhrase(["qux"], weight=0.25),
|
||||
PromptPhrase(["bun"]),
|
||||
PromptPhrase(["nest"], weight=0.75),
|
||||
PromptPhrase(["me"], weight=1.5),
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue