collapse prompt phrase runs
This commit is contained in:
parent
501dbff8a5
commit
3765fb9cbb
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
class PromptNetwork:
|
||||
|
@ -127,7 +127,7 @@ class Prompt:
|
|||
clip_skip: int,
|
||||
) -> None:
|
||||
self.positive_phrases = positive_phrases
|
||||
self.negative_prompt = negative_phrases
|
||||
self.negative_phrases = negative_phrases
|
||||
self.networks = networks or []
|
||||
self.region_prompts = region_prompts or []
|
||||
self.region_seeds = region_seeds or []
|
||||
|
@ -146,3 +146,45 @@ class Prompt:
|
|||
|
||||
def __repr__(self) -> str:
|
||||
return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds}, {self.clip_skip})"
|
||||
|
||||
def collapse_runs(self) -> None:
|
||||
self.positive_phrases = collapse_phrases(self.positive_phrases)
|
||||
self.negative_phrases = collapse_phrases(self.negative_phrases)
|
||||
|
||||
|
||||
def collapse_phrases(
|
||||
nodes: List[Union[Any]],
|
||||
) -> List[Union[Any]]:
|
||||
"""
|
||||
Combine phrases with the same weight.
|
||||
"""
|
||||
|
||||
weight = None
|
||||
tokens = []
|
||||
phrases = []
|
||||
|
||||
def flush_tokens():
|
||||
nonlocal weight, tokens
|
||||
if len(tokens) > 0:
|
||||
phrase = " ".join(tokens)
|
||||
phrases.append(PromptPhrase([phrase], weight))
|
||||
tokens = []
|
||||
weight = None
|
||||
|
||||
for node in nodes:
|
||||
if isinstance(node, str):
|
||||
node = PromptPhrase(node)
|
||||
elif isinstance(node, (PromptNetwork, PromptRegion, PromptSeed)):
|
||||
flush_tokens()
|
||||
phrases.append(node)
|
||||
continue
|
||||
|
||||
if node.weight == weight:
|
||||
tokens.extend(node.phrase)
|
||||
else:
|
||||
flush_tokens()
|
||||
tokens = node.phrase
|
||||
weight = node.weight
|
||||
|
||||
flush_tokens()
|
||||
return phrases
|
||||
|
|
|
@ -10,13 +10,13 @@ from ..diffusers.utils import split_clip_skip
|
|||
|
||||
|
||||
def get_inference_session(model):
|
||||
if hasattr(model, "session"):
|
||||
if hasattr(model, "session") and model.session is not None:
|
||||
return model.session
|
||||
|
||||
if hasattr(model, "model"):
|
||||
if hasattr(model, "model") and model.model is not None:
|
||||
return model.model
|
||||
|
||||
raise ValueError("Model does not have an inference session")
|
||||
raise ValueError("model does not have an inference session")
|
||||
|
||||
|
||||
def wrap_encoder(text_encoder):
|
||||
|
@ -50,6 +50,7 @@ def wrap_encoder(text_encoder):
|
|||
)
|
||||
elif output_hidden_states is True:
|
||||
hidden_states = [torch.from_numpy(state) for state in outputs[2:]]
|
||||
print("outputs", outputs)
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=torch.from_numpy(outputs[0]),
|
||||
pooler_output=torch.from_numpy(outputs[1]),
|
||||
|
@ -125,7 +126,7 @@ def encode_prompt_compel_sdxl(
|
|||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int,
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[Union[str, list]],
|
||||
negative_prompt: Optional[Union[str, list]] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
|
|
|
@ -2,6 +2,8 @@ from typing import List, Union
|
|||
|
||||
from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch
|
||||
|
||||
from .utils import collapse_phrases, flatten
|
||||
|
||||
|
||||
def token_delimiter():
|
||||
return ":"
|
||||
|
@ -192,7 +194,7 @@ class OnnxPromptVisitor(PTNodeVisitor):
|
|||
return list(flatten(children))
|
||||
|
||||
def visit_prompt(self, node, children):
|
||||
return collapse_phrases(list(flatten(children)))
|
||||
return collapse_phrases(list(flatten(children)), PhraseNode, TokenNode)
|
||||
|
||||
|
||||
def parse_phrase(child, weight):
|
||||
|
@ -206,48 +208,3 @@ def parse_phrase(child, weight):
|
|||
# return PhraseNode(child, weight)
|
||||
|
||||
return [parse_phrase(c, weight) for c in child]
|
||||
|
||||
|
||||
def flatten(lst):
|
||||
for el in lst:
|
||||
if isinstance(el, list):
|
||||
yield from flatten(el)
|
||||
else:
|
||||
yield el
|
||||
|
||||
|
||||
def collapse_phrases(
|
||||
nodes: List[Union[PhraseNode, str]]
|
||||
) -> List[Union[PhraseNode, str]]:
|
||||
"""
|
||||
Combine phrases with the same weight.
|
||||
"""
|
||||
|
||||
weight = None
|
||||
tokens = []
|
||||
phrases = []
|
||||
|
||||
def flush_tokens():
|
||||
nonlocal weight, tokens
|
||||
if len(tokens) > 0:
|
||||
phrases.append(PhraseNode(tokens, weight))
|
||||
tokens = []
|
||||
weight = None
|
||||
|
||||
for node in nodes:
|
||||
if isinstance(node, str):
|
||||
node = PhraseNode([node])
|
||||
elif isinstance(node, TokenNode):
|
||||
flush_tokens()
|
||||
phrases.append(node)
|
||||
continue
|
||||
|
||||
if node.weight == weight:
|
||||
tokens.extend(node.tokens)
|
||||
else:
|
||||
flush_tokens()
|
||||
tokens = [*node.tokens]
|
||||
weight = node.weight
|
||||
|
||||
flush_tokens()
|
||||
return phrases
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
from typing import Any, List, Union
|
||||
|
||||
|
||||
def flatten(lst):
|
||||
for el in lst:
|
||||
if isinstance(el, list):
|
||||
yield from flatten(el)
|
||||
else:
|
||||
yield el
|
||||
|
||||
|
||||
def collapse_phrases(
|
||||
nodes: List[Union[Any]],
|
||||
phrase,
|
||||
token,
|
||||
) -> List[Union[Any]]:
|
||||
"""
|
||||
Combine phrases with the same weight.
|
||||
"""
|
||||
|
||||
weight = None
|
||||
tokens = []
|
||||
phrases = []
|
||||
|
||||
def flush_tokens():
|
||||
nonlocal weight, tokens
|
||||
if len(tokens) > 0:
|
||||
phrases.append(phrase(tokens, weight))
|
||||
tokens = []
|
||||
weight = None
|
||||
|
||||
for node in nodes:
|
||||
if isinstance(node, str):
|
||||
node = phrase([node])
|
||||
elif isinstance(node, token):
|
||||
flush_tokens()
|
||||
phrases.append(node)
|
||||
continue
|
||||
|
||||
if node.weight == weight:
|
||||
tokens.extend(node.tokens)
|
||||
else:
|
||||
flush_tokens()
|
||||
tokens = [*node.tokens]
|
||||
weight = node.weight
|
||||
|
||||
flush_tokens()
|
||||
return phrases
|
|
@ -0,0 +1,76 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onnx_web.prompt.compel import (
|
||||
encode_prompt_compel,
|
||||
encode_prompt_compel_sdxl,
|
||||
get_inference_session,
|
||||
wrap_encoder,
|
||||
)
|
||||
|
||||
|
||||
class TestCompelHelpers(unittest.TestCase):
|
||||
def test_get_inference_session_missing(self):
|
||||
self.assertRaises(ValueError, get_inference_session, None)
|
||||
|
||||
def test_get_inference_session_onnx_session(self):
|
||||
model = MagicMock()
|
||||
model.model = None
|
||||
model.session = "session"
|
||||
self.assertEqual(get_inference_session(model), "session")
|
||||
|
||||
def test_get_inference_session_onnx_model(self):
|
||||
model = MagicMock()
|
||||
model.model = "model"
|
||||
model.session = None
|
||||
self.assertEqual(get_inference_session(model), "model")
|
||||
|
||||
def test_wrap_encoder(self):
|
||||
text_encoder = MagicMock()
|
||||
wrapped = wrap_encoder(text_encoder)
|
||||
self.assertEqual(wrapped.device, "cpu")
|
||||
self.assertEqual(wrapped.text_encoder, text_encoder)
|
||||
|
||||
|
||||
class TestCompelEncodePrompt(unittest.TestCase):
|
||||
def test_encode_basic(self):
|
||||
pipeline = MagicMock()
|
||||
pipeline.text_encoder = MagicMock()
|
||||
pipeline.text_encoder.return_value = [
|
||||
np.array([[1], [2]]),
|
||||
np.array([[3], [4]]),
|
||||
]
|
||||
pipeline.tokenizer = MagicMock()
|
||||
pipeline.tokenizer.model_max_length = 1
|
||||
|
||||
embeds = encode_prompt_compel(pipeline, "prompt", 1, True)
|
||||
np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]])
|
||||
|
||||
|
||||
class TestCompelEncodePromptSDXL(unittest.TestCase):
|
||||
@unittest.skip("need to fix the tensor shapes")
|
||||
def test_encode_basic(self):
|
||||
text_encoder_output = MagicMock()
|
||||
text_encoder_output.hidden_states = [[0], [1], [2], [3]]
|
||||
|
||||
def call_text_encoder(*args, return_dict=False, **kwargs):
|
||||
print("call_text_encoder", return_dict)
|
||||
if return_dict:
|
||||
return text_encoder_output
|
||||
|
||||
return [np.array([[1]]), np.array([[3]]), np.array([[5]]), np.array([[7]])]
|
||||
|
||||
pipeline = MagicMock()
|
||||
pipeline.text_encoder.side_effect = call_text_encoder
|
||||
pipeline.text_encoder_2.side_effect = call_text_encoder
|
||||
pipeline.tokenizer.model_max_length = 1
|
||||
pipeline.tokenizer_2.model_max_length = 1
|
||||
|
||||
embeds = encode_prompt_compel_sdxl(pipeline, "prompt", 1, True)
|
||||
np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -114,3 +114,15 @@ class ParserTests(unittest.TestCase):
|
|||
PromptPhrase(["me"], weight=1.5),
|
||||
],
|
||||
)
|
||||
|
||||
def test_compile_runs(self):
|
||||
prompt = compile_prompt_onnx("foo <clip:skip:2> bar (baz) <lora:qux:1.5>")
|
||||
prompt.collapse_runs()
|
||||
|
||||
self.assertEqual(
|
||||
prompt.positive_phrases,
|
||||
[
|
||||
PromptPhrase(["foo bar"]),
|
||||
PromptPhrase(["baz"], weight=1.5),
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue