diff --git a/adventure/actions.py b/adventure/actions.py index 50f28ed..64f6eff 100644 --- a/adventure/actions.py +++ b/adventure/actions.py @@ -197,7 +197,7 @@ def action_give(character: str, item_name: str) -> str: if not item: return f"You do not have the {item_name} item in your inventory." - broadcast(f"{action_actor.name} gives {character} the {item_name} item") + broadcast(f"{action_actor.name} gives {character} the {item_name} item.") action_actor.items.remove(item) destination_actor.items.append(item) diff --git a/adventure/context.py b/adventure/context.py index 90bf55d..770494f 100644 --- a/adventure/context.py +++ b/adventure/context.py @@ -2,7 +2,7 @@ from typing import Callable, Dict, Tuple from packit.agent import Agent -from adventure.models import Actor, Room, World +from adventure.models.entity import Actor, Room, World current_broadcast: Callable[[str], None] | None = None current_world: World | None = None @@ -16,6 +16,16 @@ dungeon_master: Agent | None = None actor_agents: Dict[str, Tuple[Actor, Agent]] = {} +def broadcast(message: str): + if current_broadcast: + current_broadcast(message) + + +def has_dungeon_master(): + return dungeon_master is not None + + +# region context getters def get_current_context() -> Tuple[World, Room, Actor]: if not current_world: raise ValueError( @@ -47,11 +57,23 @@ def get_current_broadcast(): return current_broadcast -def broadcast(message: str): - if current_broadcast: - current_broadcast(message) +def get_current_step() -> int: + return current_step +def get_dungeon_master() -> Agent: + if not dungeon_master: + raise ValueError( + "The dungeon master must be set before calling action functions" + ) + + return dungeon_master + + +# endregion + + +# region context setters def set_current_broadcast(broadcast): global current_broadcast current_broadcast = broadcast @@ -72,15 +94,24 @@ def set_current_actor(actor: Actor | None): current_actor = actor -def get_current_step() -> int: - return current_step - - def set_current_step(step: int): global current_step current_step = step +def set_actor_agent(name, actor, agent): + actor_agents[name] = (actor, agent) + + +def set_dungeon_master(agent): + global dungeon_master + dungeon_master = agent + + +# endregion + + +# region search functions def get_actor_for_agent(agent): return next( ( @@ -114,27 +145,8 @@ def get_actor_agent_for_name(name): ) -def set_actor_agent_for_name(name, actor, agent): - actor_agents[name] = (actor, agent) - - def get_all_actor_agents(): return list(actor_agents.values()) -def set_dungeon_master(agent): - global dungeon_master - dungeon_master = agent - - -def get_dungeon_master() -> Agent: - if not dungeon_master: - raise ValueError( - "The dungeon master must be set before calling action functions" - ) - - return dungeon_master - - -def has_dungeon_master(): - return dungeon_master is not None +# endregion diff --git a/adventure/discord_bot.py b/adventure/discord_bot.py index eb6807f..94dcc97 100644 --- a/adventure/discord_bot.py +++ b/adventure/discord_bot.py @@ -1,27 +1,34 @@ -# from functools import cache -from json import loads from logging import getLogger from os import environ from queue import Queue from re import sub from threading import Thread -from typing import Literal +from typing import Tuple from discord import Client, Embed, File, Intents -from packit.utils import could_be_json from adventure.context import ( get_actor_agent_for_name, get_current_world, - set_actor_agent_for_name, + set_actor_agent, +) +from adventure.models.event import ( + ActionEvent, + GameEvent, + GenerateEvent, + PromptEvent, + ReplyEvent, + ResultEvent, + StatusEvent, ) -from adventure.models import Actor, Room from adventure.player import RemotePlayer, get_player, has_player, set_player from adventure.render_comfy import generate_image_tool logger = getLogger(__name__) client = None -prompt_queue: Queue = Queue() + +active_tasks = set() +prompt_queue: Queue[Tuple[GameEvent, Embed | str]] = Queue() def remove_tags(text: str) -> str: @@ -32,6 +39,54 @@ def remove_tags(text: str) -> str: return sub(r"<[^>]*>", "", text) +def find_embed_field(embed: Embed, name: str) -> str | None: + return next((field.value for field in embed.fields if field.name == name), None) + + +# TODO: becomes prompt from event +def prompt_from_embed(embed: Embed) -> str | None: + room_name = embed.title + actor_name = embed.description + + world = get_current_world() + if not world: + return + + room = next((room for room in world.rooms if room.name == room_name), None) + if not room: + return + + actor = next((actor for actor in room.actors if actor.name == actor_name), None) + if not actor: + return + + item_field = find_embed_field(embed, "Item") + + action_field = find_embed_field(embed, "Action") + if action_field: + if item_field: + item = next( + ( + item + for item in (room.items + actor.items) + if item.name == item_field + ), + None, + ) + if item: + return f"{actor.name} {action_field} the {item.name}. {item.description}. {actor.description}. {room.description}." + + return f"{actor.name} {action_field} the {item_field}. {actor.description}. {room.description}." + + return f"{actor.name} {action_field}. {actor.description}. {room.name}." + + result_field = find_embed_field(embed, "Result") + if result_field: + return f"{result_field}. {actor.description}. {room.description}." + + return + + class AdventureClient(Client): async def on_ready(self): logger.info(f"Logged in as {self.user}") @@ -46,31 +101,14 @@ class AdventureClient(Client): # TODO: look up event that caused this message, get the room and actors if len(reaction.message.embeds) > 0: embed = reaction.message.embeds[0] - room_name = embed.title - actor_name = embed.description - prompt = f"{room_name}. {actor_name}." - await reaction.message.channel.send(f"Generating image for: {prompt}") - - world = get_current_world() - if not world: - return - - room = next( - (room for room in world.rooms if room.name == room_name), None - ) - if not room: - return - - actor = next( - (actor for actor in room.actors if actor.name == actor_name), None - ) - if not actor: - return - - prompt = f"{room.name}. {actor.name}." + prompt = prompt_from_embed(embed) else: prompt = remove_tags(reaction.message.content) + if prompt.startswith("Generating"): + # TODO: get the entity from the message + pass + await reaction.message.add_reaction("📸") paths = generate_image_tool(prompt, 2) logger.info(f"Generated images: {paths}") @@ -110,20 +148,22 @@ class AdventureClient(Client): await channel.send(f"Character `{character_name}` not found!") return - def prompt_player(character: str, prompt: str): + def prompt_player(event: PromptEvent): logger.info( "append prompt for character %s (user %s) to queue: %s", - character, + event.actor.name, user_name, - prompt, + event.prompt, ) - prompt_queue.put((character, prompt)) + + # TODO: build an embed from the prompt + prompt_queue.put((event, event.prompt)) return True player = RemotePlayer( actor.name, actor.backstory, prompt_player, fallback_agent=agent ) - set_actor_agent_for_name(character_name, actor, player) + set_actor_agent(character_name, actor, player) set_player(user_name, player) logger.info(f"{user_name} has joined the game as {actor.name}!") @@ -153,9 +193,6 @@ class AdventureClient(Client): return -active_tasks = set() - - def launch_bot(): def bot_main(): global client @@ -170,27 +207,29 @@ def launch_bot(): from time import sleep while True: - sleep(0.5) + sleep(0.1) if prompt_queue.empty(): continue if len(active_tasks) > 0: continue - character, prompt = prompt_queue.get() - logger.info("Prompting character %s: %s", character, prompt) + event, prompt = prompt_queue.get() + logger.info("Prompting for event %s: %s", event, prompt) if client: prompt_task = client.loop.create_task(broadcast_event(prompt)) active_tasks.add(prompt_task) prompt_task.add_done_callback(active_tasks.discard) - bot_thread = Thread(target=bot_main) + bot_thread = Thread(target=bot_main, daemon=True) bot_thread.start() - prompt_thread = Thread(target=prompt_main) + prompt_thread = Thread(target=prompt_main, daemon=True) prompt_thread.start() + return [bot_thread, prompt_thread] + def stop_bot(): global client @@ -238,39 +277,48 @@ async def broadcast_event(message: str | Embed): await channel.send(embed=message) -def bot_action(room: Room, actor: Actor, message: str): - try: - action_embed = Embed(title=room.name, description=actor.name) +def bot_event(event: GameEvent): + if isinstance(event, GenerateEvent): + bot_generate(event) + elif isinstance(event, ResultEvent): + bot_result(event) + elif isinstance(event, (ActionEvent, ReplyEvent)): + bot_action(event) + elif isinstance(event, StatusEvent): + pass + else: + logger.warning("Unknown event type: %s", event) - if could_be_json(message): - action_data = loads(message) - action_name = action_data["function"].replace("action_", "").title() - action_parameters = action_data.get("parameters", {}) + +def bot_action(event: ActionEvent | ReplyEvent): + try: + action_embed = Embed(title=event.room.name, description=event.actor.name) + + if isinstance(event, ActionEvent): + action_name = event.action.replace("action_", "").title() + action_parameters = event.parameters action_embed.add_field(name="Action", value=action_name) for key, value in action_parameters.items(): action_embed.add_field(name=key.replace("_", " ").title(), value=value) else: - action_embed.add_field(name="Message", value=message) + action_embed.add_field(name="Message", value=event.text) - prompt_queue.put((actor.name, action_embed)) + prompt_queue.put((event, action_embed)) except Exception as e: logger.error("Failed to broadcast action: %s", e) -def bot_event(message: str): - prompt_queue.put((None, message)) +def bot_generate(event: GenerateEvent): + prompt_queue.put((event, event.name)) -def bot_result(room: Room, actor: Actor, action: str): - result_embed = Embed(title=room.name, description=actor.name) - result_embed.add_field(name="Result", value=action) - prompt_queue.put((actor.name, result_embed)) +def bot_result(event: ResultEvent): + text = event.result + if len(text) > 1000: + text = text[:1000] + "..." - -def player_event(character: str, id: str, event: Literal["join", "leave"]): - if event == "join": - prompt_queue.put((character, f"{character} has joined the game!")) - elif event == "leave": - prompt_queue.put((character, f"{character} has left the game!")) + result_embed = Embed(title=event.room.name, description=event.actor.name) + result_embed.add_field(name="Result", value=text) + prompt_queue.put((event, result_embed)) diff --git a/adventure/generate.py b/adventure/generate.py index 8a4617c..9017879 100644 --- a/adventure/generate.py +++ b/adventure/generate.py @@ -1,11 +1,12 @@ from logging import getLogger from random import choice, randint -from typing import Callable, List +from typing import List from packit.agent import Agent from packit.loops import loop_retry -from adventure.models import Actor, Item, Room, World +from adventure.models.entity import Actor, Item, Room, World +from adventure.models.event import EventCallback, GenerateEvent logger = getLogger(__name__) @@ -17,13 +18,10 @@ OPPOSITE_DIRECTIONS = { } -GenerateCallback = Callable[[str], None] - - def generate_room( agent: Agent, world_theme: str, - callback: GenerateCallback | None = None, + callback: EventCallback | None = None, existing_rooms: List[str] = [], ) -> Room: def unique_name(name: str, **kwargs): @@ -45,7 +43,7 @@ def generate_room( ) if callable(callback): - callback(f"Generating room: {name}") + callback(GenerateEvent.from_name(f"Generating room: {name}")) desc = agent( "Generate a detailed description of the {name} area. What does it look like? " @@ -65,7 +63,7 @@ def generate_room( def generate_item( agent: Agent, world_theme: str, - callback: Callable[[str], None] | None = None, + callback: EventCallback | None = None, dest_room: str | None = None, dest_actor: str | None = None, existing_items: List[str] = [], @@ -99,7 +97,7 @@ def generate_item( ) if callable(callback): - callback(f"Generating item: {name}") + callback(GenerateEvent.from_name(f"Generating item: {name}")) desc = agent( "Generate a detailed description of the {name} item. What does it look like? What is it made of? What does it do?", @@ -115,7 +113,7 @@ def generate_actor( agent: Agent, world_theme: str, dest_room: str, - callback: GenerateCallback | None = None, + callback: EventCallback | None = None, existing_actors: List[str] = [], ) -> Actor: def unique_name(name: str, **kwargs): @@ -141,7 +139,7 @@ def generate_actor( ) if callable(callback): - callback(f"Generating actor: {name}") + callback(GenerateEvent.from_name(f"Generating actor: {name}")) description = agent( "Generate a detailed description of the {name} character. What do they look like? What are they wearing? " @@ -169,12 +167,14 @@ def generate_world( theme: str, room_count: int | None = None, max_rooms: int = 5, - callback: Callable[[str], None] | None = None, + callback: EventCallback | None = None, ) -> World: room_count = room_count or randint(3, max_rooms) if callable(callback): - callback(f"Generating a {theme} with {room_count} rooms") + callback( + GenerateEvent.from_name(f"Generating a {theme} with {room_count} rooms") + ) existing_actors: List[str] = [] existing_items: List[str] = [] @@ -192,7 +192,11 @@ def generate_world( item_count = randint(1, 3) if callable(callback): - callback(f"Generating {item_count} items for room: {room.name}") + callback( + GenerateEvent.from_name( + f"Generating {item_count} items for room: {room.name}" + ) + ) for j in range(item_count): item = generate_item( @@ -208,7 +212,11 @@ def generate_world( actor_count = randint(1, 3) if callable(callback): - callback(f"Generating {actor_count} actors for room: {room.name}") + callback( + GenerateEvent.from_name( + f"Generating {actor_count} actors for room: {room.name}" + ) + ) for j in range(actor_count): actor = generate_actor( @@ -225,7 +233,11 @@ def generate_world( item_count = randint(0, 2) if callable(callback): - callback(f"Generating {item_count} items for actor {actor.name}") + callback( + GenerateEvent.from_name( + f"Generating {item_count} items for actor {actor.name}" + ) + ) for k in range(item_count): item = generate_item( diff --git a/adventure/logic.py b/adventure/logic.py index d4d6710..5820762 100644 --- a/adventure/logic.py +++ b/adventure/logic.py @@ -7,7 +7,7 @@ from pydantic import Field from rule_engine import Rule from yaml import Loader, load -from adventure.models import ( +from adventure.models.entity import ( Actor, Attributes, AttributeValue, diff --git a/adventure/main.py b/adventure/main.py index 3b01922..37af3e3 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -1,26 +1,33 @@ -from json import load +import atexit from logging.config import dictConfig from os import environ, path -from typing import Callable, Sequence, Tuple +from typing import List from dotenv import load_dotenv from packit.agent import Agent, agent_easy_connect -from packit.loops import loop_retry -from packit.results import multi_function_or_str_result -from packit.toolbox import Toolbox from packit.utils import logger_with_colors +from yaml import Loader, load -from adventure.context import set_current_broadcast, set_dungeon_master -from adventure.models import Attributes +from adventure.context import set_current_step, set_dungeon_master +from adventure.generate import generate_world +from adventure.models.entity import World, WorldState +from adventure.models.event import EventCallback, GameEvent +from adventure.models.files import PromptFile, WorldPrompt from adventure.plugins import load_plugin +from adventure.simulate import simulate_world +from adventure.state import create_agents, save_world, save_world_state -# Configure logging + +def load_yaml(file): + return load(file, Loader=Loader) + + +# configure logging LOG_PATH = "logging.json" -# LOG_PATH = "dev-logging.json" try: if path.exists(LOG_PATH): with open(LOG_PATH, "r") as f: - config_logging = load(f) + config_logging = load_yaml(f) dictConfig(config_logging) else: print("logging config not found") @@ -28,166 +35,12 @@ try: except Exception as err: print("error loading logging config: %s" % (err)) -if True: - from adventure.actions import ( - action_ask, - action_give, - action_look, - action_move, - action_take, - action_tell, - ) - from adventure.context import ( - get_actor_agent_for_name, - get_actor_for_agent, - get_current_step, - get_current_world, - set_current_actor, - set_current_room, - set_current_step, - set_current_world, - ) - from adventure.generate import generate_world - from adventure.models import Actor, Room, World, WorldState - from adventure.state import create_agents, save_world, save_world_state logger = logger_with_colors(__name__, level="DEBUG") load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True) -# simulation -def world_result_parser(value, agent, **kwargs): - current_world = get_current_world() - if not current_world: - raise ValueError( - "The current world must be set before calling world_result_parser" - ) - - logger.debug(f"parsing action for {agent.name}: {value}") - - current_actor = get_actor_for_agent(agent) - current_room = next( - (room for room in current_world.rooms if current_actor in room.actors), None - ) - - set_current_room(current_room) - set_current_actor(current_actor) - - return multi_function_or_str_result(value, agent=agent, **kwargs) - - -def simulate_world( - world: World, - steps: int = 10, - actions: Sequence[Callable[..., str]] = [], - systems: Sequence[ - Tuple[Callable[[World, int], None], Callable[[Attributes], str] | None] - ] = [], - event_callbacks: Sequence[Callable[[str], None]] = [], - input_callbacks: Sequence[Callable[[Room, Actor, str], None]] = [], - result_callbacks: Sequence[Callable[[Room, Actor, str], None]] = [], -): - logger.info("Simulating the world") - set_current_world(world) - - # set up a broadcast callback - def broadcast_callback(message): - logger.info(message) - for callback in event_callbacks: - callback(message) - - set_current_broadcast(broadcast_callback) - - # build a toolbox for the actions - action_tools = Toolbox( - [ - action_ask, - action_give, - action_look, - action_move, - action_take, - action_tell, - *actions, - ] - ) - action_names = action_tools.list_tools() - - # simulate each actor - for i in range(steps): - current_step = get_current_step() - logger.info(f"Simulating step {current_step}") - for actor_name in world.order: - actor, agent = get_actor_agent_for_name(actor_name) - if not agent or not actor: - logger.error(f"Agent or actor not found for name {actor_name}") - continue - - room = next((room for room in world.rooms if actor in room.actors), None) - if not room: - logger.error(f"Actor {actor_name} is not in a room") - continue - - room_actors = [actor.name for actor in room.actors] - room_items = [item.name for item in room.items] - room_directions = list(room.portals.keys()) - - actor_attributes = " ".join( - system_format(actor.attributes) - for _, system_format in systems - if system_format - ) - actor_items = [item.name for item in actor.items] - - def result_parser(value, agent, **kwargs): - for callback in input_callbacks: - logger.info( - f"calling input callback for {actor_name}: {callback.__name__}" - ) - callback(room, actor, value) - - return world_result_parser(value, agent, **kwargs) - - logger.info("starting turn for actor: %s", actor_name) - result = loop_retry( - agent, - ( - "You are currently in {room_name}. {room_description}. {attributes}. " - "The room contains the following characters: {visible_actors}. " - "The room contains the following items: {visible_items}. " - "Your inventory contains the following items: {actor_items}." - "You can take the following actions: {actions}. " - "You can move in the following directions: {directions}. " - "What will you do next? Reply with a JSON function call, calling one of the actions." - "You can only perform one action per turn. What is your next action?" - # Pick the most important action and save the rest for later." - ), - context={ - "actions": action_names, - "actor_items": actor_items, - "attributes": actor_attributes, - "directions": room_directions, - "room_name": room.name, - "room_description": room.description, - "visible_actors": room_actors, - "visible_items": room_items, - }, - result_parser=result_parser, - toolbox=action_tools, - ) - - logger.debug(f"{actor.name} step result: {result}") - agent.memory.append(result) - - for callback in result_callbacks: - callback(room, actor, result) - - for system_update, _ in systems: - system_update(world, current_step) - - set_current_step(current_step + 1) - - # main def parse_args(): import argparse @@ -249,84 +102,72 @@ def parse_args(): default="world", help="The file to save the generated world to", ) + parser.add_argument( + "--world-prompt", + type=str, + help="The file to load the world prompt from", + ) return parser.parse_args() -def main(): - args = parse_args() +def get_world_prompt(args) -> WorldPrompt: + if args.world_prompt: + prompt_file, prompt_name = args.world_prompt.split(":") + with open(prompt_file, "r") as f: + prompts = PromptFile(**load_yaml(f)) + for prompt in prompts.prompts: + if prompt.name == prompt_name: + return prompt + logger.warning(f"prompt {prompt_name} not found in {prompt_file}") + + return WorldPrompt( + name=args.world, + theme=args.theme, + flavor=args.flavor, + ) + + +def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt): world_file = args.world + ".json" world_state_file = args.state or (args.world + ".state.json") - players = [] - if args.player: - players.append(args.player) - - # set up callbacks - event_callbacks = [] - input_callbacks = [] - result_callbacks = [] - - if args.discord: - from adventure.discord_bot import bot_action, bot_event, bot_result, launch_bot - - launch_bot() - event_callbacks.append(bot_event) - input_callbacks.append(bot_action) - result_callbacks.append(bot_result) - - if args.server: - from adventure.server import ( - launch_server, - server_action, - server_event, - server_result, - server_system, - ) - - launch_server() - event_callbacks.append(server_event) - input_callbacks.append(server_action) - result_callbacks.append(server_result) - memory = {} if path.exists(world_state_file): - logger.info(f"Loading world state from {world_state_file}") + logger.info(f"loading world state from {world_state_file}") with open(world_state_file, "r") as f: - state = WorldState(**load(f)) + state = WorldState(**load_yaml(f)) set_current_step(state.step) memory = state.memory world = state.world - world.name = args.world elif path.exists(world_file): - logger.info(f"Loading world from {world_file}") + logger.info(f"loading world from {world_file}") with open(world_file, "r") as f: - world = World(**load(f)) + world = World(**load_yaml(f)) else: - logger.info(f"Generating a new {args.theme} world") + logger.info(f"generating a new world using theme: {world_prompt.theme}") llm = agent_easy_connect() world_builder = Agent( "World Builder", - f"You are an experienced game master creating a visually detailed {args.theme} world for a new adventure. {args.flavor}", + f"You are an experienced game master creating a visually detailed world for a new adventure. " + f"{world_prompt.flavor}. The theme is: {world_prompt.theme}.", {}, llm, ) world = None - def broadcast_callback(message): - logger.info(message) - for callback in event_callbacks: - callback(message) - if args.server and world: - server_system(world, 0) + def broadcast_callback(event: GameEvent): + logger.info(event) + for callback in callbacks: + callback(event) world = generate_world( world_builder, args.world, - args.theme, + world_prompt.theme, room_count=args.rooms, max_rooms=args.max_rooms, callback=broadcast_callback, @@ -334,40 +175,68 @@ def main(): save_world(world, world_file) create_agents(world, memory=memory, players=players) + return (world, world_state_file) + + +def main(): + args = parse_args() + + players = [] + if args.player: + players.append(args.player) + + # set up callbacks + callbacks: List[EventCallback] = [] + + # launch other threads + threads = [] + if args.discord: + from adventure.discord_bot import bot_event, launch_bot + + threads.extend(launch_bot()) + callbacks.append(bot_event) + if args.server: - server_system(world, 0) + from adventure.server import launch_server, server_event, server_system - # load extra actions + threads.extend(launch_server()) + callbacks.append(server_event) + + # register the thread shutdown handler + def shutdown_threads(): + for thread in threads: + thread.join(1.0) + + atexit.register(shutdown_threads) + + # load built-in but optional actions extra_actions = [] - for action_name in args.actions or []: - logger.info(f"Loading extra actions from {action_name}") - module_actions = load_plugin(action_name) - logger.info( - f"Loaded extra actions: {[action.__name__ for action in module_actions]}" - ) - extra_actions.extend(module_actions) - if args.optional_actions: - logger.info("Loading optional actions") + logger.info("loading optional actions") from adventure.optional_actions import init as init_optional_actions optional_actions = init_optional_actions() logger.info( - f"Loaded optional actions: {[action.__name__ for action in optional_actions]}" + f"loaded optional actions: {[action.__name__ for action in optional_actions]}" ) extra_actions.extend(optional_actions) - # load extra systems - def snapshot_system(world: World, step: int) -> None: - logger.debug("Snapshotting world state") - save_world_state(world, step, world_state_file) + # load extra actions from plugins + for action_name in args.actions or []: + logger.info(f"loading extra actions from {action_name}") + module_actions = load_plugin(action_name) + logger.info( + f"loaded extra actions: {[action.__name__ for action in module_actions]}" + ) + extra_actions.extend(module_actions) - extra_systems = [(snapshot_system, None)] + # load extra systems from plugins + extra_systems = [] for system_name in args.systems or []: - logger.info(f"Loading extra systems from {system_name}") + logger.info(f"loading extra systems from {system_name}") module_systems = load_plugin(system_name) logger.info( - f"Loaded extra systems: {[component.__name__ for system in module_systems for component in system]}" + f"loaded extra systems: {[component.__name__ for system in module_systems for component in system]}" ) extra_systems.extend(module_systems) @@ -375,6 +244,23 @@ def main(): if args.server: extra_systems.append((server_system, None)) + # load or generate the world + world_prompt = get_world_prompt(args) + world, world_state_file = load_or_generate_world( + args, players, callbacks, world_prompt=world_prompt + ) + + # make sure the snapshot system runs last + def snapshot_system(world: World, step: int) -> None: + logger.info("taking snapshot of world state") + save_world_state(world, step, world_state_file) + + extra_systems.append((snapshot_system, None)) + + # run the systems once to initialize everything + for system_update, _ in extra_systems: + system_update(world, 0) + # create the DM llm = agent_easy_connect() world_builder = Agent( @@ -390,14 +276,13 @@ def main(): set_dungeon_master(world_builder) # start the sim - logger.debug("Simulating world: %s", world) + logger.debug("simulating world: %s", world) simulate_world( world, steps=args.steps, actions=extra_actions, systems=extra_systems, - input_callbacks=input_callbacks, - result_callbacks=result_callbacks, + callbacks=callbacks, ) diff --git a/adventure/models/base.py b/adventure/models/base.py new file mode 100644 index 0000000..a910cb9 --- /dev/null +++ b/adventure/models/base.py @@ -0,0 +1,6 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from pydantic.dataclasses import dataclass as dataclass # noqa diff --git a/adventure/models.py b/adventure/models/entity.py similarity index 85% rename from adventure/models.py rename to adventure/models/entity.py index 86b8bd8..5eeaa27 100644 --- a/adventure/models.py +++ b/adventure/models/entity.py @@ -1,12 +1,8 @@ -from typing import TYPE_CHECKING, Callable, Dict, List +from typing import Callable, Dict, List from pydantic import Field -if TYPE_CHECKING: - from dataclasses import dataclass -else: - from pydantic.dataclasses import dataclass as dataclass # noqa - +from .base import dataclass Actions = Dict[str, Callable] AttributeValue = bool | int | str @@ -55,3 +51,6 @@ class WorldState: memory: Dict[str, List[str | Dict[str, str]]] step: int world: World + + +WorldEntity = Room | Actor | Item diff --git a/adventure/models/event.py b/adventure/models/event.py new file mode 100644 index 0000000..f87b593 --- /dev/null +++ b/adventure/models/event.py @@ -0,0 +1,137 @@ +from json import loads +from typing import Callable, Dict, Literal + +from .base import dataclass +from .entity import Actor, Item, Room, WorldEntity + + +@dataclass +class BaseEvent: + """ + A base event class. + """ + + event: str + + +@dataclass +class GenerateEvent: + """ + A new entity has been generated. + """ + + event = "generate" + name: str + entity: WorldEntity | None = None + + @staticmethod + def from_name(name: str) -> "GenerateEvent": + return GenerateEvent(name=name) + + @staticmethod + def from_entity(entity: WorldEntity) -> "GenerateEvent": + return GenerateEvent(name=entity.name, entity=entity) + + +@dataclass +class ActionEvent: + """ + An actor has taken an action. + """ + + event = "action" + action: str + parameters: Dict[str, str] + + room: Room + actor: Actor + item: Item | None = None + + @staticmethod + def from_json(json: str, room: Room, actor: Actor) -> "ActionEvent": + openai_json = loads(json) + return ActionEvent( + action=openai_json["function"], + parameters=openai_json["parameters"], + room=room, + actor=actor, + item=None, + ) + + +@dataclass +class PromptEvent: + """ + A prompt for an actor to take an action. + """ + + event = "prompt" + prompt: str + room: Room + actor: Actor + + @staticmethod + def from_text(prompt: str, room: Room, actor: Actor) -> "PromptEvent": + return PromptEvent(prompt=prompt, room=room, actor=actor) + + +@dataclass +class ReplyEvent: + """ + An actor has replied with text. + + This is the non-JSON version of an ActionEvent. + """ + + event = "text" + text: str + room: Room + actor: Actor + + @staticmethod + def from_text(text: str, room: Room, actor: Actor) -> "ReplyEvent": + return ReplyEvent(text=text, room=room, actor=actor) + + +@dataclass +class ResultEvent: + """ + A result of an action. + """ + + event = "result" + result: str + room: Room + actor: Actor + + +@dataclass +class StatusEvent: + """ + A status broadcast event with text. + """ + + event = "status" + text: str + room: Room | None = None + actor: Actor | None = None + + +@dataclass +class PlayerEvent: + """ + A player joining or leaving the game. + """ + + event = "player" + status: Literal["join", "leave"] + character: str + client: str + + +# event types +WorldEvent = ActionEvent | PromptEvent | ReplyEvent | ResultEvent | StatusEvent +GameEvent = GenerateEvent | PlayerEvent | WorldEvent + +# callback types +EventCallback = Callable[[GameEvent], None] diff --git a/adventure/models/files.py b/adventure/models/files.py new file mode 100644 index 0000000..5ce6b17 --- /dev/null +++ b/adventure/models/files.py @@ -0,0 +1,15 @@ +from typing import List + +from .base import dataclass + + +@dataclass +class WorldPrompt: + name: str + theme: str + flavor: str = "" + + +@dataclass +class PromptFile: + prompts: List[WorldPrompt] diff --git a/adventure/player.py b/adventure/player.py index 64dea40..e60fb08 100644 --- a/adventure/player.py +++ b/adventure/player.py @@ -8,6 +8,8 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from packit.agent import Agent from packit.utils import could_be_json +from adventure.models.event import PromptEvent + logger = getLogger(__name__) @@ -158,13 +160,13 @@ class LocalPlayer(BasePlayer): class RemotePlayer(BasePlayer): fallback_agent: Agent | None input_queue: Queue[str] - send_prompt: Callable[[str, str], bool] + send_prompt: Callable[[PromptEvent], bool] def __init__( self, name: str, backstory: str, - send_prompt: Callable[[str, str], bool], + send_prompt: Callable[[PromptEvent], bool], fallback_agent=None, ) -> None: super().__init__(name, backstory) @@ -180,9 +182,11 @@ class RemotePlayer(BasePlayer): formatted_prompt = prompt.format(**kwargs) self.memory.append(HumanMessage(content=formatted_prompt)) + prompt_event = PromptEvent.from_text(formatted_prompt, None, None) + try: logger.info(f"prompting remote player: {self.name}") - if self.send_prompt(self.name, formatted_prompt): + if self.send_prompt(prompt_event): reply = self.input_queue.get(timeout=60) logger.info(f"got reply from remote player: {reply}") return self.parse_input(reply) diff --git a/adventure/prompts.yml b/adventure/prompts.yml index 7613f44..5f48d89 100644 --- a/adventure/prompts.yml +++ b/adventure/prompts.yml @@ -1,32 +1,43 @@ prompts: - - theme: talking animal truckers in the Australian outback + - name: outback-animals + theme: talking animal truckers in the Australian outback flavor: create a fun and happy world where rough and tumble talking animals drive trucks and run saloons in the outback - - theme: grimdark future where humans wage a desperate war for survival against hedgehogs + - name: grimdark-hedgehogs + theme: grimdark future where humans wage a desperate war for survival against hedgehogs flavor: create a grim, dark story where a ragtag group of humans is barely surviving in a world overrun by vicious mutant hedgehogs - - theme: crowded apartment building in New York City + - name: nyc-apartment + theme: crowded apartment building in New York City flavor: | create a crowded apartment building in NYC in the late 90s with a cast of wild and wacky characters. include colorful characters and make sure they will fully utilize all of the actions available to them in this world, exploring and interacting with each other - - theme: opening scenes from Jurassic Park + - name: jurassic-park + theme: opening scenes from Jurassic Park flavor: | follow the script of the film Jurassic Park exactly. do not deviate from the script in any way. include accurate characters and make sure they will fully utilize all of the actions available to them in this world - - theme: opening scenes from Star Wars + - name: star-wars + theme: opening scenes from Star Wars flavor: | follow the script of the 1977 film Star Wars exactly. do not deviate from the script in any way. include accurate characters and make sure they will fully utilize all of the actions available to them in this world - - theme: wealthy cyberpunk utopia with a dark secret + - name: cyberpunk-utopia + theme: wealthy cyberpunk utopia with a dark secret flavor: make a strange and dangerous world where technology is pervasive and scarcity is unheard of - for the upper class, at least - - theme: post-apocalyptic world where the only survivors are sentient robots + - name: post-apocalyptic-robots + theme: post-apocalyptic world where the only survivors are sentient robots flavor: create a world where the only survivors of a nuclear apocalypse are sentient robots, who must now rebuild society from scratch - - theme: haunted house in the middle of nowhere + - name: haunted-house + theme: haunted house in the middle of nowhere flavor: create a spooky and suspenseful world where a group of people are trapped in a haunted house in the middle of nowhere - - theme: dangerous magical fantasy world + - name: magical-kingdom + theme: dangerous magical fantasy world flavor: make a strange and dangerous world where magic winds its way through everything and incredibly powerful beings drink, fight, and wander the halls - - theme: underwater city of mermaids + - name: underwater-mermaids + theme: underwater city of mermaids flavor: create a beautiful and mysterious world where mermaids live in an underwater city, exploring the depths and interacting with each other - - theme: a mysterious town in the Pacific Northwest filled with strange cryptids and private investigators searching for them + - name: cryptid-town + theme: a mysterious town in the Pacific Northwest filled with strange cryptids and private investigators searching for them flavor: | make a strange and creepy world where terrifying creatures that you could never imagine in daylight roam the back alleys and hard-bitten private investigators with rough voices search for answers. do not use the word cryptid in any names \ No newline at end of file diff --git a/adventure/render_comfy.py b/adventure/render_comfy.py index bebb14d..610a1eb 100644 --- a/adventure/render_comfy.py +++ b/adventure/render_comfy.py @@ -150,8 +150,26 @@ def generate_images( "inputs": {"batch_size": count, "height": height, "width": width}, }, "6": { - "class_type": "CLIPTextEncode", - "inputs": {"text": prompt, "clip": ["4", 1]}, + "class_type": "smZ CLIPTextEncode", + "inputs": { + "text": prompt, + "parser": "compel", + "mean_normalization": True, + "multi_conditioning": True, + "use_old_emphasis_implementation": False, + "with_SDXL": False, + "ascore": 6, + "width": width, + "height": height, + "crop_w": 0, + "crop_h": 0, + "target_width": width, + "target_height": height, + "text_g": "", + "text_l": "", + "smZ_steps": 1, + "clip": ["4", 1], + }, }, "7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}}, "8": { diff --git a/adventure/server.py b/adventure/server.py index 024b834..1512ee3 100644 --- a/adventure/server.py +++ b/adventure/server.py @@ -8,8 +8,17 @@ from uuid import uuid4 import websockets -from adventure.context import get_actor_agent_for_name, set_actor_agent_for_name -from adventure.models import Actor, Room, World +from adventure.context import get_actor_agent_for_name, set_actor_agent +from adventure.models.entity import Actor, Room, World +from adventure.models.event import ( + ActionEvent, + GameEvent, + GenerateEvent, + PromptEvent, + ReplyEvent, + ResultEvent, + StatusEvent, +) from adventure.player import ( RemotePlayer, get_player, @@ -24,7 +33,7 @@ logger = getLogger(__name__) connected = set() recent_events = deque(maxlen=100) -recent_world = None +last_snapshot = None async def handler(websocket): @@ -45,10 +54,10 @@ async def handler(websocket): ), ) - def sync_turn(character: str, prompt: str) -> bool: + def sync_turn(event: PromptEvent) -> bool: player = get_player(id) - if player and player.name == character: - asyncio.run(next_turn(character, prompt)) + if player and player.name == event.actor.name: + asyncio.run(next_turn(event.actor.name, event.prompt)) return True return False @@ -56,8 +65,8 @@ async def handler(websocket): try: await websocket.send(dumps({"type": "id", "id": id})) - if recent_world: - await websocket.send(recent_world) + if last_snapshot: + await websocket.send(last_snapshot) for message in recent_events: await websocket.send(message) @@ -74,10 +83,14 @@ async def handler(websocket): data = loads(message) message_type = data.get("type", None) if message_type == "player": + character_name = data["become"] + if has_player(character_name): + logger.error(f"Character {character_name} is already in use") + continue + # TODO: should this always remove? remove_player(id) - character_name = data["become"] actor, llm_agent = get_actor_agent_for_name(character_name) if not actor: logger.error(f"Failed to find actor {character_name}") @@ -90,10 +103,6 @@ async def handler(websocket): ) llm_agent = llm_agent.fallback_agent - if has_player(character_name): - logger.error(f"Character {character_name} is already in use") - continue - # player_name = data["player"] player = RemotePlayer( actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent @@ -102,7 +111,7 @@ async def handler(websocket): logger.info(f"Client {id} is now character {character_name}") # swap out the LLM agent - set_actor_agent_for_name(actor.name, actor, player) + set_actor_agent(actor.name, actor, player) # notify all clients that this character is now active player_event(character_name, id, "join") @@ -134,7 +143,7 @@ async def handler(websocket): actor, _ = get_actor_agent_for_name(player.name) if actor and player.fallback_agent: logger.info("Restoring LLM agent for %s", player.name) - set_actor_agent_for_name(player.name, actor, player.fallback_agent) + set_actor_agent(player.name, actor, player.fallback_agent) logger.info("Client disconnected: %s", id) @@ -166,9 +175,11 @@ def launch_server(): def run_sockets(): asyncio.run(server_main()) - socket_thread = Thread(target=run_sockets) + socket_thread = Thread(target=run_sockets, daemon=True) socket_thread.start() + return [socket_thread] + async def server_main(): async with websockets.serve(handler, "", 8001): @@ -177,12 +188,12 @@ async def server_main(): def server_system(world: World, step: int): - global recent_world + global last_snapshot json_state = { **snapshot_world(world, step), "type": "world", } - recent_world = send_and_append(json_state) + last_snapshot = send_and_append(json_state) def server_result(room: Room, actor: Actor, action: str): @@ -205,14 +216,29 @@ def server_action(room: Room, actor: Actor, message: str): send_and_append(json_input) -def server_event(message: str): +def server_generate(event: GenerateEvent): json_broadcast = { - "message": message, - "type": "event", + "name": event.name, + "type": "generate", } send_and_append(json_broadcast) +def server_event(event: GameEvent): + if isinstance(event, GenerateEvent): + return server_generate(event) + elif isinstance(event, ActionEvent): + return server_action(event.room, event.actor, event.action) + elif isinstance(event, ReplyEvent): + return server_action(event.room, event.actor, event.text) + elif isinstance(event, ResultEvent): + return server_result(event.room, event.actor, event.result) + elif isinstance(event, StatusEvent): + pass + else: + logger.warning("Unknown event type: %s", event) + + def player_event(character: str, id: str, event: Literal["join", "leave"]): json_broadcast = { "type": "player", diff --git a/adventure/sim_systems/environment_triggers.py b/adventure/sim_systems/environment_triggers.py index bb9d9f6..82ff522 100644 --- a/adventure/sim_systems/environment_triggers.py +++ b/adventure/sim_systems/environment_triggers.py @@ -1,4 +1,4 @@ -from adventure.models import Attributes, Room +from adventure.models.entity import Attributes, Room def hot_room(room: Room, attributes: Attributes): diff --git a/adventure/simulate.py b/adventure/simulate.py new file mode 100644 index 0000000..aaa4141 --- /dev/null +++ b/adventure/simulate.py @@ -0,0 +1,177 @@ +from logging import getLogger +from typing import Callable, Sequence, Tuple + +from packit.loops import loop_retry +from packit.results import multi_function_or_str_result +from packit.toolbox import Toolbox +from packit.utils import could_be_json + +from adventure.actions import ( + action_ask, + action_give, + action_look, + action_move, + action_take, + action_tell, +) +from adventure.context import ( + get_actor_agent_for_name, + get_actor_for_agent, + get_current_step, + get_current_world, + set_current_actor, + set_current_broadcast, + set_current_room, + set_current_step, + set_current_world, +) +from adventure.models.entity import Attributes, World +from adventure.models.event import ( + ActionEvent, + EventCallback, + ReplyEvent, + ResultEvent, + StatusEvent, +) + +logger = getLogger(__name__) + + +def world_result_parser(value, agent, **kwargs): + current_world = get_current_world() + if not current_world: + raise ValueError( + "The current world must be set before calling world_result_parser" + ) + + logger.debug(f"parsing action for {agent.name}: {value}") + + current_actor = get_actor_for_agent(agent) + current_room = next( + (room for room in current_world.rooms if current_actor in room.actors), None + ) + + set_current_room(current_room) + set_current_actor(current_actor) + + return multi_function_or_str_result(value, agent=agent, **kwargs) + + +def simulate_world( + world: World, + steps: int = 10, + actions: Sequence[Callable[..., str]] = [], + systems: Sequence[ + Tuple[Callable[[World, int], None], Callable[[Attributes], str] | None] + ] = [], + callbacks: Sequence[EventCallback] = [], +): + logger.info("Simulating the world") + set_current_world(world) + + # set up a broadcast callback + def broadcast_callback(message): + logger.info(message) + event = StatusEvent(text=message) + for callback in callbacks: + callback(event) + + set_current_broadcast(broadcast_callback) + + # build a toolbox for the actions + action_tools = Toolbox( + [ + action_ask, + action_give, + action_look, + action_move, + action_take, + action_tell, + *actions, + ] + ) + action_names = action_tools.list_tools() + + # simulate each actor + for i in range(steps): + current_step = get_current_step() + logger.info(f"Simulating step {current_step}") + for actor_name in world.order: + actor, agent = get_actor_agent_for_name(actor_name) + if not agent or not actor: + logger.error(f"Agent or actor not found for name {actor_name}") + continue + + room = next((room for room in world.rooms if actor in room.actors), None) + if not room: + logger.error(f"Actor {actor_name} is not in a room") + continue + + room_actors = [actor.name for actor in room.actors] + room_items = [item.name for item in room.items] + room_directions = list(room.portals.keys()) + + actor_attributes = " ".join( + system_format(actor.attributes) + for _, system_format in systems + if system_format + ) + actor_items = [item.name for item in actor.items] + + def result_parser(value, agent, **kwargs): + if not room or not actor: + raise ValueError( + "Room and actor must be set before parsing results" + ) + + if could_be_json(value): + event = ActionEvent.from_json(value, room, actor) + else: + event = ReplyEvent.from_text(value, room, actor) + + for callback in callbacks: + logger.info( + f"calling input callback for {actor_name}: {callback.__name__}" + ) + callback(event) + + return world_result_parser(value, agent, **kwargs) + + logger.info("starting turn for actor: %s", actor_name) + result = loop_retry( + agent, + ( + "You are currently in {room_name}. {room_description}. {attributes}. " + "The room contains the following characters: {visible_actors}. " + "The room contains the following items: {visible_items}. " + "Your inventory contains the following items: {actor_items}." + "You can take the following actions: {actions}. " + "You can move in the following directions: {directions}. " + "What will you do next? Reply with a JSON function call, calling one of the actions." + "You can only perform one action per turn. What is your next action?" + ), + context={ + "actions": action_names, + "actor_items": actor_items, + "attributes": actor_attributes, + "directions": room_directions, + "room_name": room.name, + "room_description": room.description, + "visible_actors": room_actors, + "visible_items": room_items, + }, + result_parser=result_parser, + toolbox=action_tools, + ) + + logger.debug(f"{actor.name} step result: {result}") + agent.memory.append(result) + + result_event = ResultEvent(result=result, room=room, actor=actor) + for callback in callbacks: + callback(result_event) + + for system_update, _ in systems: + system_update(world, current_step) + + set_current_step(current_step + 1) diff --git a/adventure/state.py b/adventure/state.py index bc22fd9..fbd8921 100644 --- a/adventure/state.py +++ b/adventure/state.py @@ -7,8 +7,8 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System from packit.agent import Agent, agent_easy_connect from pydantic import RootModel -from adventure.context import get_all_actor_agents, set_actor_agent_for_name -from adventure.models import World +from adventure.context import get_all_actor_agents, set_actor_agent +from adventure.models.entity import World from adventure.player import LocalPlayer @@ -29,7 +29,7 @@ def create_agents( else: agent = Agent(actor.name, actor.backstory, {}, llm) agent.memory = restore_memory(memory.get(actor.name, [])) - set_actor_agent_for_name(actor.name, actor, agent) + set_actor_agent(actor.name, actor, agent) def graph_world(world: World, step: int):