1
0
Fork 0
onnx-web/api/tests/prompt/test_parser.py

51 lines
1.5 KiB
Python
Raw Permalink Normal View History

import unittest
from onnx_web.prompt.grammar import PromptPhrase, PromptToken
from onnx_web.prompt.parser import parse_prompt_onnx
class ParserTests(unittest.TestCase):
def test_single_word_phrase(self):
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo"]),
str(PromptPhrase(["bar"], weight=1.5)),
str(["bin"]),
2023-11-20 05:18:57 +00:00
],
)
def test_multi_word_phrase(self):
res = parse_prompt_onnx(None, "foo bar (middle words) bin bun", debug=False)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo", "bar"]),
str(PromptPhrase(["middle", "words"], weight=1.5)),
str(["bin", "bun"]),
2023-11-20 05:18:57 +00:00
],
)
def test_nested_phrase(self):
res = parse_prompt_onnx(None, "foo (((bar))) bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo"]),
2023-11-20 05:18:57 +00:00
str(PromptPhrase(["bar"], weight=(1.5**3))),
str(["bin"]),
2023-11-20 05:18:57 +00:00
],
)
def test_lora_token(self):
res = parse_prompt_onnx(None, "foo <lora:name:1.5> bin", debug=False)
self.assertListEqual(
[str(i) for i in res],
[
str(["foo"]),
str(PromptToken("lora", "name", 1.5)),
str(["bin"]),
],
)