1
0
Fork 0

allow manual control of characters, improve prompts and error handling

This commit is contained in:
Sean Sube 2024-05-02 18:17:13 -05:00
parent 4d7db75ffb
commit e803f40b75
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 146 additions and 51 deletions

View File

@ -84,7 +84,7 @@ def action_ask(character: str, question: str) -> str:
# sanity checks
if character == action_actor.name:
return "You cannot ask yourself a question. Stop talking to yourself."
return "You cannot ask yourself a question. You have wasted your turn. Stop talking to yourself."
question_actor, question_agent = get_actor_agent_for_name(character)
if not question_actor:
@ -95,8 +95,8 @@ def action_ask(character: str, question: str) -> str:
logger.info(f"{action_actor.name} asks {character}: {question}")
answer = question_agent(
f"{action_actor.name} asks you: {question}. Reply with your response. "
f"Do not include the question or any other text, only your reply to {action_actor.name}."
f"{action_actor.name} asks you: {question}. Reply with your response to them. "
f"Do not include the question or any JSON. Only include your answer for {action_actor.name}."
)
if could_be_json(answer) and action_tell.__name__ in answer:
@ -120,7 +120,7 @@ def action_tell(character: str, message: str) -> str:
# sanity checks
if character == action_actor.name:
return "You cannot tell yourself a message. Stop talking to yourself."
return "You cannot tell yourself a message. You have wasted your turn. Stop talking to yourself."
question_actor, question_agent = get_actor_agent_for_name(character)
if not question_actor:
@ -131,8 +131,8 @@ def action_tell(character: str, message: str) -> str:
logger.info(f"{action_actor.name} tells {character}: {message}")
answer = question_agent(
f"{action_actor.name} tells you: {message}. Reply with your response. "
f"Do not include the message or any other text, only your reply to {action_actor.name}."
f"{action_actor.name} tells you: {message}. Reply with your response to them. "
f"Do not include the message or any JSON. Only include your reply to {action_actor.name}."
)
if could_be_json(answer) and action_tell.__name__ in answer:

View File

