1
0
Fork 0
taleweave-ai/taleweave/engine.py

159 lines
5.1 KiB
Python

from functools import partial
from itertools import count
from logging import getLogger
from os import path
from typing import List
from packit.agent import Agent, agent_easy_connect
from packit.memory import make_limited_memory
from taleweave.context import (
get_current_turn,
get_system_data,
set_current_turn,
set_dungeon_master,
set_system_data,
)
from taleweave.game_system import GameSystem
from taleweave.generate import generate_room, generate_world, link_rooms
from taleweave.models.config import Config
from taleweave.models.entity import World, WorldState
from taleweave.models.files import WorldPrompt
from taleweave.state import create_agents, save_world
from taleweave.utils.file import load_yaml
from taleweave.utils.template import format_prompt
logger = getLogger(__name__)
def load_or_initialize_system_data(
world_path: str, systems: List[GameSystem], world: World
):
for system in systems:
if system.data:
system_data_file = f"{world_path}.{system.name}.json"
if path.exists(system_data_file):
logger.info(f"loading system data from {system_data_file}")
data = system.data.load(system_data_file)
set_system_data(system.name, data)
continue
else:
logger.info(f"no system data found at {system_data_file}")
if system.initialize:
logger.info(f"initializing system data for {system.name}")
data = system.initialize(world)
set_system_data(system.name, data)
def save_system_data(world_path: str, systems: List[GameSystem]):
for system in systems:
if system.data:
system_data_file = f"{world_path}.{system.name}.json"
logger.info(f"saving system data to {system_data_file}")
system.data.save(system_data_file, get_system_data(system.name))
def load_or_generate_world(
world_path: str,
state_path: str | None,
config: Config,
players, # TODO: type me
systems: List[GameSystem],
world_prompt: WorldPrompt,
add_rooms: int = 0,
room_count: int | None = None,
):
world_file = world_path + ".json"
world_state_file = state_path or (world_path + ".state.json")
memory = {}
turn = 0
# prepare an agent for the world builder
llm = agent_easy_connect()
memory_factory = partial(
make_limited_memory, limit=config.world.character.memory_limit
)
world_builder = Agent(
"World Builder",
format_prompt(
"world_generate_dungeon_master",
flavor=world_prompt.flavor,
theme=world_prompt.theme,
),
{},
llm,
memory_factory=memory_factory,
)
set_dungeon_master(world_builder)
if path.exists(world_state_file):
logger.info(f"loading world state from {world_state_file}")
with open(world_state_file, "r") as f:
state = WorldState(**load_yaml(f))
set_current_turn(state.turn)
load_or_initialize_system_data(world_path, systems, state.world)
memory = state.memory
turn = state.turn
world = state.world
elif path.exists(world_file):
logger.info(f"loading world from {world_file}")
with open(world_file, "r") as f:
world = World(**load_yaml(f))
load_or_initialize_system_data(world_path, systems, world)
else:
logger.info(f"generating a new world using theme: {world_prompt.theme}")
world = generate_world(
world_builder,
world_path,
world_prompt.theme,
systems,
room_count=room_count,
)
load_or_initialize_system_data(world_path, systems, world)
# TODO: check if there have been any changes before saving
save_world(world, world_file)
save_system_data(world_path, systems)
if add_rooms:
new_rooms = []
for i in range(add_rooms):
logger.info(f"generating room {i + 1} of {add_rooms}")
room = generate_room(
world_builder, world, systems, current_room=i, total_rooms=add_rooms
)
new_rooms.append(room)
world.rooms.append(room)
# if the world was already full, no new rooms will be added
if new_rooms:
link_rooms(world_builder, world, systems, new_rooms)
# create agents for each character after adding any new rooms
create_agents(world, memory=memory, players=players)
return (world, world_state_file, turn)
def simulate_world(world: World, systems: List[GameSystem], turns: int):
# run game systems for each turn
logger.info(f"simulating the world for {turns} turns using systems: {systems}")
for i in count():
current_turn = get_current_turn()
logger.info(f"simulating turn {i} of {turns} (world turn {current_turn})")
for system in systems:
if system.simulate:
logger.info(f"running system {system.name}")
system.simulate(world, current_turn)
set_current_turn(current_turn + 1)
if i >= turns:
logger.info("reached turn limit at world turn %s", current_turn + 1)
break