2024-05-31 23:58:01 +00:00
|
|
|
from logging import getLogger
|
|
|
|
|
|
|
|
from jinja2 import Environment
|
|
|
|
|
|
|
|
from taleweave.context import get_prompt_library
|
2024-06-03 01:00:17 +00:00
|
|
|
from taleweave.utils.string import and_list, or_list
|
2024-05-31 23:58:01 +00:00
|
|
|
from taleweave.utils.world import describe_entity, name_entity
|
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2024-06-06 00:05:21 +00:00
|
|
|
|
|
|
|
def a_prefix(name: str) -> str:
|
|
|
|
first_word = name.split(" ")[0]
|
|
|
|
if first_word.lower() in ["a", "an", "the"]:
|
|
|
|
return name
|
|
|
|
|
|
|
|
if name[0].lower() in "aeiou":
|
|
|
|
return f"an {name}"
|
|
|
|
|
|
|
|
return f"a {name}"
|
|
|
|
|
|
|
|
|
|
|
|
def the_prefix(name: str) -> str:
|
|
|
|
first_word = name.split(" ")[0]
|
|
|
|
if first_word.lower() in ["a", "an", "the"]:
|
|
|
|
return name
|
|
|
|
|
|
|
|
return f"the {name}"
|
|
|
|
|
|
|
|
|
|
|
|
def punctuate(name: str, suffix: str) -> str:
|
|
|
|
if name[-1] in [".", "!", "?", suffix]:
|
|
|
|
return name
|
|
|
|
|
|
|
|
return f"{name}{suffix}"
|
|
|
|
|
|
|
|
|
2024-06-03 01:00:17 +00:00
|
|
|
jinja_env = Environment()
|
|
|
|
jinja_env.filters["describe"] = describe_entity
|
|
|
|
jinja_env.filters["name"] = name_entity
|
|
|
|
jinja_env.filters["and_list"] = and_list
|
|
|
|
jinja_env.filters["or_list"] = or_list
|
2024-06-06 00:05:21 +00:00
|
|
|
jinja_env.filters["a_prefix"] = a_prefix
|
|
|
|
jinja_env.filters["the_prefix"] = the_prefix
|
|
|
|
jinja_env.filters["punctuate"] = punctuate
|
2024-06-03 01:00:17 +00:00
|
|
|
|
2024-05-31 23:58:01 +00:00
|
|
|
|
|
|
|
def format_prompt(prompt_key: str, **kwargs) -> str:
|
|
|
|
try:
|
|
|
|
library = get_prompt_library()
|
|
|
|
template_str = library.prompts[prompt_key]
|
|
|
|
return format_str(template_str, **kwargs)
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception("error formatting prompt: %s", prompt_key)
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
def format_str(template_str: str, **kwargs) -> str:
|
2024-06-05 03:07:26 +00:00
|
|
|
# TODO: cache templates
|
2024-06-03 01:00:17 +00:00
|
|
|
template = jinja_env.from_string(template_str)
|
2024-05-31 23:58:01 +00:00
|
|
|
return template.render(**kwargs)
|