1
0
Fork 0

collapse prompt phrase runs

This commit is contained in:
Sean Sube 2024-03-15 20:59:25 -05:00
parent 501dbff8a5
commit 3765fb9cbb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 188 additions and 52 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional, Union
class PromptNetwork:
@ -127,7 +127,7 @@ class Prompt:
clip_skip: int,
) -> None:
self.positive_phrases = positive_phrases
self.negative_prompt = negative_phrases
self.negative_phrases = negative_phrases
self.networks = networks or []
self.region_prompts = region_prompts or []
self.region_seeds = region_seeds or []
@ -146,3 +146,45 @@ class Prompt:
def __repr__(self) -> str:
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds}, {self.clip_skip})"
def collapse_runs(self) -> None:
self.positive_phrases = collapse_phrases(self.positive_phrases)
self.negative_phrases = collapse_phrases(self.negative_phrases)
def collapse_phrases(
nodes: List[Union[Any]],
) -> List[Union[Any]]:
"""
Combine phrases with the same weight.
"""
weight = None
tokens = []
phrases = []
def flush_tokens():
nonlocal weight, tokens
if len(tokens) > 0:
phrase = " ".join(tokens)
phrases.append(PromptPhrase([phrase], weight))
tokens = []
weight = None
for node in nodes:
if isinstance(node, str):
node = PromptPhrase(node)
elif isinstance(node, (PromptNetwork, PromptRegion, PromptSeed)):
flush_tokens()
phrases.append(node)
continue
if node.weight == weight:
tokens.extend(node.phrase)
else:
flush_tokens()
tokens = node.phrase
weight = node.weight
flush_tokens()
return phrases

View File

