1
0
Fork 0
taleweave-ai/taleweave/state.py

118 lines
3.7 KiB
Python
Raw Permalink Normal View History

2024-05-02 11:56:57 +00:00
from collections import deque
from json import dump
from os import path
2024-05-02 11:56:57 +00:00
from typing import Dict, List, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from packit.agent import Agent, agent_easy_connect
from taleweave.context import (
get_all_character_agents,
get_game_config,
set_character_agent,
)
2024-06-08 02:18:56 +00:00
from taleweave.models.base import dump_model, dump_model_json
from taleweave.models.entity import World
from taleweave.player import LocalPlayer
from taleweave.utils.template import format_prompt
2024-05-02 11:56:57 +00:00
def create_agents(
world: World,
memory: Dict[str, List[str | Dict[str, str]]] = {},
players: List[str] = [],
):
# set up agents for each character
2024-05-02 11:56:57 +00:00
llm = agent_easy_connect()
for room in world.rooms:
for character in room.characters:
if character.name in players:
agent = LocalPlayer(character.name, character.backstory)
agent_memory = restore_memory(memory.get(character.name, []))
agent.load_history(agent_memory)
else:
backstory = format_prompt(
"world_agent_backstory", character=character, world=world
)
agent = Agent(character.name, backstory, {}, llm)
agent.memory = restore_memory(memory.get(character.name, []))
set_character_agent(character.name, character, agent)
2024-05-02 11:56:57 +00:00
2024-05-27 12:54:36 +00:00
def graph_world(world: World, turn: int):
2024-05-02 11:56:57 +00:00
import graphviz
2024-05-27 12:54:36 +00:00
graph_name = f"{path.basename(world.name)}-{turn}"
graph = graphviz.Digraph(graph_name, format="png")
2024-05-02 11:56:57 +00:00
for room in world.rooms:
characters = [character.name for character in room.characters]
room_label = "\n".join([room.name, *characters])
graph.node(room.name, room_label)
for portal in room.portals:
graph.edge(room.name, portal.destination, label=portal.name)
2024-05-02 11:56:57 +00:00
graph_path = path.dirname(world.name)
graph.render(directory=graph_path)
2024-05-02 11:56:57 +00:00
2024-05-27 12:54:36 +00:00
def snapshot_world(world: World, turn: int):
# save the world itself, along with the turn number and the memory of each agent
2024-06-08 02:18:56 +00:00
json_world = dump_model(World, world)
2024-05-02 11:56:57 +00:00
json_memory = {}
for character, agent in get_all_character_agents():
json_memory[character.name] = list(agent.memory or [])
2024-05-02 11:56:57 +00:00
return {
"world": json_world,
"memory": json_memory,
2024-05-27 12:54:36 +00:00
"turn": turn,
2024-05-02 11:56:57 +00:00
}
def restore_memory(
data: Sequence[str | Dict[str, str]]
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
config = get_game_config()
2024-05-02 11:56:57 +00:00
memories = []
for memory in data:
if isinstance(memory, str):
memories.append(memory)
elif isinstance(memory, dict):
memory_content = memory["content"]
memory_type = memory["type"]
if memory_type == "human":
memories.append(HumanMessage(content=memory_content))
elif memory_type == "system":
memories.append(SystemMessage(content=memory_content))
elif memory_type == "ai":
memories.append(AIMessage(content=memory_content))
return deque(memories, maxlen=config.world.character.memory_limit)
2024-05-02 11:56:57 +00:00
def save_world(world, filename):
with open(filename, "w") as f:
2024-06-08 02:18:56 +00:00
json_world = dump_model_json(World, world)
2024-05-02 11:56:57 +00:00
f.write(json_world)
2024-05-27 12:54:36 +00:00
def save_world_state(world, turn, filename):
graph_world(world, turn)
json_state = snapshot_world(world, turn)
2024-05-02 11:56:57 +00:00
with open(filename, "w") as f:
2024-05-04 04:18:21 +00:00
dump(json_state, f, default=world_json, indent=2)
2024-05-02 11:56:57 +00:00
2024-05-04 04:18:21 +00:00
def world_json(obj):
if isinstance(obj, BaseMessage):
return {
"content": obj.content,
"type": obj.type,
}
2024-05-02 11:56:57 +00:00
2024-05-04 04:18:21 +00:00
raise ValueError(f"Cannot serialize {obj}")