1
0
Fork 0

dry model dumping, improve DM prompt
Run Docker Build / build (push) Successful in 14s Details
Run Python Build / build (push) Successful in 26s Details

This commit is contained in:
Sean Sube 2024-06-07 21:18:56 -05:00
parent 9f435eeb8b
commit e1c72c3717
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 36 additions and 24 deletions

View File

@ -209,10 +209,10 @@ prompts:
# world generation
world_generate_dungeon_master: |
You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme}}. Be
creative and original, creating a world that is visually detailed and full of curious details. Do not repeat
yourself unless you are given the same prompt with the same characters, room, and context. {{flavor}}. The theme of
the world must be: {{theme}}.
You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme | punctuate}} Be
creative and original, creating a world that is visually detailed, consistent, and plausible within the context of
the setting. Do not repeat yourself unless you are given the same prompt with the same characters, room, and
context. {{flavor | punctuate}} The theme of the world must be: {{theme | punctuate}}
world_generate_world_broadcast_theme: |
Generating a {{theme}} with {{room_count}} rooms

View File

@ -10,9 +10,10 @@ from packit.agent import Agent, agent_easy_connect
from packit.memory import make_limited_memory
from packit.utils import logger_with_colors
# configure logging
# this is the only taleweave import allowed before the logger has been created
from taleweave.utils.file import load_yaml
# configure logging
LOG_PATH = "logging.json"
try:
if path.exists(LOG_PATH):
@ -30,6 +31,15 @@ logger = logger_with_colors(__name__) # , level="DEBUG")
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
# start the debugger, if needed
if environ.get("DEBUG", "false").lower() == "true":
import debugpy
debugpy.listen(5679)
logger.info("waiting for debugger to attach...")
debugpy.wait_for_client()
if True:
from taleweave.context import (
get_prompt_library,
@ -52,14 +62,6 @@ if True:
from taleweave.state import create_agents, save_world, save_world_state
from taleweave.utils.template import format_prompt
# start the debugger, if needed
if environ.get("DEBUG", "false").lower() == "true":
import debugpy
debugpy.listen(5679)
logger.info("waiting for debugger to attach...")
debugpy.wait_for_client()
def int_or_inf(value: str) -> float | int:
if value == "inf":
@ -415,7 +417,9 @@ def main():
world_builder = Agent(
"dungeon master",
format_prompt(
"world_generate_dungeon_master", flavor=args.flavor, theme=world.theme
"world_generate_dungeon_master",
flavor=world_prompt.flavor,
theme=world_prompt.theme,
),
{},
llm,

View File

@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Dict
from uuid import uuid4
from pydantic import RootModel
if TYPE_CHECKING:
from dataclasses import dataclass
else:
@ -16,6 +18,14 @@ class BaseModel:
id: str
def dump_model(cls, model: BaseModel) -> Dict:
return RootModel[cls](model).model_dump()
def dump_model_json(cls, model: BaseModel) -> str:
return RootModel[cls](model).model_dump_json(indent=2)
def uuid() -> str:
return uuid4().hex

View File

@ -10,7 +10,6 @@ from uuid import uuid4
import websockets
from PIL import Image
from pydantic import RootModel
from taleweave.context import (
broadcast,
@ -20,6 +19,7 @@ from taleweave.context import (
set_character_agent,
subscribe,
)
from taleweave.models.base import dump_model
from taleweave.models.config import WebsocketServerConfig
from taleweave.models.entity import World, WorldEntity
from taleweave.models.event import (
@ -343,7 +343,7 @@ def server_system(world: World, turn: int, data: Any | None = None):
def server_event(event: GameEvent):
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
json_event: Dict[str, Any] = dump_model(event.__class__, event)
json_event.update(
{
"id": event.id,

View File

@ -5,13 +5,13 @@ from typing import Dict, List, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from packit.agent import Agent, agent_easy_connect
from pydantic import RootModel
from taleweave.context import (
get_all_character_agents,
get_game_config,
set_character_agent,
)
from taleweave.models.base import dump_model, dump_model_json
from taleweave.models.entity import World
from taleweave.player import LocalPlayer
from taleweave.utils.template import format_prompt
@ -58,10 +58,9 @@ def graph_world(world: World, turn: int):
def snapshot_world(world: World, turn: int):
# save the world itself, along with the turn number and the memory of each agent
json_world = RootModel[World](world).model_dump()
json_world = dump_model(World, world)
json_memory = {}
for character, agent in get_all_character_agents():
json_memory[character.name] = list(agent.memory or [])
@ -97,7 +96,7 @@ def restore_memory(
def save_world(world, filename):
with open(filename, "w") as f:
json_world = RootModel[World](world).model_dump_json(indent=2)
json_world = dump_model_json(World, world)
f.write(json_world)

View File

@ -1,5 +1,4 @@
from pydantic import RootModel
from taleweave.models.base import dump_model
from taleweave.utils.file import load_yaml, save_yaml
@ -11,6 +10,6 @@ def load_system_data(cls, file):
def save_system_data(cls, file, model):
data = RootModel[cls](model).model_dump()
data = dump_model(cls, model)
with open(file, "w") as f:
save_yaml(f, data)

View File

@ -28,7 +28,7 @@ def the_prefix(name: str) -> str:
return f"the {name}"
def punctuate(name: str, suffix: str) -> str:
def punctuate(name: str, suffix: str = ".") -> str:
if name[-1] in [".", "!", "?", suffix]:
return name