1
0
Fork 0
taleweave-ai/adventure/state.py

112 lines
3.5 KiB
Python
Raw Normal View History

2024-05-02 11:56:57 +00:00
from collections import deque
from json import dump
from os import path
2024-05-02 11:56:57 +00:00
from typing import Dict, List, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from packit.agent import Agent, agent_easy_connect
from pydantic import RootModel
from adventure.context import get_all_character_agents, set_character_agent
2024-05-09 02:11:16 +00:00
from adventure.models.entity import World
from adventure.player import LocalPlayer
2024-05-02 11:56:57 +00:00
2024-05-27 03:42:08 +00:00
MEMORY_LIMIT = 25 # 10
2024-05-02 11:56:57 +00:00
def create_agents(
world: World,
memory: Dict[str, List[str | Dict[str, str]]] = {},
players: List[str] = [],
):
# set up agents for each character
2024-05-02 11:56:57 +00:00
llm = agent_easy_connect()
for room in world.rooms:
for character in room.characters:
if character.name in players:
agent = LocalPlayer(character.name, character.backstory)
agent_memory = restore_memory(memory.get(character.name, []))
agent.load_history(agent_memory)
else:
agent = Agent(character.name, character.backstory, {}, llm)
agent.memory = restore_memory(memory.get(character.name, []))
set_character_agent(character.name, character, agent)
2024-05-02 11:56:57 +00:00
2024-05-27 12:54:36 +00:00
def graph_world(world: World, turn: int):
2024-05-02 11:56:57 +00:00
import graphviz
2024-05-27 12:54:36 +00:00
graph_name = f"{path.basename(world.name)}-{turn}"
graph = graphviz.Digraph(graph_name, format="png")
2024-05-02 11:56:57 +00:00
for room in world.rooms:
characters = [character.name for character in room.characters]
room_label = "\n".join([room.name, *characters])
graph.node(room.name, room_label)
for portal in room.portals:
graph.edge(room.name, portal.destination, label=portal.name)
2024-05-02 11:56:57 +00:00
graph_path = path.dirname(world.name)
graph.render(directory=graph_path)
2024-05-02 11:56:57 +00:00
2024-05-27 12:54:36 +00:00
def snapshot_world(world: World, turn: int):
# save the world itself, along with the turn number and the memory of each agent
2024-05-02 11:56:57 +00:00
json_world = RootModel[World](world).model_dump()
json_memory = {}
for character, agent in get_all_character_agents():
json_memory[character.name] = list(agent.memory or [])
2024-05-02 11:56:57 +00:00
return {
"world": json_world,
"memory": json_memory,
2024-05-27 12:54:36 +00:00
"turn": turn,
2024-05-02 11:56:57 +00:00
}
def restore_memory(
data: Sequence[str | Dict[str, str]]
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
memories = []
for memory in data:
if isinstance(memory, str):
memories.append(memory)
elif isinstance(memory, dict):
memory_content = memory["content"]
memory_type = memory["type"]
if memory_type == "human":
memories.append(HumanMessage(content=memory_content))
elif memory_type == "system":
memories.append(SystemMessage(content=memory_content))
elif memory_type == "ai":
memories.append(AIMessage(content=memory_content))
2024-05-27 03:42:08 +00:00
return deque(memories, maxlen=MEMORY_LIMIT)
2024-05-02 11:56:57 +00:00
def save_world(world, filename):
with open(filename, "w") as f:
json_world = RootModel[World](world).model_dump_json(indent=2)
f.write(json_world)
2024-05-27 12:54:36 +00:00
def save_world_state(world, turn, filename):
graph_world(world, turn)
json_state = snapshot_world(world, turn)
2024-05-02 11:56:57 +00:00
with open(filename, "w") as f:
2024-05-04 04:18:21 +00:00
dump(json_state, f, default=world_json, indent=2)
2024-05-02 11:56:57 +00:00
2024-05-04 04:18:21 +00:00
def world_json(obj):
if isinstance(obj, BaseMessage):
return {
"content": obj.content,
"type": obj.type,
}
2024-05-02 11:56:57 +00:00
2024-05-04 04:18:21 +00:00
raise ValueError(f"Cannot serialize {obj}")