119 lines
3.7 KiB
Python
119 lines
3.7 KiB
Python
from collections import deque
|
|
from json import dump
|
|
from os import path
|
|
from typing import Dict, List, Sequence
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
from packit.agent import Agent, agent_easy_connect
|
|
from pydantic import RootModel
|
|
|
|
from taleweave.context import (
|
|
get_all_character_agents,
|
|
get_game_config,
|
|
set_character_agent,
|
|
)
|
|
from taleweave.models.entity import World
|
|
from taleweave.player import LocalPlayer
|
|
from taleweave.utils.template import format_prompt
|
|
|
|
|
|
def create_agents(
|
|
world: World,
|
|
memory: Dict[str, List[str | Dict[str, str]]] = {},
|
|
players: List[str] = [],
|
|
):
|
|
# set up agents for each character
|
|
llm = agent_easy_connect()
|
|
|
|
for room in world.rooms:
|
|
for character in room.characters:
|
|
if character.name in players:
|
|
agent = LocalPlayer(character.name, character.backstory)
|
|
agent_memory = restore_memory(memory.get(character.name, []))
|
|
agent.load_history(agent_memory)
|
|
else:
|
|
backstory = format_prompt(
|
|
"world_agent_backstory", character=character, world=world
|
|
)
|
|
agent = Agent(character.name, backstory, {}, llm)
|
|
agent.memory = restore_memory(memory.get(character.name, []))
|
|
set_character_agent(character.name, character, agent)
|
|
|
|
|
|
def graph_world(world: World, turn: int):
|
|
import graphviz
|
|
|
|
graph_name = f"{path.basename(world.name)}-{turn}"
|
|
graph = graphviz.Digraph(graph_name, format="png")
|
|
for room in world.rooms:
|
|
characters = [character.name for character in room.characters]
|
|
room_label = "\n".join([room.name, *characters])
|
|
graph.node(room.name, room_label)
|
|
for portal in room.portals:
|
|
graph.edge(room.name, portal.destination, label=portal.name)
|
|
|
|
graph_path = path.dirname(world.name)
|
|
graph.render(directory=graph_path)
|
|
|
|
|
|
def snapshot_world(world: World, turn: int):
|
|
# save the world itself, along with the turn number and the memory of each agent
|
|
json_world = RootModel[World](world).model_dump()
|
|
|
|
json_memory = {}
|
|
|
|
for character, agent in get_all_character_agents():
|
|
json_memory[character.name] = list(agent.memory or [])
|
|
|
|
return {
|
|
"world": json_world,
|
|
"memory": json_memory,
|
|
"turn": turn,
|
|
}
|
|
|
|
|
|
def restore_memory(
|
|
data: Sequence[str | Dict[str, str]]
|
|
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
|
|
config = get_game_config()
|
|
memories = []
|
|
|
|
for memory in data:
|
|
if isinstance(memory, str):
|
|
memories.append(memory)
|
|
elif isinstance(memory, dict):
|
|
memory_content = memory["content"]
|
|
memory_type = memory["type"]
|
|
|
|
if memory_type == "human":
|
|
memories.append(HumanMessage(content=memory_content))
|
|
elif memory_type == "system":
|
|
memories.append(SystemMessage(content=memory_content))
|
|
elif memory_type == "ai":
|
|
memories.append(AIMessage(content=memory_content))
|
|
|
|
return deque(memories, maxlen=config.world.character.memory_limit)
|
|
|
|
|
|
def save_world(world, filename):
|
|
with open(filename, "w") as f:
|
|
json_world = RootModel[World](world).model_dump_json(indent=2)
|
|
f.write(json_world)
|
|
|
|
|
|
def save_world_state(world, turn, filename):
|
|
graph_world(world, turn)
|
|
json_state = snapshot_world(world, turn)
|
|
with open(filename, "w") as f:
|
|
dump(json_state, f, default=world_json, indent=2)
|
|
|
|
|
|
def world_json(obj):
|
|
if isinstance(obj, BaseMessage):
|
|
return {
|
|
"content": obj.content,
|
|
"type": obj.type,
|
|
}
|
|
|
|
raise ValueError(f"Cannot serialize {obj}")
|