1
0
Fork 0
taleweave-ai/adventure/player.py

204 lines
5.6 KiB
Python
Raw Normal View History

from json import dumps
from logging import getLogger
2024-05-05 14:14:54 +00:00
from queue import Queue
from readline import add_history
2024-05-08 01:42:10 +00:00
from typing import Any, Callable, Dict, List, Optional, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from packit.agent import Agent
from packit.utils import could_be_json
from adventure.context import action_context
2024-05-09 02:11:16 +00:00
from adventure.models.event import PromptEvent
logger = getLogger(__name__)
2024-05-08 01:42:10 +00:00
# Dict[client, player]
active_players: Dict[str, "BasePlayer"] = {}
def get_player(client: str) -> Optional["BasePlayer"]:
"""
Get a player by name.
"""
return active_players.get(client, None)
def set_player(client: str, player: "BasePlayer"):
"""
Add a player to the active players.
"""
if has_player(player.name):
raise ValueError(f"Someone is already playing as {player.name}!")
active_players[client] = player
def remove_player(client: str):
"""
Remove a player from the active players.
"""
if client in active_players:
del active_players[client]
def has_player(character_name: str) -> bool:
"""
Check if a character is already being played.
"""
return character_name in [player.name for player in active_players.values()]
def list_players():
return {client: player.name for client, player in active_players.items()}
2024-05-05 14:14:54 +00:00
class BasePlayer:
"""
A human agent that can interact with the world.
"""
name: str
backstory: str
memory: List[str | BaseMessage]
def __init__(self, name: str, backstory: str) -> None:
self.name = name
self.backstory = backstory
self.memory = []
def load_history(self, lines: Sequence[str | BaseMessage]):
"""
Load the history of the player's input.
"""
self.memory.extend(lines)
for line in lines:
if isinstance(line, BaseMessage):
add_history(str(line.content))
else:
add_history(line)
def invoke(self, prompt: str, context: Dict[str, Any], **kwargs) -> Any:
"""
Ask the player for input.
"""
return self(prompt, **context)
2024-05-05 14:14:54 +00:00
def parse_input(self, reply: str):
# if the reply starts with a tilde, it is a literal response and should be returned without the tilde
if reply.startswith("~"):
reply = reply[1:]
self.memory.append(AIMessage(content=reply))
return reply
# if the reply is JSON or a special command, return it as-is
if could_be_json(reply) or reply.lower() in ["end", ""]:
self.memory.append(AIMessage(content=reply))
return reply
# turn other replies into a JSON function call
action, *param_rest = reply.split(":", 1)
param_str = ",".join(param_rest or [])
param_pairs = param_str.split(",")
def parse_value(value: str) -> str | bool | float | int:
if value.startswith("~"):
return value[1:]
if value.lower() in ["true", "false"]:
return value.lower() == "true"
if value.isdecimal():
return float(value)
if value.isnumeric():
return int(value)
return value
params = {
key.strip(): parse_value(value.strip())
for key, value in (
pair.split("=", 1) for pair in param_pairs if len(pair.strip()) > 0
)
}
reply_json = dumps(
{
"function": action,
"parameters": params,
}
)
self.memory.append(AIMessage(content=reply_json))
return reply_json
2024-05-05 14:14:54 +00:00
def __call__(self, prompt: str, **kwargs) -> str:
raise NotImplementedError("Subclasses must implement this method")
class LocalPlayer(BasePlayer):
def __call__(self, prompt: str, **kwargs) -> str:
"""
Ask the player for input.
"""
logger.info("prompting local player: {self.name}")
2024-05-05 14:14:54 +00:00
formatted_prompt = prompt.format(**kwargs)
self.memory.append(HumanMessage(content=formatted_prompt))
print(formatted_prompt)
reply = input(">>> ")
reply = reply.strip()
return self.parse_input(reply)
class RemotePlayer(BasePlayer):
fallback_agent: Agent | None
2024-05-05 14:14:54 +00:00
input_queue: Queue[str]
2024-05-09 02:11:16 +00:00
send_prompt: Callable[[PromptEvent], bool]
2024-05-05 14:14:54 +00:00
def __init__(
2024-05-05 22:46:24 +00:00
self,
name: str,
backstory: str,
2024-05-09 02:11:16 +00:00
send_prompt: Callable[[PromptEvent], bool],
2024-05-05 22:46:24 +00:00
fallback_agent=None,
) -> None:
2024-05-05 14:14:54 +00:00
super().__init__(name, backstory)
self.fallback_agent = fallback_agent
2024-05-05 14:14:54 +00:00
self.input_queue = Queue()
self.send_prompt = send_prompt
def __call__(self, prompt: str, **kwargs) -> str:
"""
Ask the player for input.
"""
formatted_prompt = prompt.format(**kwargs)
self.memory.append(HumanMessage(content=formatted_prompt))
with action_context() as (current_room, current_actor):
prompt_event = PromptEvent(
prompt=formatted_prompt, room=current_room, actor=current_actor
)
2024-05-09 02:11:16 +00:00
try:
logger.info(f"prompting remote player: {self.name}")
if self.send_prompt(prompt_event):
reply = self.input_queue.get(timeout=60)
logger.info(f"got reply from remote player: {reply}")
return self.parse_input(reply)
except Exception:
logger.exception("error getting reply from remote player")
2024-05-05 14:14:54 +00:00
if self.fallback_agent:
return self.fallback_agent(prompt, **kwargs)
return ""