1
0
Fork 0
taleweave-ai/adventure/main.py

322 lines
10 KiB
Python
Raw Normal View History

from importlib import import_module
2024-05-02 11:56:57 +00:00
from json import load
2024-05-03 01:57:11 +00:00
from os import environ, path
2024-05-04 04:18:21 +00:00
from typing import Callable, Dict, Sequence, Tuple
2024-05-02 11:25:35 +00:00
2024-05-03 01:57:11 +00:00
from dotenv import load_dotenv
2024-05-02 11:25:35 +00:00
from packit.agent import Agent, agent_easy_connect
2024-05-03 01:57:11 +00:00
from packit.loops import loop_retry
2024-05-02 11:25:35 +00:00
from packit.results import multi_function_or_str_result
from packit.toolbox import Toolbox
from packit.utils import logger_with_colors
2024-05-02 11:56:57 +00:00
from adventure.actions import (
action_ask,
action_give,
action_look,
action_move,
action_take,
action_tell,
)
from adventure.context import (
2024-05-04 04:18:21 +00:00
get_actor_agent_for_name,
2024-05-02 11:56:57 +00:00
get_actor_for_agent,
get_current_world,
get_step,
set_current_actor,
set_current_room,
set_current_world,
set_step,
)
from adventure.generate import generate_world
2024-05-04 04:18:21 +00:00
from adventure.models import Actor, Room, World, WorldState
2024-05-02 11:56:57 +00:00
from adventure.state import create_agents, save_world, save_world_state
2024-05-02 11:25:35 +00:00
logger = logger_with_colors(__name__)
2024-05-03 01:57:11 +00:00
load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True)
2024-05-02 11:25:35 +00:00
# simulation
def world_result_parser(value, agent, **kwargs):
2024-05-02 11:56:57 +00:00
current_world = get_current_world()
2024-05-02 11:25:35 +00:00
if not current_world:
raise ValueError(
"The current world must be set before calling world_result_parser"
)
logger.debug(f"parsing action for {agent.name}: {value}")
2024-05-02 11:56:57 +00:00
current_actor = get_actor_for_agent(agent)
2024-05-02 11:25:35 +00:00
current_room = next(
(room for room in current_world.rooms if current_actor in room.actors), None
)
2024-05-02 11:56:57 +00:00
set_current_room(current_room)
set_current_actor(current_actor)
2024-05-02 11:25:35 +00:00
2024-05-02 11:56:57 +00:00
return multi_function_or_str_result(value, agent=agent, **kwargs)
2024-05-02 11:25:35 +00:00
2024-05-04 04:18:21 +00:00
def simulate_world(
world: World,
steps: int = 10,
actions: Sequence[Callable[..., str]] = [],
systems: Sequence[
Tuple[Callable[[World, int], None], Callable[[Dict[str, str]], str] | None]
] = [],
input_callbacks: Sequence[Callable[[Room, Actor, str], None]] = [],
result_callbacks: Sequence[Callable[[Room, Actor, str], None]] = [],
):
2024-05-02 11:25:35 +00:00
logger.info("Simulating the world")
2024-05-04 04:18:21 +00:00
set_current_world(world)
2024-05-02 11:25:35 +00:00
2024-05-04 04:18:21 +00:00
# build a toolbox for the actions
2024-05-02 11:25:35 +00:00
action_tools = Toolbox(
[
action_ask,
action_give,
action_look,
action_move,
action_take,
action_tell,
2024-05-04 04:18:21 +00:00
*actions,
2024-05-02 11:25:35 +00:00
]
)
action_names = action_tools.list_tools()
2024-05-02 11:25:35 +00:00
# simulate each actor
for i in range(steps):
2024-05-02 11:56:57 +00:00
current_step = get_step()
2024-05-02 11:25:35 +00:00
logger.info(f"Simulating step {current_step}")
2024-05-04 04:18:21 +00:00
for actor_name in world.order:
actor, agent = get_actor_agent_for_name(actor_name)
if not agent or not actor:
logger.error(f"Agent or actor not found for name {actor_name}")
2024-05-02 11:56:57 +00:00
continue
2024-05-02 11:25:35 +00:00
room = next((room for room in world.rooms if actor in room.actors), None)
if not room:
2024-05-04 04:18:21 +00:00
logger.error(f"Actor {actor_name} is not in a room")
2024-05-02 11:25:35 +00:00
continue
room_actors = [actor.name for actor in room.actors]
room_items = [item.name for item in room.items]
room_directions = list(room.portals.keys())
2024-05-04 04:18:21 +00:00
actor_attributes = " ".join(
system_format(actor.attributes)
for _, system_format in systems
if system_format
)
actor_items = [item.name for item in actor.items]
def result_parser(value, agent, **kwargs):
for callback in input_callbacks:
callback(room, actor, value)
return world_result_parser(value, agent, **kwargs)
logger.info("starting turn for actor: %s", actor_name)
2024-05-03 01:57:11 +00:00
result = loop_retry(
2024-05-02 11:25:35 +00:00
agent,
(
2024-05-04 04:18:21 +00:00
"You are currently in {room_name}. {room_description}. {attributes}. "
"The room contains the following characters: {visible_actors}. "
"The room contains the following items: {visible_items}. "
"Your inventory contains the following items: {actor_items}."
2024-05-02 11:25:35 +00:00
"You can take the following actions: {actions}. "
"You can move in the following directions: {directions}. "
"What will you do next? Reply with a JSON function call, calling one of the actions."
"You can only take one action per turn. Pick the most important action and save the rest for later."
"What is your action?"
2024-05-02 11:25:35 +00:00
),
context={
"actions": action_names,
2024-05-04 04:18:21 +00:00
"actor_items": actor_items,
"attributes": actor_attributes,
2024-05-02 11:25:35 +00:00
"directions": room_directions,
"room_name": room.name,
"room_description": room.description,
2024-05-04 04:18:21 +00:00
"visible_actors": room_actors,
"visible_items": room_items,
2024-05-02 11:25:35 +00:00
},
2024-05-04 04:18:21 +00:00
result_parser=result_parser,
2024-05-02 11:25:35 +00:00
toolbox=action_tools,
)
2024-05-03 01:57:11 +00:00
logger.debug(f"{actor.name} step result: {result}")
agent.memory.append(result)
2024-05-02 11:25:35 +00:00
2024-05-04 04:18:21 +00:00
for callback in result_callbacks:
callback(room, actor, result)
for system_update, _ in systems:
system_update(world, current_step)
2024-05-02 11:25:35 +00:00
set_step(current_step + 1)
2024-05-02 11:25:35 +00:00
# main
def parse_args():
import argparse
parser = argparse.ArgumentParser(
2024-05-04 04:18:21 +00:00
description="Generate and simulate a text adventure world"
2024-05-02 11:25:35 +00:00
)
parser.add_argument(
2024-05-04 04:18:21 +00:00
"--actions",
type=str,
nargs="*",
help="Extra actions to include in the simulation",
)
parser.add_argument(
"--flavor", type=str, help="Some additional flavor text for the generated world"
)
parser.add_argument(
"--player", type=str, help="The name of the character to play as"
)
2024-05-03 01:57:11 +00:00
parser.add_argument(
"--rooms", type=int, default=5, help="The number of rooms to generate"
)
parser.add_argument(
"--max-rooms", type=int, help="The maximum number of rooms to generate"
)
2024-05-04 04:18:21 +00:00
parser.add_argument(
"--server", type=str, help="The address on which to run the server"
)
parser.add_argument(
"--state",
type=str,
# default="world.state.json",
help="The file to save the world state to. Defaults to $world.state.json, if not set",
)
2024-05-02 11:25:35 +00:00
parser.add_argument(
"--steps", type=int, default=10, help="The number of simulation steps to run"
)
2024-05-04 04:18:21 +00:00
parser.add_argument(
"--systems",
type=str,
nargs="*",
help="Extra logic systems to run in the simulation",
)
2024-05-02 11:25:35 +00:00
parser.add_argument(
"--theme", type=str, default="fantasy", help="The theme of the generated world"
)
parser.add_argument(
"--world",
type=str,
default="world",
2024-05-02 11:25:35 +00:00
help="The file to save the generated world to",
)
return parser.parse_args()
def main():
args = parse_args()
world_file = args.world + ".json"
world_state_file = args.state or (args.world + ".state.json")
players = []
if args.player:
players.append(args.player)
memory = {}
if path.exists(world_state_file):
logger.info(f"Loading world state from {world_state_file}")
with open(world_state_file, "r") as f:
2024-05-02 11:25:35 +00:00
state = WorldState(**load(f))
2024-05-02 11:56:57 +00:00
set_step(state.step)
memory = state.memory
2024-05-02 11:25:35 +00:00
world = state.world
world.name = args.world
elif path.exists(world_file):
logger.info(f"Loading world from {world_file}")
with open(world_file, "r") as f:
world = World(**load(f))
2024-05-02 11:25:35 +00:00
else:
logger.info(f"Generating a new {args.theme} world")
llm = agent_easy_connect()
agent = Agent(
"World Builder",
f"You are an experienced game master creating a visually detailed {args.theme} world for a new adventure. {args.flavor}",
2024-05-02 11:25:35 +00:00
{},
llm,
)
2024-05-03 01:57:11 +00:00
world = generate_world(
2024-05-04 04:18:21 +00:00
agent,
args.world,
args.theme,
room_count=args.rooms,
max_rooms=args.max_rooms,
2024-05-03 01:57:11 +00:00
)
save_world(world, world_file)
create_agents(world, memory=memory, players=players)
# load extra actions
extra_actions = []
2024-05-04 04:18:21 +00:00
for action_name in args.actions:
logger.info(f"Loading extra actions from {action_name}")
module_actions = load_plugin(action_name)
logger.info(
f"Loaded extra actions: {[action.__name__ for action in module_actions]}"
)
extra_actions.extend(module_actions)
2024-05-02 11:25:35 +00:00
2024-05-04 04:18:21 +00:00
# load extra systems
def snapshot_system(world: World, step: int) -> None:
logger.debug("Snapshotting world state")
save_world_state(world, step, world_state_file)
extra_systems = [(snapshot_system, None)]
for system_name in args.systems:
logger.info(f"Loading extra systems from {system_name}")
module_systems = load_plugin(system_name)
logger.info(
f"Loaded extra systems: {[system.__name__ for system in module_systems]}"
)
extra_systems.append(module_systems)
# make sure the server system is last
input_callbacks = []
result_callbacks = []
if args.server:
from adventure.server import (
launch_server,
server_input,
server_result,
server_system,
)
launch_server()
extra_systems.append((server_system, None))
input_callbacks.append(server_input)
result_callbacks.append(server_result)
# start the sim
logger.debug("Simulating world: %s", world)
2024-05-02 11:56:57 +00:00
simulate_world(
world,
steps=args.steps,
2024-05-04 04:18:21 +00:00
actions=extra_actions,
systems=extra_systems,
input_callbacks=input_callbacks,
result_callbacks=result_callbacks,
2024-05-02 11:56:57 +00:00
)
2024-05-02 11:25:35 +00:00
2024-05-04 04:18:21 +00:00
def load_plugin(name):
module_name, function_name = name.rsplit(":", 1)
plugin_module = import_module(module_name)
plugin_entry = getattr(plugin_module, function_name)
return plugin_entry()
2024-05-02 11:25:35 +00:00
if __name__ == "__main__":
main()