collapse prompt phrase runs
This commit is contained in:
parent
501dbff8a5
commit
3765fb9cbb
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class PromptNetwork:
|
class PromptNetwork:
|
||||||
|
@ -127,7 +127,7 @@ class Prompt:
|
||||||
clip_skip: int,
|
clip_skip: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.positive_phrases = positive_phrases
|
self.positive_phrases = positive_phrases
|
||||||
self.negative_prompt = negative_phrases
|
self.negative_phrases = negative_phrases
|
||||||
self.networks = networks or []
|
self.networks = networks or []
|
||||||
self.region_prompts = region_prompts or []
|
self.region_prompts = region_prompts or []
|
||||||
self.region_seeds = region_seeds or []
|
self.region_seeds = region_seeds or []
|
||||||
|
@ -146,3 +146,45 @@ class Prompt:
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds}, {self.clip_skip})"
|
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds}, {self.clip_skip})"
|
||||||
|
|
||||||
|
def collapse_runs(self) -> None:
|
||||||
|
self.positive_phrases = collapse_phrases(self.positive_phrases)
|
||||||
|
self.negative_phrases = collapse_phrases(self.negative_phrases)
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_phrases(
|
||||||
|
nodes: List[Union[Any]],
|
||||||
|
) -> List[Union[Any]]:
|
||||||
|
"""
|
||||||
|
Combine phrases with the same weight.
|
||||||
|
"""
|
||||||
|
|
||||||
|
weight = None
|
||||||
|
tokens = []
|
||||||
|
phrases = []
|
||||||
|
|
||||||
|
def flush_tokens():
|
||||||
|
nonlocal weight, tokens
|
||||||
|
if len(tokens) > 0:
|
||||||
|
phrase = " ".join(tokens)
|
||||||
|
phrases.append(PromptPhrase([phrase], weight))
|
||||||
|
tokens = []
|
||||||
|
weight = None
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if isinstance(node, str):
|
||||||
|
node = PromptPhrase(node)
|
||||||
|
elif isinstance(node, (PromptNetwork, PromptRegion, PromptSeed)):
|
||||||
|
flush_tokens()
|
||||||
|
phrases.append(node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if node.weight == weight:
|
||||||
|
tokens.extend(node.phrase)
|
||||||
|
else:
|
||||||
|
flush_tokens()
|
||||||
|
tokens = node.phrase
|
||||||
|
weight = node.weight
|
||||||
|
|
||||||
|
flush_tokens()
|
||||||
|
return phrases
|
||||||
|
|
|
@ -10,13 +10,13 @@ from ..diffusers.utils import split_clip_skip
|
||||||
|
|
||||||
|
|
||||||
def get_inference_session(model):
|
def get_inference_session(model):
|
||||||
if hasattr(model, "session"):
|
if hasattr(model, "session") and model.session is not None:
|
||||||
return model.session
|
return model.session
|
||||||
|
|
||||||
if hasattr(model, "model"):
|
if hasattr(model, "model") and model.model is not None:
|
||||||
return model.model
|
return model.model
|
||||||
|
|
||||||
raise ValueError("Model does not have an inference session")
|
raise ValueError("model does not have an inference session")
|
||||||
|
|
||||||
|
|
||||||
def wrap_encoder(text_encoder):
|
def wrap_encoder(text_encoder):
|
||||||
|
@ -50,6 +50,7 @@ def wrap_encoder(text_encoder):
|
||||||
)
|
)
|
||||||
elif output_hidden_states is True:
|
elif output_hidden_states is True:
|
||||||
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
|
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
|
||||||
|
print("outputs", outputs)
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||||
pooler_output=torch.from_numpy(outputs[1]),
|
pooler_output=torch.from_numpy(outputs[1]),
|
||||||
|
@ -125,7 +126,7 @@ def encode_prompt_compel_sdxl(
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
num_images_per_prompt: int,
|
num_images_per_prompt: int,
|
||||||
do_classifier_free_guidance: bool,
|
do_classifier_free_guidance: bool,
|
||||||
negative_prompt: Optional[Union[str, list]],
|
negative_prompt: Optional[Union[str, list]] = None,
|
||||||
prompt_embeds: Optional[np.ndarray] = None,
|
prompt_embeds: Optional[np.ndarray] = None,
|
||||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
|
|
@ -2,6 +2,8 @@ from typing import List, Union
|
||||||
|
|
||||||
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
|
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
|
||||||
|
|
||||||
|
from .utils import collapse_phrases, flatten
|
||||||
|
|
||||||
|
|
||||||
def token_delimiter():
|
def token_delimiter():
|
||||||
return ":"
|
return ":"
|
||||||
|
@ -192,7 +194,7 @@ class OnnxPromptVisitor(PTNodeVisitor):
|
||||||
return list(flatten(children))
|
return list(flatten(children))
|
||||||
|
|
||||||
def visit_prompt(self, node, children):
|
def visit_prompt(self, node, children):
|
||||||
return collapse_phrases(list(flatten(children)))
|
return collapse_phrases(list(flatten(children)), PhraseNode, TokenNode)
|
||||||
|
|
||||||
|
|
||||||
def parse_phrase(child, weight):
|
def parse_phrase(child, weight):
|
||||||
|
@ -206,48 +208,3 @@ def parse_phrase(child, weight):
|
||||||
# return PhraseNode(child, weight)
|
# return PhraseNode(child, weight)
|
||||||
|
|
||||||
return [parse_phrase(c, weight) for c in child]
|
return [parse_phrase(c, weight) for c in child]
|
||||||
|
|
||||||
|
|
||||||
def flatten(lst):
|
|
||||||
for el in lst:
|
|
||||||
if isinstance(el, list):
|
|
||||||
yield from flatten(el)
|
|
||||||
else:
|
|
||||||
yield el
|
|
||||||
|
|
||||||
|
|
||||||
def collapse_phrases(
|
|
||||||
nodes: List[Union[PhraseNode, str]]
|
|
||||||
) -> List[Union[PhraseNode, str]]:
|
|
||||||
"""
|
|
||||||
Combine phrases with the same weight.
|
|
||||||
"""
|
|
||||||
|
|
||||||
weight = None
|
|
||||||
tokens = []
|
|
||||||
phrases = []
|
|
||||||
|
|
||||||
def flush_tokens():
|
|
||||||
nonlocal weight, tokens
|
|
||||||
if len(tokens) > 0:
|
|
||||||
phrases.append(PhraseNode(tokens, weight))
|
|
||||||
tokens = []
|
|
||||||
weight = None
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
if isinstance(node, str):
|
|
||||||
node = PhraseNode([node])
|
|
||||||
elif isinstance(node, TokenNode):
|
|
||||||
flush_tokens()
|
|
||||||
phrases.append(node)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if node.weight == weight:
|
|
||||||
tokens.extend(node.tokens)
|
|
||||||
else:
|
|
||||||
flush_tokens()
|
|
||||||
tokens = [*node.tokens]
|
|
||||||
weight = node.weight
|
|
||||||
|
|
||||||
flush_tokens()
|
|
||||||
return phrases
|
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
from typing import Any, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(lst):
|
||||||
|
for el in lst:
|
||||||
|
if isinstance(el, list):
|
||||||
|
yield from flatten(el)
|
||||||
|
else:
|
||||||
|
yield el
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_phrases(
|
||||||
|
nodes: List[Union[Any]],
|
||||||
|
phrase,
|
||||||
|
token,
|
||||||
|
) -> List[Union[Any]]:
|
||||||
|
"""
|
||||||
|
Combine phrases with the same weight.
|
||||||
|
"""
|
||||||
|
|
||||||
|
weight = None
|
||||||
|
tokens = []
|
||||||
|
phrases = []
|
||||||
|
|
||||||
|
def flush_tokens():
|
||||||
|
nonlocal weight, tokens
|
||||||
|
if len(tokens) > 0:
|
||||||
|
phrases.append(phrase(tokens, weight))
|
||||||
|
tokens = []
|
||||||
|
weight = None
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if isinstance(node, str):
|
||||||
|
node = phrase([node])
|
||||||
|
elif isinstance(node, token):
|
||||||
|
flush_tokens()
|
||||||
|
phrases.append(node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if node.weight == weight:
|
||||||
|
tokens.extend(node.tokens)
|
||||||
|
else:
|
||||||
|
flush_tokens()
|
||||||
|
tokens = [*node.tokens]
|
||||||
|
weight = node.weight
|
||||||
|
|
||||||
|
flush_tokens()
|
||||||
|
return phrases
|
|
@ -0,0 +1,76 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from onnx_web.prompt.compel import (
|
||||||
|
encode_prompt_compel,
|
||||||
|
encode_prompt_compel_sdxl,
|
||||||
|
get_inference_session,
|
||||||
|
wrap_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompelHelpers(unittest.TestCase):
|
||||||
|
def test_get_inference_session_missing(self):
|
||||||
|
self.assertRaises(ValueError, get_inference_session, None)
|
||||||
|
|
||||||
|
def test_get_inference_session_onnx_session(self):
|
||||||
|
model = MagicMock()
|
||||||
|
model.model = None
|
||||||
|
model.session = "session"
|
||||||
|
self.assertEqual(get_inference_session(model), "session")
|
||||||
|
|
||||||
|
def test_get_inference_session_onnx_model(self):
|
||||||
|
model = MagicMock()
|
||||||
|
model.model = "model"
|
||||||
|
model.session = None
|
||||||
|
self.assertEqual(get_inference_session(model), "model")
|
||||||
|
|
||||||
|
def test_wrap_encoder(self):
|
||||||
|
text_encoder = MagicMock()
|
||||||
|
wrapped = wrap_encoder(text_encoder)
|
||||||
|
self.assertEqual(wrapped.device, "cpu")
|
||||||
|
self.assertEqual(wrapped.text_encoder, text_encoder)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompelEncodePrompt(unittest.TestCase):
|
||||||
|
def test_encode_basic(self):
|
||||||
|
pipeline = MagicMock()
|
||||||
|
pipeline.text_encoder = MagicMock()
|
||||||
|
pipeline.text_encoder.return_value = [
|
||||||
|
np.array([[1], [2]]),
|
||||||
|
np.array([[3], [4]]),
|
||||||
|
]
|
||||||
|
pipeline.tokenizer = MagicMock()
|
||||||
|
pipeline.tokenizer.model_max_length = 1
|
||||||
|
|
||||||
|
embeds = encode_prompt_compel(pipeline, "prompt", 1, True)
|
||||||
|
np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]])
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompelEncodePromptSDXL(unittest.TestCase):
|
||||||
|
@unittest.skip("need to fix the tensor shapes")
|
||||||
|
def test_encode_basic(self):
|
||||||
|
text_encoder_output = MagicMock()
|
||||||
|
text_encoder_output.hidden_states = [[0], [1], [2], [3]]
|
||||||
|
|
||||||
|
def call_text_encoder(*args, return_dict=False, **kwargs):
|
||||||
|
print("call_text_encoder", return_dict)
|
||||||
|
if return_dict:
|
||||||
|
return text_encoder_output
|
||||||
|
|
||||||
|
return [np.array([[1]]), np.array([[3]]), np.array([[5]]), np.array([[7]])]
|
||||||
|
|
||||||
|
pipeline = MagicMock()
|
||||||
|
pipeline.text_encoder.side_effect = call_text_encoder
|
||||||
|
pipeline.text_encoder_2.side_effect = call_text_encoder
|
||||||
|
pipeline.tokenizer.model_max_length = 1
|
||||||
|
pipeline.tokenizer_2.model_max_length = 1
|
||||||
|
|
||||||
|
embeds = encode_prompt_compel_sdxl(pipeline, "prompt", 1, True)
|
||||||
|
np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -114,3 +114,15 @@ class ParserTests(unittest.TestCase):
|
||||||
PromptPhrase(["me"], weight=1.5),
|
PromptPhrase(["me"], weight=1.5),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_compile_runs(self):
|
||||||
|
prompt = compile_prompt_onnx("foo <clip:skip:2> bar (baz) <lora:qux:1.5>")
|
||||||
|
prompt.collapse_runs()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
prompt.positive_phrases,
|
||||||
|
[
|
||||||
|
PromptPhrase(["foo bar"]),
|
||||||
|
PromptPhrase(["baz"], weight=1.5),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue