From ebf4ccf1c49f7aa0a88ea19c2eae554705730996 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 3 May 2024 23:18:21 -0500 Subject: [PATCH] add server and logic systems --- adventure/generate.py | 26 ++++--- adventure/main.py | 150 ++++++++++++++++++++++++++++--------- adventure/models.py | 7 +- adventure/server.py | 90 ++++++++++++++++++++++ adventure/state.py | 16 ++-- adventure/systems/logic.py | 90 ++++++++++++++++++++++ 6 files changed, 322 insertions(+), 57 deletions(-) create mode 100644 adventure/server.py create mode 100644 adventure/systems/logic.py diff --git a/adventure/generate.py b/adventure/generate.py index cecfb73..9dbfe3a 100644 --- a/adventure/generate.py +++ b/adventure/generate.py @@ -49,7 +49,8 @@ def generate_item( name = agent( "Generate one item or object that would make sense in the world of {world_theme}. {dest_note}. " - 'Only respond with the item name, do not include a description or any other text. Do not prefix the name with "the", do not wrap it in quotes. ' + "Only respond with the item name, do not include a description or any other text. Do not prefix the " + 'name with "the", do not wrap it in quotes. Do not include the name of the room. ' "Do not create any duplicate items in the same room. Do not give characters any duplicate items. The existing items are: {existing_items}", dest_note=dest_note, existing_items=existing_items, @@ -72,7 +73,8 @@ def generate_actor( name = agent( "Generate one person or creature that would make sense in the world of {world_theme}. The character will be placed in the {dest_room} room. " 'Only respond with the character name, do not include a description or any other text. Do not prefix the name with "the", do not wrap it in quotes. ' - "Do not create any duplicate characters in the same room. The existing characters are: {existing_actors}", + "Do not include the name of the room. Do not give characters any duplicate names." + "Do not create any duplicate characters. The existing characters are: {existing_actors}", dest_room=dest_room, existing_actors=existing_actors, world_theme=world_theme, @@ -90,22 +92,22 @@ def generate_actor( name=name, ) - health = 100 - actions = {} - return Actor( name=name, backstory=backstory, description=description, - health=health, - actions=actions, + actions={}, ) def generate_world( - agent: Agent, name: str, theme: str, rooms: int | None = None, max_rooms: int = 5 + agent: Agent, + name: str, + theme: str, + room_count: int | None = None, + max_rooms: int = 5, ) -> World: - room_count = rooms or randint(3, max_rooms) + room_count = room_count or randint(3, max_rooms) logger.info(f"Generating a {theme} with {room_count} rooms") existing_actors: List[str] = [] @@ -150,7 +152,7 @@ def generate_world( "west": "east", } - # TODO: generate portals to link the rooms together + # generate portals to link the rooms together for room in rooms: directions = ["north", "south", "east", "west"] for direction in directions: @@ -180,4 +182,6 @@ def generate_world( room.portals[direction] = dest_room.name dest_room.portals[opposite_directions[direction]] = room.name - return World(name=name, rooms=rooms, theme=theme) + # ensure actors act in a stable order + order = [actor.name for room in rooms for actor in room.actors] + return World(name=name, rooms=rooms, theme=theme, order=order) diff --git a/adventure/main.py b/adventure/main.py index b3da68f..70effd6 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -1,6 +1,7 @@ from importlib import import_module from json import load from os import environ, path +from typing import Callable, Dict, Sequence, Tuple from dotenv import load_dotenv from packit.agent import Agent, agent_easy_connect @@ -18,8 +19,8 @@ from adventure.actions import ( action_tell, ) from adventure.context import ( + get_actor_agent_for_name, get_actor_for_agent, - get_agent_for_actor, get_current_world, get_step, set_current_actor, @@ -28,7 +29,7 @@ from adventure.context import ( set_step, ) from adventure.generate import generate_world -from adventure.models import World, WorldState +from adventure.models import Actor, Room, World, WorldState from adventure.state import create_agents, save_world, save_world_state logger = logger_with_colors(__name__) @@ -57,13 +58,20 @@ def world_result_parser(value, agent, **kwargs): return multi_function_or_str_result(value, agent=agent, **kwargs) -def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[]): +def simulate_world( + world: World, + steps: int = 10, + actions: Sequence[Callable[..., str]] = [], + systems: Sequence[ + Tuple[Callable[[World, int], None], Callable[[Dict[str, str]], str] | None] + ] = [], + input_callbacks: Sequence[Callable[[Room, Actor, str], None]] = [], + result_callbacks: Sequence[Callable[[Room, Actor, str], None]] = [], +): logger.info("Simulating the world") + set_current_world(world) - # collect actors, so they are only processed once - all_actors = [actor for room in world.rooms for actor in room.actors] - - # TODO: add actions for: drop, use, attack, cast, jump, climb, swim, fly, etc. + # build a toolbox for the actions action_tools = Toolbox( [ action_ask, @@ -72,40 +80,51 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[ action_move, action_take, action_tell, - *extra_actions, + *actions, ] ) action_names = action_tools.list_tools() - # create a result parser that will memorize the actor and room - set_current_world(world) - # simulate each actor for i in range(steps): current_step = get_step() logger.info(f"Simulating step {current_step}") - for actor in all_actors: - agent = get_agent_for_actor(actor) - if not agent: - logger.error(f"Agent not found for actor {actor.name}") + for actor_name in world.order: + actor, agent = get_actor_agent_for_name(actor_name) + if not agent or not actor: + logger.error(f"Agent or actor not found for name {actor_name}") continue room = next((room for room in world.rooms if actor in room.actors), None) if not room: - logger.error(f"Actor {actor.name} is not in a room") + logger.error(f"Actor {actor_name} is not in a room") continue room_actors = [actor.name for actor in room.actors] room_items = [item.name for item in room.items] room_directions = list(room.portals.keys()) - logger.info("starting actor %s turn", actor.name) + actor_attributes = " ".join( + system_format(actor.attributes) + for _, system_format in systems + if system_format + ) + actor_items = [item.name for item in actor.items] + + def result_parser(value, agent, **kwargs): + for callback in input_callbacks: + callback(room, actor, value) + + return world_result_parser(value, agent, **kwargs) + + logger.info("starting turn for actor: %s", actor_name) result = loop_retry( agent, ( - "You are currently in {room_name}. {room_description}. " - "The room contains the following characters: {actors}. " - "The room contains the following items: {items}. " + "You are currently in {room_name}. {room_description}. {attributes}. " + "The room contains the following characters: {visible_actors}. " + "The room contains the following items: {visible_items}. " + "Your inventory contains the following items: {actor_items}." "You can take the following actions: {actions}. " "You can move in the following directions: {directions}. " "What will you do next? Reply with a JSON function call, calling one of the actions." @@ -114,21 +133,26 @@ def simulate_world(world: World, steps: int = 10, callback=None, extra_actions=[ ), context={ "actions": action_names, - "actors": room_actors, + "actor_items": actor_items, + "attributes": actor_attributes, "directions": room_directions, - "items": room_items, "room_name": room.name, "room_description": room.description, + "visible_actors": room_actors, + "visible_items": room_items, }, - result_parser=world_result_parser, + result_parser=result_parser, toolbox=action_tools, ) logger.debug(f"{actor.name} step result: {result}") agent.memory.append(result) - if callback: - callback(world, current_step) + for callback in result_callbacks: + callback(room, actor, result) + + for system_update, _ in systems: + system_update(world, current_step) set_step(current_step + 1) @@ -138,10 +162,13 @@ def parse_args(): import argparse parser = argparse.ArgumentParser( - description="Generate and simulate a fantasy world" + description="Generate and simulate a text adventure world" ) parser.add_argument( - "--actions", type=str, help="Extra actions to include in the simulation" + "--actions", + type=str, + nargs="*", + help="Extra actions to include in the simulation", ) parser.add_argument( "--flavor", type=str, help="Some additional flavor text for the generated world" @@ -155,6 +182,9 @@ def parse_args(): parser.add_argument( "--max-rooms", type=int, help="The maximum number of rooms to generate" ) + parser.add_argument( + "--server", type=str, help="The address on which to run the server" + ) parser.add_argument( "--state", type=str, @@ -164,6 +194,12 @@ def parse_args(): parser.add_argument( "--steps", type=int, default=10, help="The number of simulation steps to run" ) + parser.add_argument( + "--systems", + type=str, + nargs="*", + help="Extra logic systems to run in the simulation", + ) parser.add_argument( "--theme", type=str, default="fantasy", help="The theme of the generated world" ) @@ -211,7 +247,11 @@ def main(): llm, ) world = generate_world( - agent, args.world, args.theme, rooms=args.rooms, max_rooms=args.max_rooms + agent, + args.world, + args.theme, + room_count=args.rooms, + max_rooms=args.max_rooms, ) save_world(world, world_file) @@ -219,25 +259,63 @@ def main(): # load extra actions extra_actions = [] - if args.actions: - logger.info(f"Loading extra actions from {args.actions}") - action_module, action_function = args.actions.rsplit(":", 1) - action_module = import_module(action_module) - action_function = getattr(action_module, action_function) - module_actions = action_function() + for action_name in args.actions: + logger.info(f"Loading extra actions from {action_name}") + module_actions = load_plugin(action_name) logger.info( f"Loaded extra actions: {[action.__name__ for action in module_actions]}" ) extra_actions.extend(module_actions) + # load extra systems + def snapshot_system(world: World, step: int) -> None: + logger.debug("Snapshotting world state") + save_world_state(world, step, world_state_file) + + extra_systems = [(snapshot_system, None)] + for system_name in args.systems: + logger.info(f"Loading extra systems from {system_name}") + module_systems = load_plugin(system_name) + logger.info( + f"Loaded extra systems: {[system.__name__ for system in module_systems]}" + ) + extra_systems.append(module_systems) + + # make sure the server system is last + input_callbacks = [] + result_callbacks = [] + + if args.server: + from adventure.server import ( + launch_server, + server_input, + server_result, + server_system, + ) + + launch_server() + extra_systems.append((server_system, None)) + input_callbacks.append(server_input) + result_callbacks.append(server_result) + + # start the sim logger.debug("Simulating world: %s", world) simulate_world( world, steps=args.steps, - callback=lambda w, s: save_world_state(w, s, world_state_file), - extra_actions=extra_actions, + actions=extra_actions, + systems=extra_systems, + input_callbacks=input_callbacks, + result_callbacks=result_callbacks, ) +def load_plugin(name): + module_name, function_name = name.rsplit(":", 1) + plugin_module = import_module(module_name) + plugin_entry = getattr(plugin_module, function_name) + return plugin_entry() + + if __name__ == "__main__": main() diff --git a/adventure/models.py b/adventure/models.py index 6ba0cc3..9aad3e4 100644 --- a/adventure/models.py +++ b/adventure/models.py @@ -16,6 +16,7 @@ class Item: name: str description: str actions: Actions = Field(default_factory=dict) + attributes: Dict[str, str] = Field(default_factory=dict) @dataclass @@ -23,9 +24,9 @@ class Actor: name: str backstory: str description: str - health: int actions: Actions = Field(default_factory=dict) items: List[Item] = Field(default_factory=list) + attributes: Dict[str, str] = Field(default_factory=dict) @dataclass @@ -36,17 +37,19 @@ class Room: items: List[Item] = Field(default_factory=list) actors: List[Actor] = Field(default_factory=list) actions: Actions = Field(default_factory=dict) + attributes: Dict[str, str] = Field(default_factory=dict) @dataclass class World: name: str + order: List[str] rooms: List[Room] theme: str @dataclass class WorldState: - world: World memory: Dict[str, List[str | Dict[str, str]]] step: int + world: World diff --git a/adventure/server.py b/adventure/server.py new file mode 100644 index 0000000..6cfcf99 --- /dev/null +++ b/adventure/server.py @@ -0,0 +1,90 @@ +import asyncio +from json import dumps +from logging import getLogger +from threading import Thread + +import websockets +from flask import Flask, send_from_directory + +from adventure.models import Actor, Room, World +from adventure.state import snapshot_world, world_json + +logger = getLogger(__name__) + +app = Flask(__name__) +connected = set() + + +@app.route("/") +def send_report(page: str): + print(f"Sending {page}") + return send_from_directory( + "/home/ssube/code/github/ssube/llm-adventure/web-ui", page + ) + + +async def handler(websocket): + connected.add(websocket) + while True: + try: + # await websocket.wait_closed() + message = await websocket.recv() + print(message) + except websockets.ConnectionClosedOK: + break + + connected.remove(websocket) + + +socket_thread = None +static_thread = None + + +def launch_server(): + global socket_thread, static_thread + + def run_sockets(): + asyncio.run(server_main()) + + def run_static(): + app.run(port=8000) + + socket_thread = Thread(target=run_sockets) + socket_thread.start() + + static_thread = Thread(target=run_static) + static_thread.start() + + +async def server_main(): + async with websockets.serve(handler, "", 8001): + logger.info("Server started") + await asyncio.Future() # run forever + + +def server_system(world: World, step: int): + json_state = { + **snapshot_world(world, step), + "type": "world", + } + websockets.broadcast(connected, dumps(json_state, default=world_json)) + + +def server_result(room: Room, actor: Actor, action: str): + json_action = { + "actor": actor.name, + "result": action, + "room": room.name, + "type": "result", + } + websockets.broadcast(connected, dumps(json_action)) + + +def server_input(room: Room, actor: Actor, message: str): + json_input = { + "actor": actor.name, + "input": message, + "room": room.name, + "type": "input", + } + websockets.broadcast(connected, dumps(json_input)) diff --git a/adventure/state.py b/adventure/state.py index 07a8a15..bc22fd9 100644 --- a/adventure/state.py +++ b/adventure/state.py @@ -95,14 +95,14 @@ def save_world_state(world, step, filename): graph_world(world, step) json_state = snapshot_world(world, step) with open(filename, "w") as f: + dump(json_state, f, default=world_json, indent=2) - def dumper(obj): - if isinstance(obj, BaseMessage): - return { - "content": obj.content, - "type": obj.type, - } - raise ValueError(f"Cannot serialize {obj}") +def world_json(obj): + if isinstance(obj, BaseMessage): + return { + "content": obj.content, + "type": obj.type, + } - dump(json_state, f, default=dumper, indent=2) + raise ValueError(f"Cannot serialize {obj}") diff --git a/adventure/systems/logic.py b/adventure/systems/logic.py new file mode 100644 index 0000000..98c9ee6 --- /dev/null +++ b/adventure/systems/logic.py @@ -0,0 +1,90 @@ +from logging import getLogger +from random import random +from typing import Dict, List, Optional + +from pydantic import Field +from yaml import Loader, load + +from adventure.models import World, dataclass + +logger = getLogger(__name__) + + +@dataclass +class LogicLabel: + backstory: str + description: str + + +@dataclass +class LogicRule: + match: Dict[str, str] + chance: float = 1.0 + remove: Optional[List[str]] = None + set: Optional[Dict[str, str]] = None + + +@dataclass +class LogicTable: + rules: List[LogicRule] + labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict) + + +with open("./worlds/logic.yaml") as file: + logic_rules = LogicTable(**load(file, Loader=Loader)) + + +def update_attributes( + attributes: Dict[str, str], dataset: LogicTable +) -> Dict[str, str]: + for rule in dataset.rules: + if rule.match.items() <= attributes.items(): + logger.info("matched logic: %s", rule.match) + if rule.chance < 1: + if random() > rule.chance: + logger.info("logic skipped by chance: %s", rule.chance) + continue + + if rule.set: + attributes.update(rule.set) + logger.info("logic set state: %s", rule.set) + + for key in rule.remove or []: + attributes.pop(key, None) + + return attributes + + +def update_logic(world: World, step: int) -> None: + for room in world.rooms: + room.attributes = update_attributes(room.attributes, logic_rules) + for actor in room.actors: + actor.attributes = update_attributes(actor.attributes, logic_rules) + for item in actor.items: + item.attributes = update_attributes(item.attributes, logic_rules) + for item in room.items: + item.attributes = update_attributes(item.attributes, logic_rules) + + logger.info("updated world attributes") + + +def format_logic(attributes: Dict[str, str], self=True) -> str: + labels = [] + + for attribute, value in attributes.items(): + if attribute in logic_rules.labels and value in logic_rules.labels[attribute]: + label = logic_rules.labels[attribute][value] + if self: + labels.append(label.backstory) + else: + labels.append(label.description) + + if len(labels) > 0: + logger.info("adding labels: %s", labels) + + return " ".join(labels) + + +def init(): + logger.info("initialized logic system") + return (update_logic, format_logic)