@ -72,6 +72,7 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[
*extra_actions,
]
)
action_names = action_tools.list_tools()
# create a result parser that will memorize the actor and room
set_current_world(world)
@ -105,17 +106,11 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[
"You can take the following actions: {actions}. "
"You can move in the following directions: {directions}. "
"What will you do next? Reply with a JSON function call, calling one of the actions."
"You can only take one action per turn. Pick the most important action and save the rest for later."
"What is your action?"
),
context={
# TODO: add custom action names or remove this list entirely
"actions": [
"ask",
"give",
"look",
"move",
"take",
"tell",
], # , "use"],
"actions": action_names,
"actors": room_actors,
"directions": room_directions,
"items": room_items,
@ -127,25 +122,7 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[
)
logger.info(f"{actor.name} step result: {result}")
# if result was JSON, it has already been parsed and executed. anything remaining is flavor text
# that should be presented back to the actor
# TODO: inject this directly in the agent's memory rather than reprompting them
response = agent(
"The result of your last action was: {result}. Your turn is over, no further actions will be accepted. "
'If you understand, reply with the word "end".',
result=result,
)
logger.debug(f"{actor.name} step response: '{response}'")
if response.strip().lower() not in ["end", ""]:
logger.warning(
f"{actor.name} responded after the end of their turn: %s", response
)
response = agent(
"Your turn is over, no further actions will be accepted. Do not reply."
)
logger.debug(f"{actor.name} warning response: {response}")
agent.memory.append(result)
if callback:
callback(world, current_step)
@ -163,6 +140,18 @@ def parse_args():
parser.add_argument(
"--actions", type=str, help="Extra actions to include in the simulation"
)
parser.add_argument(
"--flavor", type=str, help="Some additional flavor text for the generated world"
)
parser.add_argument(
"--player", type=str, help="The name of the character to play as"
)
parser.add_argument(
"--state",
type=str,
# default="world.state.json",
help="The file to save the world state to. Defaults to $world.state.json, if not set",
)
parser.add_argument(
"--steps", type=int, default=10, help="The number of simulation steps to run"
)
@ -175,12 +164,6 @@ def parse_args():
default="world",
help="The file to save the generated world to",
)
parser.add_argument(
"--state",
type=str,
# default="world-state.json",
help="The file to save the world state to",
)
return parser.parse_args()
@ -190,34 +173,39 @@ def main():
world_file = args.world + ".json"
world_state_file = args.state or (args.world + ".state.json")
players = []
if args.player:
players.append(args.player)
memory = {}
if path.exists(world_state_file):
logger.info(f"Loading world state from {world_state_file}")
with open(world_state_file, "r") as f:
state = WorldState(**load(f))
set_step(state.step)
create_agents(state.world, state.memory)
memory = state.memory
world = state.world
world.name = args.world
elif path.exists(world_file):
logger.info(f"Loading world from {world_file}")
with open(world_file, "r") as f:
world = World(**load(f), name=args.world)
create_agents(world)
world = World(**load(f))
else:
logger.info(f"Generating a new {args.theme} world")
llm = agent_easy_connect()
agent = Agent(
"world builder",
f"You are an experienced game master creating a visually detailed {args.theme} world for a new adventure.",
"World Builder",
f"You are an experienced game master creating a visually detailed {args.theme} world for a new adventure. {args.flavor}",
{},
llm,
)
world = generate_world(agent, args.world, args.theme)
create_agents(world)
save_world(world, world_file)
create_agents(world, memory=memory, players=players)
# load extra actions
extra_actions = []
if args.actions:
@ -229,7 +217,7 @@ def main():
logger.info(
f"Loaded extra actions: {[action.__name__ for action in module_actions]}"
)
extra_actions.append(module_actions)
extra_actions.extend(module_actions)
logger.debug("Simulating world: %s", world)
simulate_world(

96
adventure/player.py Normal file
View File

@ -0,0 +1,96 @@
from json import dumps
from readline import add_history
from typing import Any, Dict, List, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from packit.utils import could_be_json
class LocalPlayer:
"""
A human agent that can interact with the world.
"""
name: str
backstory: str
memory: List[str | BaseMessage]
def __init__(self, name: str, backstory: str) -> None:
self.name = name
self.backstory = backstory
self.memory = []
def load_history(self, lines: Sequence[str | BaseMessage]):
"""
Load the history of the player's input.
"""
self.memory.extend(lines)
for line in lines:
if isinstance(line, BaseMessage):
add_history(str(line.content))
else:
add_history(line)
def invoke(self, prompt: str, context: Dict[str, Any], **kwargs) -> Any:
"""
Ask the player for input.
"""
return self(prompt, **context)
def __call__(self, prompt: str, **kwargs) -> str:
"""
Ask the player for input.
"""
formatted_prompt = prompt.format(**kwargs)
self.memory.append(HumanMessage(content=formatted_prompt))
print(formatted_prompt)
reply = input(">>> ")
reply = reply.strip()
# if the reply starts with a tilde, it is a literal response and should be returned without the tilde
if reply.startswith("~"):
reply = reply[1:]
self.memory.append(AIMessage(content=reply))
return reply
# if the reply is JSON or a special command, return it as-is
if could_be_json(reply) or reply.lower() in ["end", ""]:
self.memory.append(AIMessage(content=reply))
return reply
# turn other replies into a JSON function call
action, *param_rest = reply.split(":", 1)
param_str = ",".join(param_rest or [])
param_pairs = param_str.split(",")
def parse_value(value: str) -> str | bool | float | int:
if value.startswith("~"):
return value[1:]
if value.lower() in ["true", "false"]:
return value.lower() == "true"
if value.isdecimal():
return float(value)
if value.isnumeric():
return int(value)
return value
params = {
key.strip(): parse_value(value.strip())
for key, value in (
pair.split("=", 1) for pair in param_pairs if len(pair.strip()) > 0
)
}
reply_json = dumps(
{
"function": action,
"parameters": params,
}
)
self.memory.append(AIMessage(content=reply_json))
return reply_json

View File

@ -9,23 +9,34 @@ from pydantic import RootModel
from adventure.context import get_all_actor_agents, set_actor_agent_for_name
from adventure.models import World
from adventure.player import LocalPlayer
def create_agents(world: World, memory: Dict[str, List[str | Dict[str, str]]] = {}):
def create_agents(
world: World,
memory: Dict[str, List[str | Dict[str, str]]] = {},
players: List[str] = [],
):
# set up agents for each actor
llm = agent_easy_connect()
for room in world.rooms:
for actor in room.actors:
agent = Agent(actor.name, actor.backstory, {}, llm)
agent.memory = restore_memory(memory.get(actor.name, []))
if actor.name in players:
agent = LocalPlayer(actor.name, actor.backstory)
agent_memory = restore_memory(memory.get(actor.name, []))
agent.load_history(agent_memory)
else:
agent = Agent(actor.name, actor.backstory, {}, llm)
agent.memory = restore_memory(memory.get(actor.name, []))
set_actor_agent_for_name(actor.name, actor, agent)
def graph_world(world: World, step: int):
import graphviz
graph = graphviz.Digraph(f"{world.theme}-{step}", format="png")
graph_name = f"{path.basename(world.name)}-{step}"
graph = graphviz.Digraph(graph_name, format="png")
for room in world.rooms:
room_label = "\n".join([room.name, *[actor.name for actor in room.actors]])
graph.node(room.name, room_label)