dry model dumping, improve DM prompt
This commit is contained in:
parent
9f435eeb8b
commit
e1c72c3717
|
@ -209,10 +209,10 @@ prompts:
|
||||||
|
|
||||||
# world generation
|
# world generation
|
||||||
world_generate_dungeon_master: |
|
world_generate_dungeon_master: |
|
||||||
You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme}}. Be
|
You are an experienced dungeon master creating a visually detailed world for a new adventure set in {{theme | punctuate}} Be
|
||||||
creative and original, creating a world that is visually detailed and full of curious details. Do not repeat
|
creative and original, creating a world that is visually detailed, consistent, and plausible within the context of
|
||||||
yourself unless you are given the same prompt with the same characters, room, and context. {{flavor}}. The theme of
|
the setting. Do not repeat yourself unless you are given the same prompt with the same characters, room, and
|
||||||
the world must be: {{theme}}.
|
context. {{flavor | punctuate}} The theme of the world must be: {{theme | punctuate}}
|
||||||
|
|
||||||
world_generate_world_broadcast_theme: |
|
world_generate_world_broadcast_theme: |
|
||||||
Generating a {{theme}} with {{room_count}} rooms
|
Generating a {{theme}} with {{room_count}} rooms
|
||||||
|
|
|
@ -10,9 +10,10 @@ from packit.agent import Agent, agent_easy_connect
|
||||||
from packit.memory import make_limited_memory
|
from packit.memory import make_limited_memory
|
||||||
from packit.utils import logger_with_colors
|
from packit.utils import logger_with_colors
|
||||||
|
|
||||||
|
# configure logging
|
||||||
|
# this is the only taleweave import allowed before the logger has been created
|
||||||
from taleweave.utils.file import load_yaml
|
from taleweave.utils.file import load_yaml
|
||||||
|
|
||||||
# configure logging
|
|
||||||
LOG_PATH = "logging.json"
|
LOG_PATH = "logging.json"
|
||||||
try:
|
try:
|
||||||
if path.exists(LOG_PATH):
|
if path.exists(LOG_PATH):
|
||||||
|
@ -30,6 +31,15 @@ logger = logger_with_colors(__name__) # , level="DEBUG")
|
||||||
|
|
||||||
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
|
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
|
||||||
|
|
||||||
|
# start the debugger, if needed
|
||||||
|
if environ.get("DEBUG", "false").lower() == "true":
|
||||||
|
import debugpy
|
||||||
|
|
||||||
|
debugpy.listen(5679)
|
||||||
|
logger.info("waiting for debugger to attach...")
|
||||||
|
debugpy.wait_for_client()
|
||||||
|
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
from taleweave.context import (
|
from taleweave.context import (
|
||||||
get_prompt_library,
|
get_prompt_library,
|
||||||
|
@ -52,14 +62,6 @@ if True:
|
||||||
from taleweave.state import create_agents, save_world, save_world_state
|
from taleweave.state import create_agents, save_world, save_world_state
|
||||||
from taleweave.utils.template import format_prompt
|
from taleweave.utils.template import format_prompt
|
||||||
|
|
||||||
# start the debugger, if needed
|
|
||||||
if environ.get("DEBUG", "false").lower() == "true":
|
|
||||||
import debugpy
|
|
||||||
|
|
||||||
debugpy.listen(5679)
|
|
||||||
logger.info("waiting for debugger to attach...")
|
|
||||||
debugpy.wait_for_client()
|
|
||||||
|
|
||||||
|
|
||||||
def int_or_inf(value: str) -> float | int:
|
def int_or_inf(value: str) -> float | int:
|
||||||
if value == "inf":
|
if value == "inf":
|
||||||
|
@ -415,7 +417,9 @@ def main():
|
||||||
world_builder = Agent(
|
world_builder = Agent(
|
||||||
"dungeon master",
|
"dungeon master",
|
||||||
format_prompt(
|
format_prompt(
|
||||||
"world_generate_dungeon_master", flavor=args.flavor, theme=world.theme
|
"world_generate_dungeon_master",
|
||||||
|
flavor=world_prompt.flavor,
|
||||||
|
theme=world_prompt.theme,
|
||||||
),
|
),
|
||||||
{},
|
{},
|
||||||
llm,
|
llm,
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import RootModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
else:
|
else:
|
||||||
|
@ -16,6 +18,14 @@ class BaseModel:
|
||||||
id: str
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
def dump_model(cls, model: BaseModel) -> Dict:
|
||||||
|
return RootModel[cls](model).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
def dump_model_json(cls, model: BaseModel) -> str:
|
||||||
|
return RootModel[cls](model).model_dump_json(indent=2)
|
||||||
|
|
||||||
|
|
||||||
def uuid() -> str:
|
def uuid() -> str:
|
||||||
return uuid4().hex
|
return uuid4().hex
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ from uuid import uuid4
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import RootModel
|
|
||||||
|
|
||||||
from taleweave.context import (
|
from taleweave.context import (
|
||||||
broadcast,
|
broadcast,
|
||||||
|
@ -20,6 +19,7 @@ from taleweave.context import (
|
||||||
set_character_agent,
|
set_character_agent,
|
||||||
subscribe,
|
subscribe,
|
||||||
)
|
)
|
||||||
|
from taleweave.models.base import dump_model
|
||||||
from taleweave.models.config import WebsocketServerConfig
|
from taleweave.models.config import WebsocketServerConfig
|
||||||
from taleweave.models.entity import World, WorldEntity
|
from taleweave.models.entity import World, WorldEntity
|
||||||
from taleweave.models.event import (
|
from taleweave.models.event import (
|
||||||
|
@ -343,7 +343,7 @@ def server_system(world: World, turn: int, data: Any | None = None):
|
||||||
|
|
||||||
|
|
||||||
def server_event(event: GameEvent):
|
def server_event(event: GameEvent):
|
||||||
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
|
json_event: Dict[str, Any] = dump_model(event.__class__, event)
|
||||||
json_event.update(
|
json_event.update(
|
||||||
{
|
{
|
||||||
"id": event.id,
|
"id": event.id,
|
||||||
|
|
|
@ -5,13 +5,13 @@ from typing import Dict, List, Sequence
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from packit.agent import Agent, agent_easy_connect
|
from packit.agent import Agent, agent_easy_connect
|
||||||
from pydantic import RootModel
|
|
||||||
|
|
||||||
from taleweave.context import (
|
from taleweave.context import (
|
||||||
get_all_character_agents,
|
get_all_character_agents,
|
||||||
get_game_config,
|
get_game_config,
|
||||||
set_character_agent,
|
set_character_agent,
|
||||||
)
|
)
|
||||||
|
from taleweave.models.base import dump_model, dump_model_json
|
||||||
from taleweave.models.entity import World
|
from taleweave.models.entity import World
|
||||||
from taleweave.player import LocalPlayer
|
from taleweave.player import LocalPlayer
|
||||||
from taleweave.utils.template import format_prompt
|
from taleweave.utils.template import format_prompt
|
||||||
|
@ -58,10 +58,9 @@ def graph_world(world: World, turn: int):
|
||||||
|
|
||||||
def snapshot_world(world: World, turn: int):
|
def snapshot_world(world: World, turn: int):
|
||||||
# save the world itself, along with the turn number and the memory of each agent
|
# save the world itself, along with the turn number and the memory of each agent
|
||||||
json_world = RootModel[World](world).model_dump()
|
json_world = dump_model(World, world)
|
||||||
|
|
||||||
json_memory = {}
|
json_memory = {}
|
||||||
|
|
||||||
for character, agent in get_all_character_agents():
|
for character, agent in get_all_character_agents():
|
||||||
json_memory[character.name] = list(agent.memory or [])
|
json_memory[character.name] = list(agent.memory or [])
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ def restore_memory(
|
||||||
|
|
||||||
def save_world(world, filename):
|
def save_world(world, filename):
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
json_world = RootModel[World](world).model_dump_json(indent=2)
|
json_world = dump_model_json(World, world)
|
||||||
f.write(json_world)
|
f.write(json_world)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pydantic import RootModel
|
from taleweave.models.base import dump_model
|
||||||
|
|
||||||
from taleweave.utils.file import load_yaml, save_yaml
|
from taleweave.utils.file import load_yaml, save_yaml
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,6 +10,6 @@ def load_system_data(cls, file):
|
||||||
|
|
||||||
|
|
||||||
def save_system_data(cls, file, model):
|
def save_system_data(cls, file, model):
|
||||||
data = RootModel[cls](model).model_dump()
|
data = dump_model(cls, model)
|
||||||
with open(file, "w") as f:
|
with open(file, "w") as f:
|
||||||
save_yaml(f, data)
|
save_yaml(f, data)
|
||||||
|
|
|
@ -28,7 +28,7 @@ def the_prefix(name: str) -> str:
|
||||||
return f"the {name}"
|
return f"the {name}"
|
||||||
|
|
||||||
|
|
||||||
def punctuate(name: str, suffix: str) -> str:
|
def punctuate(name: str, suffix: str = ".") -> str:
|
||||||
if name[-1] in [".", "!", "?", suffix]:
|
if name[-1] in [".", "!", "?", suffix]:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue