diff --git a/.gitignore b/.gitignore index 1b36d05..0103980 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +adventure/custom_actions.py worlds/ __pycache__/ diff --git a/adventure/actions.py b/adventure/actions.py index dee3a9a..20c53cd 100644 --- a/adventure/actions.py +++ b/adventure/actions.py @@ -163,9 +163,3 @@ def action_give(character: str, item_name: str) -> str: destination_actor.items.append(item) return f"You give the {item_name} item to {character}." - - -def action_stop() -> str: - _, _, action_actor = get_current_context() - logger.info(f"{action_actor.name} end their turn") - return "You stop your actions and end your turn." diff --git a/adventure/generate.py b/adventure/generate.py index c089082..6b1cbd4 100644 --- a/adventure/generate.py +++ b/adventure/generate.py @@ -101,7 +101,7 @@ def generate_actor( ) -def generate_world(agent: Agent, theme: str) -> World: +def generate_world(agent: Agent, name: str, theme: str) -> World: room_count = randint(3, 5) logger.info(f"Generating a {theme} with {room_count} rooms") @@ -177,4 +177,4 @@ def generate_world(agent: Agent, theme: str) -> World: room.portals[direction] = dest_room.name dest_room.portals[opposite_directions[direction]] = room.name - return World(rooms=rooms, theme=theme) + return World(name=name, rooms=rooms, theme=theme) diff --git a/adventure/main.py b/adventure/main.py index 5cd4873..f4a55b4 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -1,3 +1,4 @@ +from importlib import import_module from json import load from os import path @@ -53,7 +54,7 @@ def world_result_parser(value, agent, **kwargs): return multi_function_or_str_result(value, agent=agent, **kwargs) -def simulate_world(world: World, steps: int = 10, callback=None): +def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[]): logger.info("Simulating the world") # collect actors, so they are only processed once @@ -68,6 +69,7 @@ def simulate_world(world: World, steps: int = 10, callback=None): action_move, action_take, action_tell, + *extra_actions, ] ) @@ -105,6 +107,7 @@ def simulate_world(world: World, steps: int = 10, callback=None): "What will you do next? Reply with a JSON function call, calling one of the actions." ), context={ + # TODO: add custom action names or remove this list entirely "actions": [ "ask", "give", @@ -147,7 +150,7 @@ def simulate_world(world: World, steps: int = 10, callback=None): if callback: callback(world, current_step) - current_step += 1 + set_step(current_step + 1) # main @@ -157,6 +160,9 @@ def parse_args(): parser = argparse.ArgumentParser( description="Generate and simulate a fantasy world" ) + parser.add_argument( + "--actions", type=str, help="Extra actions to include in the simulation" + ) parser.add_argument( "--steps", type=int, default=10, help="The number of simulation steps to run" ) @@ -166,13 +172,13 @@ def parse_args(): parser.add_argument( "--world", type=str, - default="world.json", + default="world", help="The file to save the generated world to", ) parser.add_argument( - "--world-state", + "--state", type=str, - default="world-state.json", + # default="world-state.json", help="The file to save the world state to", ) return parser.parse_args() @@ -181,18 +187,23 @@ def parse_args(): def main(): args = parse_args() - if args.world_state and path.exists(args.world_state): - logger.info(f"Loading world state from {args.world_state}") - with open(args.world_state, "r") as f: + world_file = args.world + ".json" + world_state_file = args.state or (args.world + ".state.json") + + if path.exists(world_state_file): + logger.info(f"Loading world state from {world_state_file}") + with open(world_state_file, "r") as f: state = WorldState(**load(f)) set_step(state.step) create_agents(state.world, state.memory) + world = state.world - elif args.world and path.exists(args.world): - logger.info(f"Loading world from {args.world}") - with open(args.world, "r") as f: - world = World(**load(f)) + world.name = args.world + elif path.exists(world_file): + logger.info(f"Loading world from {world_file}") + with open(world_file, "r") as f: + world = World(**load(f), name=args.world) create_agents(world) else: logger.info(f"Generating a new {args.theme} world") @@ -203,18 +214,29 @@ def main(): {}, llm, ) - world = generate_world(agent, args.theme) + world = generate_world(agent, args.world, args.theme) create_agents(world) + save_world(world, world_file) - logger.debug("Loaded world: %s", world) - - if args.world: - save_world(world, args.world) + # load extra actions + extra_actions = [] + if args.actions: + logger.info(f"Loading extra actions from {args.actions}") + action_module, action_function = args.actions.rsplit(":", 1) + action_module = import_module(action_module) + action_function = getattr(action_module, action_function) + module_actions = action_function() + logger.info( + f"Loaded extra actions: {[action.__name__ for action in module_actions]}" + ) + extra_actions.append(module_actions) + logger.debug("Simulating world: %s", world) simulate_world( world, steps=args.steps, - callback=lambda w, s: save_world_state(w, s, args.world_state), + callback=lambda w, s: save_world_state(w, s, world_state_file), + extra_actions=extra_actions, ) diff --git a/adventure/models.py b/adventure/models.py index a1cd826..6ba0cc3 100644 --- a/adventure/models.py +++ b/adventure/models.py @@ -40,6 +40,7 @@ class Room: @dataclass class World: + name: str rooms: List[Room] theme: str diff --git a/adventure/state.py b/adventure/state.py index 5c4b4f0..15a1b2e 100644 --- a/adventure/state.py +++ b/adventure/state.py @@ -1,5 +1,6 @@ from collections import deque from json import dump +from os import path from typing import Dict, List, Sequence from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage @@ -27,11 +28,12 @@ def graph_world(world: World, step: int): graph = graphviz.Digraph(f"{world.theme}-{step}", format="png") for room in world.rooms: room_label = "\n".join([room.name, *[actor.name for actor in room.actors]]) - graph.node(room.name, room_label) # , room.description) + graph.node(room.name, room_label) for direction, destination in room.portals.items(): graph.edge(room.name, destination, label=direction) - graph.render(directory="worlds", view=True) + graph_path = path.dirname(world.name) + graph.render(directory=graph_path) def snapshot_world(world: World, step: int):