@ -10,13 +10,13 @@ from ..diffusers.utils import split_clip_skip
def get_inference_session(model):
if hasattr(model, "session"):
if hasattr(model, "session") and model.session is not None:
return model.session
if hasattr(model, "model"):
if hasattr(model, "model") and model.model is not None:
return model.model
raise ValueError("Model does not have an inference session")
raise ValueError("model does not have an inference session")
def wrap_encoder(text_encoder):
@ -50,6 +50,7 @@ def wrap_encoder(text_encoder):
)
elif output_hidden_states is True:
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
print("outputs", outputs)
return SimpleNamespace(
last_hidden_state=torch.from_numpy(outputs[0]),
pooler_output=torch.from_numpy(outputs[1]),
@ -125,7 +126,7 @@ def encode_prompt_compel_sdxl(
prompt: Union[str, List[str]],
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[Union[str, list]],
negative_prompt: Optional[Union[str, list]] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,

View File

@ -2,6 +2,8 @@ from typing import List, Union
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
from .utils import collapse_phrases, flatten
def token_delimiter():
return ":"
@ -192,7 +194,7 @@ class OnnxPromptVisitor(PTNodeVisitor):
return list(flatten(children))
def visit_prompt(self, node, children):
return collapse_phrases(list(flatten(children)))
return collapse_phrases(list(flatten(children)), PhraseNode, TokenNode)
def parse_phrase(child, weight):
@ -206,48 +208,3 @@ def parse_phrase(child, weight):
# return PhraseNode(child, weight)
return [parse_phrase(c, weight) for c in child]
def flatten(lst):
for el in lst:
if isinstance(el, list):
yield from flatten(el)
else:
yield el
def collapse_phrases(
nodes: List[Union[PhraseNode, str]]
) -> List[Union[PhraseNode, str]]:
"""
Combine phrases with the same weight.
"""
weight = None
tokens = []
phrases = []
def flush_tokens():
nonlocal weight, tokens
if len(tokens) > 0:
phrases.append(PhraseNode(tokens, weight))
tokens = []
weight = None
for node in nodes:
if isinstance(node, str):
node = PhraseNode([node])
elif isinstance(node, TokenNode):
flush_tokens()
phrases.append(node)
continue
if node.weight == weight:
tokens.extend(node.tokens)
else:
flush_tokens()
tokens = [*node.tokens]
weight = node.weight
flush_tokens()
return phrases

View File

@ -0,0 +1,48 @@
from typing import Any, List, Union
def flatten(lst):
for el in lst:
if isinstance(el, list):
yield from flatten(el)
else:
yield el
def collapse_phrases(
nodes: List[Union[Any]],
phrase,
token,
) -> List[Union[Any]]:
"""
Combine phrases with the same weight.
"""
weight = None
tokens = []
phrases = []
def flush_tokens():
nonlocal weight, tokens
if len(tokens) > 0:
phrases.append(phrase(tokens, weight))
tokens = []
weight = None
for node in nodes:
if isinstance(node, str):
node = phrase([node])
elif isinstance(node, token):
flush_tokens()
phrases.append(node)
continue
if node.weight == weight:
tokens.extend(node.tokens)
else:
flush_tokens()
tokens = [*node.tokens]
weight = node.weight
flush_tokens()
return phrases

View File

@ -0,0 +1,76 @@
import unittest
from unittest.mock import MagicMock
import numpy as np
from onnx_web.prompt.compel import (
encode_prompt_compel,
encode_prompt_compel_sdxl,
get_inference_session,
wrap_encoder,
)
class TestCompelHelpers(unittest.TestCase):
def test_get_inference_session_missing(self):
self.assertRaises(ValueError, get_inference_session, None)
def test_get_inference_session_onnx_session(self):
model = MagicMock()
model.model = None
model.session = "session"
self.assertEqual(get_inference_session(model), "session")
def test_get_inference_session_onnx_model(self):
model = MagicMock()
model.model = "model"
model.session = None
self.assertEqual(get_inference_session(model), "model")
def test_wrap_encoder(self):
text_encoder = MagicMock()
wrapped = wrap_encoder(text_encoder)
self.assertEqual(wrapped.device, "cpu")
self.assertEqual(wrapped.text_encoder, text_encoder)
class TestCompelEncodePrompt(unittest.TestCase):
def test_encode_basic(self):
pipeline = MagicMock()
pipeline.text_encoder = MagicMock()
pipeline.text_encoder.return_value = [
np.array([[1], [2]]),
np.array([[3], [4]]),
]
pipeline.tokenizer = MagicMock()
pipeline.tokenizer.model_max_length = 1
embeds = encode_prompt_compel(pipeline, "prompt", 1, True)
np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]])
class TestCompelEncodePromptSDXL(unittest.TestCase):
@unittest.skip("need to fix the tensor shapes")
def test_encode_basic(self):
text_encoder_output = MagicMock()
text_encoder_output.hidden_states = [[0], [1], [2], [3]]
def call_text_encoder(*args, return_dict=False, **kwargs):
print("call_text_encoder", return_dict)
if return_dict:
return text_encoder_output
return [np.array([[1]]), np.array([[3]]), np.array([[5]]), np.array([[7]])]
pipeline = MagicMock()
pipeline.text_encoder.side_effect = call_text_encoder
pipeline.text_encoder_2.side_effect = call_text_encoder
pipeline.tokenizer.model_max_length = 1
pipeline.tokenizer_2.model_max_length = 1
embeds = encode_prompt_compel_sdxl(pipeline, "prompt", 1, True)
np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]])
if __name__ == "__main__":
unittest.main()

View File

@ -114,3 +114,15 @@ class ParserTests(unittest.TestCase):
PromptPhrase(["me"], weight=1.5),
],
)
def test_compile_runs(self):
prompt = compile_prompt_onnx("foo <clip:skip:2> bar (baz) <lora:qux:1.5>")
prompt.collapse_runs()
self.assertEqual(
prompt.positive_phrases,
[
PromptPhrase(["foo bar"]),
PromptPhrase(["baz"], weight=1.5),
],
)