48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
from typing import Literal
|
|
|
|
import numpy as np
|
|
from arpeggio import ParserPython, visit_parse_tree
|
|
|
|
from .grammar import OnnxPromptVisitor
|
|
from .grammar import prompt as prompt_base
|
|
|
|
|
|
def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray:
|
|
from compel import Compel
|
|
|
|
parser = Compel(tokenizer=pipeline.tokenizer, text_encoder=pipeline.text_encoder)
|
|
return parser([prompt])
|
|
|
|
|
|
def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray:
|
|
pass
|
|
|
|
|
|
def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:
|
|
parser = ParserPython(prompt_base, debug=debug)
|
|
visitor = OnnxPromptVisitor()
|
|
|
|
ast = parser.parse(prompt)
|
|
return visit_parse_tree(ast, visitor)
|
|
|
|
|
|
def parse_prompt_vanilla(pipeline, prompt: str) -> np.ndarray:
|
|
return pipeline._encode_prompt(prompt)
|
|
|
|
|
|
def parse_prompt(
|
|
pipeline,
|
|
prompt: str,
|
|
engine: Literal["compel", "lpw", "onnx-web", "pipeline"] = "onnx-web",
|
|
) -> np.ndarray:
|
|
if engine == "compel":
|
|
return parse_prompt_compel(pipeline, prompt)
|
|
if engine == "lpw":
|
|
return parse_prompt_lpw(pipeline, prompt)
|
|
elif engine == "onnx-web":
|
|
return parse_prompt_onnx(pipeline, prompt)
|
|
elif engine == "pipeline":
|
|
return parse_prompt_vanilla(pipeline, prompt)
|
|
else:
|
|
raise ValueError("invalid prompt parser")
|