From e803f40b7583b8fca3309ae657094bfdbca99863 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 2 May 2024 18:17:13 -0500 Subject: [PATCH] allow manual control of characters, improve prompts and error handling --- adventure/actions.py | 12 +++--- adventure/main.py | 70 +++++++++++++------------------- adventure/player.py | 96 ++++++++++++++++++++++++++++++++++++++++++++ adventure/state.py | 19 +++++++-- 4 files changed, 146 insertions(+), 51 deletions(-) create mode 100644 adventure/player.py diff --git a/adventure/actions.py b/adventure/actions.py index 20c53cd..792b942 100644 --- a/adventure/actions.py +++ b/adventure/actions.py @@ -84,7 +84,7 @@ def action_ask(character: str, question: str) -> str: # sanity checks if character == action_actor.name: - return "You cannot ask yourself a question. Stop talking to yourself." + return "You cannot ask yourself a question. You have wasted your turn. Stop talking to yourself." question_actor, question_agent = get_actor_agent_for_name(character) if not question_actor: @@ -95,8 +95,8 @@ def action_ask(character: str, question: str) -> str: logger.info(f"{action_actor.name} asks {character}: {question}") answer = question_agent( - f"{action_actor.name} asks you: {question}. Reply with your response. " - f"Do not include the question or any other text, only your reply to {action_actor.name}." + f"{action_actor.name} asks you: {question}. Reply with your response to them. " + f"Do not include the question or any JSON. Only include your answer for {action_actor.name}." ) if could_be_json(answer) and action_tell.__name__ in answer: @@ -120,7 +120,7 @@ def action_tell(character: str, message: str) -> str: # sanity checks if character == action_actor.name: - return "You cannot tell yourself a message. Stop talking to yourself." + return "You cannot tell yourself a message. You have wasted your turn. Stop talking to yourself." question_actor, question_agent = get_actor_agent_for_name(character) if not question_actor: @@ -131,8 +131,8 @@ def action_tell(character: str, message: str) -> str: logger.info(f"{action_actor.name} tells {character}: {message}") answer = question_agent( - f"{action_actor.name} tells you: {message}. Reply with your response. " - f"Do not include the message or any other text, only your reply to {action_actor.name}." + f"{action_actor.name} tells you: {message}. Reply with your response to them. " + f"Do not include the message or any JSON. Only include your reply to {action_actor.name}." ) if could_be_json(answer) and action_tell.__name__ in answer: diff --git a/adventure/main.py b/adventure/main.py index f4a55b4..f855296 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -72,6 +72,7 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[ *extra_actions, ] ) + action_names = action_tools.list_tools() # create a result parser that will memorize the actor and room set_current_world(world) @@ -105,17 +106,11 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[ "You can take the following actions: {actions}. " "You can move in the following directions: {directions}. " "What will you do next? Reply with a JSON function call, calling one of the actions." + "You can only take one action per turn. Pick the most important action and save the rest for later." + "What is your action?" ), context={ - # TODO: add custom action names or remove this list entirely - "actions": [ - "ask", - "give", - "look", - "move", - "take", - "tell", - ], # , "use"], + "actions": action_names, "actors": room_actors, "directions": room_directions, "items": room_items, @@ -127,25 +122,7 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[ ) logger.info(f"{actor.name} step result: {result}") - - # if result was JSON, it has already been parsed and executed. anything remaining is flavor text - # that should be presented back to the actor - # TODO: inject this directly in the agent's memory rather than reprompting them - response = agent( - "The result of your last action was: {result}. Your turn is over, no further actions will be accepted. " - 'If you understand, reply with the word "end".', - result=result, - ) - - logger.debug(f"{actor.name} step response: '{response}'") - if response.strip().lower() not in ["end", ""]: - logger.warning( - f"{actor.name} responded after the end of their turn: %s", response - ) - response = agent( - "Your turn is over, no further actions will be accepted. Do not reply." - ) - logger.debug(f"{actor.name} warning response: {response}") + agent.memory.append(result) if callback: callback(world, current_step) @@ -163,6 +140,18 @@ def parse_args(): parser.add_argument( "--actions", type=str, help="Extra actions to include in the simulation" ) + parser.add_argument( + "--flavor", type=str, help="Some additional flavor text for the generated world" + ) + parser.add_argument( + "--player", type=str, help="The name of the character to play as" + ) + parser.add_argument( + "--state", + type=str, + # default="world.state.json", + help="The file to save the world state to. Defaults to $world.state.json, if not set", + ) parser.add_argument( "--steps", type=int, default=10, help="The number of simulation steps to run" ) @@ -175,12 +164,6 @@ def parse_args(): default="world", help="The file to save the generated world to", ) - parser.add_argument( - "--state", - type=str, - # default="world-state.json", - help="The file to save the world state to", - ) return parser.parse_args() @@ -190,34 +173,39 @@ def main(): world_file = args.world + ".json" world_state_file = args.state or (args.world + ".state.json") + players = [] + if args.player: + players.append(args.player) + + memory = {} if path.exists(world_state_file): logger.info(f"Loading world state from {world_state_file}") with open(world_state_file, "r") as f: state = WorldState(**load(f)) set_step(state.step) - create_agents(state.world, state.memory) + memory = state.memory world = state.world world.name = args.world elif path.exists(world_file): logger.info(f"Loading world from {world_file}") with open(world_file, "r") as f: - world = World(**load(f), name=args.world) - create_agents(world) + world = World(**load(f)) else: logger.info(f"Generating a new {args.theme} world") llm = agent_easy_connect() agent = Agent( - "world builder", - f"You are an experienced game master creating a visually detailed {args.theme} world for a new adventure.", + "World Builder", + f"You are an experienced game master creating a visually detailed {args.theme} world for a new adventure. {args.flavor}", {}, llm, ) world = generate_world(agent, args.world, args.theme) - create_agents(world) save_world(world, world_file) + create_agents(world, memory=memory, players=players) + # load extra actions extra_actions = [] if args.actions: @@ -229,7 +217,7 @@ def main(): logger.info( f"Loaded extra actions: {[action.__name__ for action in module_actions]}" ) - extra_actions.append(module_actions) + extra_actions.extend(module_actions) logger.debug("Simulating world: %s", world) simulate_world( diff --git a/adventure/player.py b/adventure/player.py new file mode 100644 index 0000000..cbb29c3 --- /dev/null +++ b/adventure/player.py @@ -0,0 +1,96 @@ +from json import dumps +from readline import add_history +from typing import Any, Dict, List, Sequence + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from packit.utils import could_be_json + + +class LocalPlayer: + """ + A human agent that can interact with the world. + """ + + name: str + backstory: str + memory: List[str | BaseMessage] + + def __init__(self, name: str, backstory: str) -> None: + self.name = name + self.backstory = backstory + self.memory = [] + + def load_history(self, lines: Sequence[str | BaseMessage]): + """ + Load the history of the player's input. + """ + + self.memory.extend(lines) + + for line in lines: + if isinstance(line, BaseMessage): + add_history(str(line.content)) + else: + add_history(line) + + def invoke(self, prompt: str, context: Dict[str, Any], **kwargs) -> Any: + """ + Ask the player for input. + """ + + return self(prompt, **context) + + def __call__(self, prompt: str, **kwargs) -> str: + """ + Ask the player for input. + """ + + formatted_prompt = prompt.format(**kwargs) + self.memory.append(HumanMessage(content=formatted_prompt)) + print(formatted_prompt) + + reply = input(">>> ") + reply = reply.strip() + + # if the reply starts with a tilde, it is a literal response and should be returned without the tilde + if reply.startswith("~"): + reply = reply[1:] + self.memory.append(AIMessage(content=reply)) + return reply + + # if the reply is JSON or a special command, return it as-is + if could_be_json(reply) or reply.lower() in ["end", ""]: + self.memory.append(AIMessage(content=reply)) + return reply + + # turn other replies into a JSON function call + action, *param_rest = reply.split(":", 1) + param_str = ",".join(param_rest or []) + param_pairs = param_str.split(",") + + def parse_value(value: str) -> str | bool | float | int: + if value.startswith("~"): + return value[1:] + if value.lower() in ["true", "false"]: + return value.lower() == "true" + if value.isdecimal(): + return float(value) + if value.isnumeric(): + return int(value) + return value + + params = { + key.strip(): parse_value(value.strip()) + for key, value in ( + pair.split("=", 1) for pair in param_pairs if len(pair.strip()) > 0 + ) + } + + reply_json = dumps( + { + "function": action, + "parameters": params, + } + ) + self.memory.append(AIMessage(content=reply_json)) + return reply_json diff --git a/adventure/state.py b/adventure/state.py index 15a1b2e..07a8a15 100644 --- a/adventure/state.py +++ b/adventure/state.py @@ -9,23 +9,34 @@ from pydantic import RootModel from adventure.context import get_all_actor_agents, set_actor_agent_for_name from adventure.models import World +from adventure.player import LocalPlayer -def create_agents(world: World, memory: Dict[str, List[str | Dict[str, str]]] = {}): +def create_agents( + world: World, + memory: Dict[str, List[str | Dict[str, str]]] = {}, + players: List[str] = [], +): # set up agents for each actor llm = agent_easy_connect() for room in world.rooms: for actor in room.actors: - agent = Agent(actor.name, actor.backstory, {}, llm) - agent.memory = restore_memory(memory.get(actor.name, [])) + if actor.name in players: + agent = LocalPlayer(actor.name, actor.backstory) + agent_memory = restore_memory(memory.get(actor.name, [])) + agent.load_history(agent_memory) + else: + agent = Agent(actor.name, actor.backstory, {}, llm) + agent.memory = restore_memory(memory.get(actor.name, [])) set_actor_agent_for_name(actor.name, actor, agent) def graph_world(world: World, step: int): import graphviz - graph = graphviz.Digraph(f"{world.theme}-{step}", format="png") + graph_name = f"{path.basename(world.name)}-{step}" + graph = graphviz.Digraph(graph_name, format="png") for room in world.rooms: room_label = "\n".join([room.name, *[actor.name for actor in room.actors]]) graph.node(room.name, room_label)