diff --git a/adventure/generate.py b/adventure/generate.py index 8904d07..91d6d8e 100644 --- a/adventure/generate.py +++ b/adventure/generate.py @@ -17,8 +17,14 @@ OPPOSITE_DIRECTIONS = { } +GenerateCallback = Callable[[str], None] + + def generate_room( - agent: Agent, world_theme: str, existing_rooms: List[str], callback + agent: Agent, + world_theme: str, + callback: GenerateCallback | None = None, + existing_rooms: List[str] = [], ) -> Room: def unique_name(name: str, **kwargs): if name in existing_rooms: @@ -37,7 +43,10 @@ def generate_room( }, result_parser=unique_name, ) - callback(f"Generating room: {name}") + + if callable(callback): + callback(f"Generating room: {name}") + desc = agent( "Generate a detailed description of the {name} area. What does it look like? " "What does it smell like? What can be seen or heard?", @@ -56,10 +65,10 @@ def generate_room( def generate_item( agent: Agent, world_theme: str, - existing_items: List[str], - callback, + callback: Callable[[str], None] | None = None, dest_room: str | None = None, dest_actor: str | None = None, + existing_items: List[str] = [], ) -> Item: if dest_actor: dest_note = "The item will be held by the {dest_actor} character" @@ -87,7 +96,10 @@ def generate_item( }, result_parser=unique_name, ) - callback(f"Generating item: {name}") + + if callable(callback): + callback(f"Generating item: {name}") + desc = agent( "Generate a detailed description of the {name} item. What does it look like? What is it made of? What does it do?", name=name, @@ -99,7 +111,11 @@ def generate_item( def generate_actor( - agent: Agent, world_theme: str, dest_room: str, existing_actors: List[str], callback + agent: Agent, + world_theme: str, + dest_room: str, + callback: GenerateCallback | None = None, + existing_actors: List[str] = [], ) -> Actor: def unique_name(name: str, **kwargs): if name in existing_actors: @@ -120,7 +136,10 @@ def generate_actor( }, result_parser=unique_name, ) - callback(f"Generating actor: {name}") + + if callable(callback): + callback(f"Generating actor: {name}") + description = agent( "Generate a detailed description of the {name} character. What do they look like? What are they wearing? " "What are they doing? Describe their appearance from the perspective of an outside observer." @@ -147,10 +166,12 @@ def generate_world( theme: str, room_count: int | None = None, max_rooms: int = 5, - callback: Callable[[str], None] = lambda x: None, + callback: Callable[[str], None] | None = None, ) -> World: room_count = room_count or randint(3, max_rooms) - callback(f"Generating a {theme} with {room_count} rooms") + + if callable(callback): + callback(f"Generating a {theme} with {room_count} rooms") existing_actors: List[str] = [] existing_items: List[str] = [] @@ -159,12 +180,16 @@ def generate_world( # generate the rooms rooms = [] for i in range(room_count): - room = generate_room(agent, theme, existing_rooms, callback=callback) + room = generate_room( + agent, theme, existing_rooms=existing_rooms, callback=callback + ) rooms.append(room) existing_rooms.append(room.name) item_count = randint(0, 3) - callback(f"Generating {item_count} items for room: {room.name}") + + if callable(callback): + callback(f"Generating {item_count} items for room: {room.name}") for j in range(item_count): item = generate_item( @@ -178,7 +203,9 @@ def generate_world( existing_items.append(item.name) actor_count = randint(0, 3) - callback(f"Generating {actor_count} actors for room: {room.name}") + + if callable(callback): + callback(f"Generating {actor_count} actors for room: {room.name}") for j in range(actor_count): actor = generate_actor( @@ -193,7 +220,9 @@ def generate_world( # generate the actor's inventory item_count = randint(0, 3) - callback(f"Generating {item_count} items for actor {actor.name}") + + if callable(callback): + callback(f"Generating {item_count} items for actor {actor.name}") for k in range(item_count): item = generate_item( diff --git a/adventure/logic.py b/adventure/logic.py new file mode 100644 index 0000000..d5acf41 --- /dev/null +++ b/adventure/logic.py @@ -0,0 +1,142 @@ +from logging import getLogger +from random import random +from typing import Callable, Dict, List, Optional +from functools import partial + +from rule_engine import Rule +from pydantic import Field +from yaml import Loader, load + +from adventure.models import Actor, Item, Room, World, dataclass +from adventure.plugins import get_plugin_function + +logger = getLogger(__name__) + + +@dataclass +class LogicLabel: + backstory: str + description: str + + +@dataclass +class LogicRule: + chance: float = 1.0 + group: Optional[str] = None + match: Optional[Dict[str, str]] = None + remove: Optional[List[str]] = None + rule: Optional[str] = None + set: Optional[Dict[str, str]] = None + trigger: Optional[List[str]] = None + + +@dataclass +class LogicTable: + rules: List[LogicRule] + labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict) + + +LogicTrigger = Callable[[Room | Actor | Item, Dict[str, str]], Dict[str, str]] +TriggerTable = Dict[LogicRule, List[LogicTrigger]] + + +def update_attributes( + entity: Room | Actor | Item, + attributes: Dict[str, str], + rules: LogicTable, + triggers: TriggerTable, +) -> Dict[str, str]: + entity_type = entity.__class__.__name__.lower() + skip_groups = set() + + for rule in rules.rules: + if rule.group: + if rule.group in skip_groups: + logger.debug("skipping logic group: %s", rule.group) + continue + + skip_groups.add(rule.group) + + typed_attributes = { + **attributes, + "type": entity_type, + } + + if rule.rule: + # TODO: pre-compile rules + rule_impl = Rule(rule.rule) + if not rule_impl.matches({ + "attributes": typed_attributes, + }): + logger.debug("logic rule did not match attributes: %s", rule.rule) + continue + + if rule.match and not(rule.match.items() <= typed_attributes.items()): + logger.debug("logic did not match attributes: %s", rule.match) + continue + + logger.info("matched logic: %s", rule.match) + if rule.chance < 1: + if random() > rule.chance: + logger.info("logic skipped by chance: %s", rule.chance) + continue + + for key in rule.remove or []: + attributes.pop(key, None) + + if rule.set: + attributes.update(rule.set) + logger.info("logic set state: %s", rule.set) + + if rule in triggers: + for trigger in triggers[rule]: + attributes = trigger(entity, attributes) + + return attributes + + +def update_logic(world: World, step: int, rules: LogicTable, triggers: TriggerTable) -> None: + for room in world.rooms: + room.attributes = update_attributes(room, room.attributes, rules=rules, triggers=triggers) + for actor in room.actors: + actor.attributes = update_attributes(actor, actor.attributes, rules=rules, triggers=triggers) + for item in actor.items: + item.attributes = update_attributes(item, item.attributes, rules=rules, triggers=triggers) + for item in room.items: + item.attributes = update_attributes(item, item.attributes, rules=rules, triggers=triggers) + + logger.info("updated world attributes") + + +def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> str: + labels = [] + + for attribute, value in attributes.items(): + if attribute in rules.labels and value in rules.labels[attribute]: + label = rules.labels[attribute][value] + if self: + labels.append(label.backstory) + else: + labels.append(label.description) + + if len(labels) > 0: + logger.info("adding labels: %s", labels) + + return " ".join(labels) + + +def init_from_file(filename: str): + logger.info("loading logic from file: %s", filename) + with open(filename) as file: + logic_rules = LogicTable(**load(file, Loader=Loader)) + logic_triggers = { + rule: [get_plugin_function(trigger) for trigger in rule.trigger] + for rule in logic_rules.rules + if rule.trigger + } + + logger.info("initialized logic system") + return ( + partial(update_logic, rules=logic_rules, triggers=logic_triggers), + partial(format_logic, rules=logic_rules) + ) diff --git a/adventure/main.py b/adventure/main.py index c79160a..d7546f3 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -228,7 +228,7 @@ def parse_args(): "--systems", type=str, nargs="*", - help="Extra logic systems to run in the simulation", + help="Extra systems to run in the simulation", ) parser.add_argument( "--theme", type=str, default="fantasy", help="The theme of the generated world" diff --git a/adventure/optional_actions.py b/adventure/optional_actions.py index 3f82a2d..94c96d6 100644 --- a/adventure/optional_actions.py +++ b/adventure/optional_actions.py @@ -3,13 +3,15 @@ from typing import Callable, List from packit.agent import Agent, agent_easy_connect -from adventure.context import broadcast, get_current_context +from adventure.context import broadcast, get_agent_for_actor, get_current_context from adventure.generate import OPPOSITE_DIRECTIONS, generate_item, generate_room logger = getLogger(__name__) llm = agent_easy_connect() + +# TODO: provide dungeon master with the world theme dungeon_master = Agent( "dungeon master", "You are the dungeon master in charge of a fantasy world.", @@ -37,7 +39,7 @@ def action_explore(direction: str) -> str: existing_rooms = [room.name for room in current_world.rooms] new_room = generate_room( - dungeon_master, current_world.theme, existing_rooms, callback=lambda x: x + dungeon_master, current_world.theme, existing_rooms=existing_rooms ) current_world.rooms.append(new_room) @@ -68,7 +70,6 @@ def action_search() -> str: action_world.theme, existing_items=existing_items, dest_room=action_room.name, - callback=lambda x: x, ) action_room.items.append(new_item) @@ -78,6 +79,46 @@ def action_search() -> str: return f"You search the room and find a new item: {new_item.name}" +def action_use(item: str, target: str) -> str: + """ + Use an item on yourself or another character in the room. + + Args: + item: The name of the item to use. + target: The name of the character to use the item on, or "self" to use the item on yourself. + """ + _, action_room, action_actor = get_current_context() + + available_items = [item.name for item in action_actor.items] + [item.name for item in action_room.items] + + if item not in available_items: + return f"The {item} item is not available to use." + + if target == "self": + target_actor = action_actor + target = action_actor.name + else: + target_actor = next( + (actor for actor in action_room.actors if actor.name == target), None + ) + if not target_actor: + return f"The {target} character is not in the room." + + broadcast(f"{action_actor.name} uses {item} on {target}") + outcome = dungeon_master( + f"{action_actor.name} uses {item} on {target}. {action_actor.description}. {target_actor.description}. What happens? How does {target} react? " + "Specify the outcome of the action. Do not include the question or any JSON. Only include the outcome of the action." + ) + broadcast(f"The action resulted in: {outcome}") + + # make sure both agents remember the outcome + target_agent = get_agent_for_actor(target_actor) + if target_agent: + target_agent.memory.append(outcome) + + return outcome + + def init() -> List[Callable]: """ Initialize the custom actions. @@ -85,4 +126,5 @@ def init() -> List[Callable]: return [ action_explore, action_search, + action_use, ] diff --git a/adventure/player.py b/adventure/player.py index cbb29c3..9a80271 100644 --- a/adventure/player.py +++ b/adventure/player.py @@ -1,12 +1,13 @@ from json import dumps from readline import add_history -from typing import Any, Dict, List, Sequence +from queue import Queue +from typing import Any, Callable, Dict, List, Sequence from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from packit.utils import could_be_json -class LocalPlayer: +class BasePlayer: """ A human agent that can interact with the world. """ @@ -40,18 +41,7 @@ class LocalPlayer: return self(prompt, **context) - def __call__(self, prompt: str, **kwargs) -> str: - """ - Ask the player for input. - """ - - formatted_prompt = prompt.format(**kwargs) - self.memory.append(HumanMessage(content=formatted_prompt)) - print(formatted_prompt) - - reply = input(">>> ") - reply = reply.strip() - + def parse_input(self, reply: str): # if the reply starts with a tilde, it is a literal response and should be returned without the tilde if reply.startswith("~"): reply = reply[1:] @@ -94,3 +84,50 @@ class LocalPlayer: ) self.memory.append(AIMessage(content=reply_json)) return reply_json + + def __call__(self, prompt: str, **kwargs) -> str: + raise NotImplementedError("Subclasses must implement this method") + + +class LocalPlayer(BasePlayer): + def __call__(self, prompt: str, **kwargs) -> str: + """ + Ask the player for input. + """ + + formatted_prompt = prompt.format(**kwargs) + self.memory.append(HumanMessage(content=formatted_prompt)) + print(formatted_prompt) + + reply = input(">>> ") + reply = reply.strip() + + return self.parse_input(reply) + + +class RemotePlayer(BasePlayer): + input_queue: Queue[str] + send_prompt: Callable[[str, str], bool] + + def __init__(self, name: str, backstory: str, send_prompt: Callable[[str, str], bool]) -> None: + super().__init__(name, backstory) + self.input_queue = Queue() + self.send_prompt = send_prompt + + def __call__(self, prompt: str, **kwargs) -> str: + """ + Ask the player for input. + """ + + formatted_prompt = prompt.format(**kwargs) + self.memory.append(HumanMessage(content=formatted_prompt)) + + try: + if self.send_prompt(self.name, formatted_prompt): + reply = self.input_queue.get(timeout=60) + return self.parse_input(reply) + except Exception: + pass + + # logger.warning("Failed to send prompt to remote player") + return "" diff --git a/adventure/server.py b/adventure/server.py index 98c021f..c9f0711 100644 --- a/adventure/server.py +++ b/adventure/server.py @@ -1,18 +1,22 @@ import asyncio from collections import deque -from json import dumps +from json import dumps, loads from logging import getLogger from threading import Thread +from typing import Dict, Tuple import websockets +from adventure.context import get_actor_agent_for_name from adventure.models import Actor, Room, World +from adventure.player import RemotePlayer from adventure.state import snapshot_world, world_json logger = getLogger(__name__) connected = set() -recent_events = deque(maxlen=10) +characters: Dict[str, RemotePlayer] = {} +recent_events = deque(maxlen=100) recent_world = None @@ -20,6 +24,20 @@ async def handler(websocket): logger.info("Client connected") connected.add(websocket) + async def next_turn(character: str, prompt: str) -> None: + await websocket.send(connected, dumps({ + "type": "turn", + "character": character, + "prompt": prompt, + })) + + def sync_turn(character: str, prompt: str) -> bool: + if websocket not in characters: + return False + + asyncio.run(next_turn(character, prompt)) + return True + try: if recent_world: await websocket.send(recent_world) @@ -31,12 +49,44 @@ async def handler(websocket): while True: try: + # if this socket is attached to a character and that character's turn is active, wait for input message = await websocket.recv() - print(message) + logger.info(f"Received message: {message}") + + try: + data = loads(message) + if "become" in data: + character = characters.get(websocket) + if character: + del characters[websocket] + + character_name = data["become"] + actor, _ = get_actor_agent_for_name(character_name) + if not actor: + logger.error(f"Failed to find actor {character_name}") + continue + + if character_name in [player.name for player in characters.values()]: + logger.error(f"Character {character_name} is already in use") + continue + + characters[websocket] = RemotePlayer(actor.name, actor.backstory, sync_turn) + logger.info(f"Client {websocket} is now character {character_name}") + elif websocket in characters: + player = characters[websocket] + player.input_queue.put(message) + + except Exception: + logger.exception("Failed to parse message") except websockets.ConnectionClosedOK: break connected.remove(websocket) + + # TODO: swap out the character for the original agent + if websocket in characters: + del characters[websocket] + logger.info("Client disconnected") diff --git a/adventure/systems/logic.py b/adventure/systems/logic.py deleted file mode 100644 index bab9c3e..0000000 --- a/adventure/systems/logic.py +++ /dev/null @@ -1,103 +0,0 @@ -from logging import getLogger -from random import random -from typing import Dict, List, Optional - -from pydantic import Field -from yaml import Loader, load - -from adventure.models import Actor, Item, Room, World, dataclass -from adventure.plugins import get_plugin_function - -logger = getLogger(__name__) - - -@dataclass -class LogicLabel: - backstory: str - description: str - - -@dataclass -class LogicRule: - match: Dict[str, str] - chance: float = 1.0 - remove: Optional[List[str]] = None - set: Optional[Dict[str, str]] = None - trigger: Optional[List[str]] = None - - -@dataclass -class LogicTable: - rules: List[LogicRule] - labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict) - - -with open("./worlds/logic.yaml") as file: - logic_rules = LogicTable(**load(file, Loader=Loader)) - logic_triggers = { - rule: [get_plugin_function(trigger) for trigger in rule.trigger] - for rule in logic_rules.rules - if rule.trigger - } - - -def update_attributes( - entity: Room | Actor | Item, - attributes: Dict[str, str], - dataset: LogicTable, -) -> Dict[str, str]: - for rule in dataset.rules: - if rule.match.items() <= attributes.items(): - logger.info("matched logic: %s", rule.match) - if rule.chance < 1: - if random() > rule.chance: - logger.info("logic skipped by chance: %s", rule.chance) - continue - - if rule.set: - attributes.update(rule.set) - logger.info("logic set state: %s", rule.set) - - for key in rule.remove or []: - attributes.pop(key, None) - - if rule in logic_triggers: - for trigger in logic_triggers[rule]: - attributes = trigger(entity, attributes) - - return attributes - - -def update_logic(world: World, step: int) -> None: - for room in world.rooms: - room.attributes = update_attributes(room, room.attributes, logic_rules) - for actor in room.actors: - actor.attributes = update_attributes(actor, actor.attributes, logic_rules) - for item in actor.items: - item.attributes = update_attributes(item, item.attributes, logic_rules) - for item in room.items: - item.attributes = update_attributes(item, item.attributes, logic_rules) - - logger.info("updated world attributes") - - -def format_logic(attributes: Dict[str, str], self=True) -> str: - labels = [] - - for attribute, value in attributes.items(): - if attribute in logic_rules.labels and value in logic_rules.labels[attribute]: - label = logic_rules.labels[attribute][value] - if self: - labels.append(label.backstory) - else: - labels.append(label.description) - - if len(labels) > 0: - logger.info("adding labels: %s", labels) - - return " ".join(labels) - - -def init(): - logger.info("initialized logic system") - return (update_logic, format_logic)