1
0
Fork 0

dry model dumping, improve DM prompt
Run Docker Build / build (push) Successful in 14s Details
Run Python Build / build (push) Successful in 26s Details

This commit is contained in:
Sean Sube 2024-06-07 21:18:56 -05:00
parent 9f435eeb8b
commit e1c72c3717
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 36 additions and 24 deletions

View File

@ -209,10 +209,10 @@ prompts:
# world generation # world generation
world_generate_dungeon_master: | world_generate_dungeon_master: |
You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme}}. Be You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme | punctuate}} Be
creative and original, creating a world that is visually detailed and full of curious details. Do not repeat creative and original, creating a world that is visually detailed, consistent, and plausible within the context of
yourself unless you are given the same prompt with the same characters, room, and context. {{flavor}}. The theme of the setting. Do not repeat yourself unless you are given the same prompt with the same characters, room, and
the world must be: {{theme}}. context. {{flavor | punctuate}} The theme of the world must be: {{theme | punctuate}}
world_generate_world_broadcast_theme: | world_generate_world_broadcast_theme: |
Generating a {{theme}} with {{room_count}} rooms Generating a {{theme}} with {{room_count}} rooms

View File

@ -10,9 +10,10 @@ from packit.agent import Agent, agent_easy_connect
from packit.memory import make_limited_memory from packit.memory import make_limited_memory
from packit.utils import logger_with_colors from packit.utils import logger_with_colors
# configure logging
# this is the only taleweave import allowed before the logger has been created
from taleweave.utils.file import load_yaml from taleweave.utils.file import load_yaml
# configure logging
LOG_PATH = "logging.json" LOG_PATH = "logging.json"
try: try:
if path.exists(LOG_PATH): if path.exists(LOG_PATH):
@ -30,6 +31,15 @@ logger = logger_with_colors(__name__) # , level="DEBUG")
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True) load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
# start the debugger, if needed
if environ.get("DEBUG", "false").lower() == "true":
import debugpy
debugpy.listen(5679)
logger.info("waiting for debugger to attach...")
debugpy.wait_for_client()
if True: if True:
from taleweave.context import ( from taleweave.context import (
get_prompt_library, get_prompt_library,
@ -52,14 +62,6 @@ if True:
from taleweave.state import create_agents, save_world, save_world_state from taleweave.state import create_agents, save_world, save_world_state
from taleweave.utils.template import format_prompt from taleweave.utils.template import format_prompt
# start the debugger, if needed
if environ.get("DEBUG", "false").lower() == "true":
import debugpy
debugpy.listen(5679)
logger.info("waiting for debugger to attach...")
debugpy.wait_for_client()
def int_or_inf(value: str) -> float | int: def int_or_inf(value: str) -> float | int:
if value == "inf": if value == "inf":
@ -415,7 +417,9 @@ def main():
world_builder = Agent( world_builder = Agent(
"dungeon master", "dungeon master",
format_prompt( format_prompt(
"world_generate_dungeon_master", flavor=args.flavor, theme=world.theme "world_generate_dungeon_master",
flavor=world_prompt.flavor,
theme=world_prompt.theme,
), ),
{}, {},
llm, llm,

View File

@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from uuid import uuid4 from uuid import uuid4
from pydantic import RootModel
if TYPE_CHECKING: if TYPE_CHECKING:
from dataclasses import dataclass from dataclasses import dataclass
else: else:
@ -16,6 +18,14 @@ class BaseModel:
id: str id: str
def dump_model(cls, model: BaseModel) -> Dict:
return RootModel[cls](model).model_dump()
def dump_model_json(cls, model: BaseModel) -> str:
return RootModel[cls](model).model_dump_json(indent=2)
def uuid() -> str: def uuid() -> str:
return uuid4().hex return uuid4().hex

View File

@ -10,7 +10,6 @@ from uuid import uuid4
import websockets import websockets
from PIL import Image from PIL import Image
from pydantic import RootModel
from taleweave.context import ( from taleweave.context import (
broadcast, broadcast,
@ -20,6 +19,7 @@ from taleweave.context import (
set_character_agent, set_character_agent,
subscribe, subscribe,
) )
from taleweave.models.base import dump_model
from taleweave.models.config import WebsocketServerConfig from taleweave.models.config import WebsocketServerConfig
from taleweave.models.entity import World, WorldEntity from taleweave.models.entity import World, WorldEntity
from taleweave.models.event import ( from taleweave.models.event import (
@ -343,7 +343,7 @@ def server_system(world: World, turn: int, data: Any | None = None):
def server_event(event: GameEvent): def server_event(event: GameEvent):
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump() json_event: Dict[str, Any] = dump_model(event.__class__, event)
json_event.update( json_event.update(
{ {
"id": event.id, "id": event.id,

View File

@ -5,13 +5,13 @@ from typing import Dict, List, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from packit.agent import Agent, agent_easy_connect from packit.agent import Agent, agent_easy_connect
from pydantic import RootModel
from taleweave.context import ( from taleweave.context import (
get_all_character_agents, get_all_character_agents,
get_game_config, get_game_config,
set_character_agent, set_character_agent,
) )
from taleweave.models.base import dump_model, dump_model_json
from taleweave.models.entity import World from taleweave.models.entity import World
from taleweave.player import LocalPlayer from taleweave.player import LocalPlayer
from taleweave.utils.template import format_prompt from taleweave.utils.template import format_prompt
@ -58,10 +58,9 @@ def graph_world(world: World, turn: int):
def snapshot_world(world: World, turn: int): def snapshot_world(world: World, turn: int):
# save the world itself, along with the turn number and the memory of each agent # save the world itself, along with the turn number and the memory of each agent
json_world = RootModel[World](world).model_dump() json_world = dump_model(World, world)
json_memory = {} json_memory = {}
for character, agent in get_all_character_agents(): for character, agent in get_all_character_agents():
json_memory[character.name] = list(agent.memory or []) json_memory[character.name] = list(agent.memory or [])
@ -97,7 +96,7 @@ def restore_memory(
def save_world(world, filename): def save_world(world, filename):
with open(filename, "w") as f: with open(filename, "w") as f:
json_world = RootModel[World](world).model_dump_json(indent=2) json_world = dump_model_json(World, world)
f.write(json_world) f.write(json_world)

View File

@ -1,5 +1,4 @@
from pydantic import RootModel from taleweave.models.base import dump_model
from taleweave.utils.file import load_yaml, save_yaml from taleweave.utils.file import load_yaml, save_yaml
@ -11,6 +10,6 @@ def load_system_data(cls, file):
def save_system_data(cls, file, model): def save_system_data(cls, file, model):
data = RootModel[cls](model).model_dump() data = dump_model(cls, model)
with open(file, "w") as f: with open(file, "w") as f:
save_yaml(f, data) save_yaml(f, data)

View File

@ -28,7 +28,7 @@ def the_prefix(name: str) -> str:
return f"the {name}" return f"the {name}"
def punctuate(name: str, suffix: str) -> str: def punctuate(name: str, suffix: str = ".") -> str:
if name[-1] in [".", "!", "?", suffix]: if name[-1] in [".", "!", "?", suffix]:
return name return name