From 03c324ef601055f73824dfbca47bfc22b18adfdd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 May 2024 16:58:11 -0500 Subject: [PATCH] switch to a proper event emitter --- adventure/bot_discord.py | 5 +- adventure/context.py | 62 +++++++++++++++++------ adventure/generate.py | 62 ++++++++--------------- adventure/main.py | 45 ++++------------ adventure/models/config.py | 13 +++++ adventure/render_comfy.py | 9 +++- adventure/rpg_systems/crafting_actions.py | 3 +- adventure/rpg_systems/magic_actions.py | 1 + adventure/rpg_systems/movement_actions.py | 1 + adventure/server_socket.py | 11 +++- adventure/simulate.py | 36 ++----------- 11 files changed, 124 insertions(+), 124 deletions(-) diff --git a/adventure/bot_discord.py b/adventure/bot_discord.py index 3599b2e..284c72d 100644 --- a/adventure/bot_discord.py +++ b/adventure/bot_discord.py @@ -12,8 +12,9 @@ from adventure.context import ( get_actor_agent_for_name, get_current_world, set_actor_agent, + subscribe, ) -from adventure.models.config import DiscordBotConfig, DEFAULT_CONFIG +from adventure.models.config import DEFAULT_CONFIG, DiscordBotConfig from adventure.models.event import ( ActionEvent, GameEvent, @@ -205,6 +206,8 @@ def launch_bot(config: DiscordBotConfig): send_thread = Thread(target=send_main, daemon=True) send_thread.start() + subscribe(GameEvent, bot_event) + return [bot_thread, send_thread] diff --git a/adventure/context.py b/adventure/context.py index 47538ef..2cee189 100644 --- a/adventure/context.py +++ b/adventure/context.py @@ -1,5 +1,17 @@ -from typing import Callable, Dict, List, Sequence, Tuple from contextlib import contextmanager +from logging import getLogger +from types import UnionType +from typing import ( + Callable, + Dict, + List, + Sequence, + Tuple, + Type, + Union, + get_args, + get_origin, +) from packit.agent import Agent from pyee.base import EventEmitter @@ -8,8 +20,7 @@ from adventure.game_system import GameSystem from adventure.models.entity import Actor, Room, World from adventure.models.event import GameEvent -# TODO: replace with event emitter and a context manager -current_broadcast: Callable[[str | GameEvent], None] | None = None +logger = getLogger(__name__) # world context current_step = 0 @@ -28,8 +39,33 @@ actor_agents: Dict[str, Tuple[Actor, Agent]] = {} def broadcast(message: str | GameEvent): - if current_broadcast: - current_broadcast(message) + if isinstance(message, GameEvent): + logger.debug(f"broadcasting {message.type}") + event_emitter.emit(message.type, message) + else: + logger.warning("broadcasting a string message is deprecated") + event_emitter.emit("message", message) + + +def is_union(type_: Type | UnionType): + origin = get_origin(type_) + return origin is UnionType or origin is Union + + +def subscribe( + event_type: Type[str] | Type[GameEvent] | UnionType, + callback: Callable[[GameEvent], None], +): + if is_union(event_type): + for t in get_args(event_type): + subscribe(t, callback) + + return + + logger.debug(f"subscribing {callback.__name__} to {event_type}") + event_emitter.on( + event_type.type, callback + ) # TODO: should this use str or __name__? def has_dungeon_master(): @@ -37,7 +73,12 @@ def has_dungeon_master(): # region context manager -# TODO +@contextmanager +def with_action_context(): + room, actor = get_action_context() + yield room, actor + + # endregion @@ -80,10 +121,6 @@ def get_current_actor() -> Actor | None: return current_actor -def get_current_broadcast(): - return current_broadcast - - def get_current_step() -> int: return current_step @@ -105,11 +142,6 @@ def get_game_systems() -> List[GameSystem]: # region context setters -def set_current_broadcast(broadcast): - global current_broadcast - current_broadcast = broadcast - - def set_current_world(world: World | None): global current_world current_world = world diff --git a/adventure/generate.py b/adventure/generate.py index 291810d..61a532b 100644 --- a/adventure/generate.py +++ b/adventure/generate.py @@ -6,6 +6,7 @@ from packit.agent import Agent from packit.loops import loop_retry from packit.utils import could_be_json +from adventure.context import broadcast from adventure.game_system import GameSystem from adventure.models.entity import ( Actor, @@ -17,7 +18,7 @@ from adventure.models.entity import ( World, WorldEntity, ) -from adventure.models.event import EventCallback, GenerateEvent +from adventure.models.event import GenerateEvent logger = getLogger(__name__) @@ -31,7 +32,7 @@ OPPOSITE_DIRECTIONS = { def duplicate_name_parser(existing_names: List[str]): def name_parser(value: str, **kwargs): - print(f"validating generated name: {value}") + logger.debug(f"validating generated name: {value}") if value in existing_names: raise ValueError(f'"{value}" has already been used.') @@ -50,8 +51,7 @@ def duplicate_name_parser(existing_names: List[str]): return name_parser -def callback_wrapper( - callback: EventCallback | None, +def broadcast_generated( message: str | None = None, entity: WorldEntity | None = None, ): @@ -62,14 +62,12 @@ def callback_wrapper( else: raise ValueError("Either message or entity must be provided") - if callable(callback): - callback(event) + broadcast(event) def generate_room( agent: Agent, world_theme: str, - callback: EventCallback | None = None, existing_rooms: List[str] = [], ) -> Room: name = loop_retry( @@ -84,7 +82,7 @@ def generate_room( result_parser=duplicate_name_parser(existing_rooms), ) - callback_wrapper(callback, message=f"Generating room: {name}") + broadcast_generated(message=f"Generating room: {name}") desc = agent( "Generate a detailed description of the {name} area. What does it look like? " "What does it smell like? What can be seen or heard?", @@ -103,7 +101,6 @@ def generate_room( def generate_item( agent: Agent, world_theme: str, - callback: EventCallback | None = None, dest_room: str | None = None, dest_actor: str | None = None, existing_items: List[str] = [], @@ -130,7 +127,7 @@ def generate_item( result_parser=duplicate_name_parser(existing_items), ) - callback_wrapper(callback, message=f"Generating item: {name}") + broadcast_generated(message=f"Generating item: {name}") desc = agent( "Generate a detailed description of the {name} item. What does it look like? What is it made of? What does it do?", name=name, @@ -140,14 +137,12 @@ def generate_item( item = Item(name=name, description=desc, actions=actions) effect_count = randint(1, 2) - callback_wrapper( - callback, message=f"Generating {effect_count} effects for item: {name}" - ) + broadcast_generated(message=f"Generating {effect_count} effects for item: {name}") effects = [] for i in range(effect_count): try: - effect = generate_effect(agent, world_theme, entity=item, callback=callback) + effect = generate_effect(agent, world_theme, entity=item) effects.append(effect) except Exception: logger.exception("error generating effect") @@ -160,7 +155,6 @@ def generate_actor( agent: Agent, world_theme: str, dest_room: str, - callback: EventCallback | None = None, existing_actors: List[str] = [], ) -> Actor: name = loop_retry( @@ -179,7 +173,7 @@ def generate_actor( result_parser=duplicate_name_parser(existing_actors), ) - callback_wrapper(callback, message=f"Generating actor: {name}") + broadcast_generated(message=f"Generating actor: {name}") description = agent( "Generate a detailed description of the {name} character. What do they look like? What are they wearing? " "What are they doing? Describe their appearance from the perspective of an outside observer." @@ -200,9 +194,7 @@ def generate_actor( ) -def generate_effect( - agent: Agent, theme: str, entity: Item, callback: EventCallback | None = None -) -> Effect: +def generate_effect(agent: Agent, theme: str, entity: Item) -> Effect: entity_type = entity.type existing_effects = [effect.name for effect in entity.effects] @@ -222,7 +214,7 @@ def generate_effect( }, result_parser=duplicate_name_parser(existing_effects), ) - callback_wrapper(callback, message=f"Generating effect: {name}") + broadcast_generated(message=f"Generating effect: {name}") description = agent( "Generate a detailed description of the {name} effect. What does it look like? What does it do? " @@ -302,12 +294,11 @@ def generate_world( theme: str, room_count: int | None = None, max_rooms: int = 5, - callback: EventCallback | None = None, systems: List[GameSystem] = [], ) -> World: room_count = room_count or randint(3, max_rooms) - callback_wrapper(callback, message=f"Generating a {theme} with {room_count} rooms") + broadcast_generated(message=f"Generating a {theme} with {room_count} rooms") existing_actors: List[str] = [] existing_items: List[str] = [] @@ -317,11 +308,9 @@ def generate_world( rooms = [] for i in range(room_count): try: - room = generate_room( - agent, theme, existing_rooms=existing_rooms, callback=callback - ) + room = generate_room(agent, theme, existing_rooms=existing_rooms) generate_system_attributes(agent, theme, room, systems) - callback_wrapper(callback, entity=room) + broadcast_generated(entity=room) rooms.append(room) existing_rooms.append(room.name) except Exception: @@ -329,9 +318,7 @@ def generate_world( continue item_count = randint(1, 3) - callback_wrapper( - callback, f"Generating {item_count} items for room: {room.name}" - ) + broadcast_generated(f"Generating {item_count} items for room: {room.name}") for j in range(item_count): try: @@ -340,10 +327,9 @@ def generate_world( theme, dest_room=room.name, existing_items=existing_items, - callback=callback, ) generate_system_attributes(agent, theme, item, systems) - callback_wrapper(callback, entity=item) + broadcast_generated(entity=item) room.items.append(item) existing_items.append(item.name) @@ -351,8 +337,8 @@ def generate_world( logger.exception("error generating item") actor_count = randint(1, 3) - callback_wrapper( - callback, message=f"Generating {actor_count} actors for room: {room.name}" + broadcast_generated( + message=f"Generating {actor_count} actors for room: {room.name}" ) for j in range(actor_count): @@ -362,10 +348,9 @@ def generate_world( theme, dest_room=room.name, existing_actors=existing_actors, - callback=callback, ) generate_system_attributes(agent, theme, actor, systems) - callback_wrapper(callback, entity=actor) + broadcast_generated(entity=actor) room.actors.append(actor) existing_actors.append(actor.name) @@ -375,9 +360,7 @@ def generate_world( # generate the actor's inventory item_count = randint(0, 2) - callback_wrapper( - callback, f"Generating {item_count} items for actor {actor.name}" - ) + broadcast_generated(f"Generating {item_count} items for actor {actor.name}") for k in range(item_count): try: @@ -386,10 +369,9 @@ def generate_world( theme, dest_room=room.name, existing_items=existing_items, - callback=callback, ) generate_system_attributes(agent, theme, item, systems) - callback_wrapper(callback, entity=item) + broadcast_generated(entity=item) actor.items.append(item) existing_items.append(item.name) diff --git a/adventure/main.py b/adventure/main.py index ddeb1fb..f80457f 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -6,9 +6,10 @@ from typing import List from dotenv import load_dotenv from packit.agent import Agent, agent_easy_connect from packit.utils import logger_with_colors -from pyee.base import EventEmitter from yaml import Loader, load +from adventure.context import subscribe + def load_yaml(file): return load(file, Loader=Loader) @@ -36,9 +37,9 @@ if True: from adventure.context import set_current_step, set_dungeon_master from adventure.game_system import GameSystem from adventure.generate import generate_world - from adventure.models.config import Config, DEFAULT_CONFIG + from adventure.models.config import DEFAULT_CONFIG, Config from adventure.models.entity import World, WorldState - from adventure.models.event import EventCallback, GameEvent, GenerateEvent + from adventure.models.event import GenerateEvent from adventure.models.files import PromptFile, WorldPrompt from adventure.plugins import load_plugin from adventure.simulate import simulate_world @@ -181,9 +182,7 @@ def get_world_prompt(args) -> WorldPrompt: ) -def load_or_generate_world( - args, players, callbacks, systems, world_prompt: WorldPrompt -): +def load_or_generate_world(args, players, systems, world_prompt: WorldPrompt): world_file = args.world + ".json" world_state_file = args.state or (args.world + ".state.json") @@ -212,20 +211,12 @@ def load_or_generate_world( llm, ) - world = None - - def broadcast_callback(event: GameEvent): - logger.debug("broadcasting generation event: %s", event) - for callback in callbacks: - callback(event) - world = generate_world( world_builder, args.world, world_prompt.theme, room_count=args.rooms, max_rooms=args.max_rooms, - callback=broadcast_callback, ) save_world(world, world_file) @@ -251,38 +242,25 @@ def main(): if args.player: players.append(args.player) - # set up callbacks - callbacks: List[EventCallback] = [] - # launch other threads threads = [] if args.render: - from adventure.render_comfy import launch_render + from adventure.render_comfy import launch_render, render_generated threads.extend(launch_render(config.render)) - if args.render_generated: - from adventure.render_comfy import render_entity - - def render_generated(event: GameEvent): - if isinstance(event, GenerateEvent) and event.entity: - logger.info("rendering generated entity: %s", event.entity.name) - render_entity(event.entity) - - callbacks.append(render_generated) + subscribe(GenerateEvent, render_generated) if args.discord: - from adventure.bot_discord import bot_event, launch_bot + from adventure.bot_discord import launch_bot threads.extend(launch_bot(config.bot.discord)) - callbacks.append(bot_event) if args.server: - from adventure.server_socket import launch_server, server_event, server_system + from adventure.server_socket import launch_server, server_system - threads.extend(launch_server()) - callbacks.append(server_event) + threads.extend(launch_server(config.server.websocket)) # register the thread shutdown handler def shutdown_threads(): @@ -327,7 +305,7 @@ def main(): # load or generate the world world_prompt = get_world_prompt(args) world, world_state_file = load_or_generate_world( - args, players, callbacks, extra_systems, world_prompt=world_prompt + args, players, extra_systems, world_prompt=world_prompt ) # make sure the snapshot system runs last @@ -362,7 +340,6 @@ def main(): steps=args.steps, actions=extra_actions, systems=extra_systems, - callbacks=callbacks, ) diff --git a/adventure/models/config.py b/adventure/models/config.py index b2220cd..9b14b66 100644 --- a/adventure/models/config.py +++ b/adventure/models/config.py @@ -36,10 +36,22 @@ class RenderConfig: steps: Range +@dataclass +class WebsocketServerConfig: + host: str + port: int + + +@dataclass +class ServerConfig: + websocket: WebsocketServerConfig + + @dataclass class Config: bot: BotConfig render: RenderConfig + server: ServerConfig DEFAULT_CONFIG = Config( @@ -57,4 +69,5 @@ DEFAULT_CONFIG = Config( }, steps=Range(min=30, max=30), ), + server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8000)), ) diff --git a/adventure/render_comfy.py b/adventure/render_comfy.py index e68672e..a8aa95a 100644 --- a/adventure/render_comfy.py +++ b/adventure/render_comfy.py @@ -17,11 +17,12 @@ from jinja2 import Environment, FileSystemLoader, select_autoescape from PIL import Image from adventure.context import broadcast -from adventure.models.config import RenderConfig, DEFAULT_CONFIG +from adventure.models.config import DEFAULT_CONFIG, RenderConfig from adventure.models.entity import WorldEntity from adventure.models.event import ( ActionEvent, GameEvent, + GenerateEvent, RenderEvent, ReplyEvent, ResultEvent, @@ -327,6 +328,12 @@ def render_event(event: GameEvent): render_queue.put(event) +def render_generated(event: GameEvent): + if isinstance(event, GenerateEvent) and event.entity: + logger.info("rendering generated entity: %s", event.entity.name) + render_entity(event.entity) + + def launch_render(config: RenderConfig): global render_config global render_thread diff --git a/adventure/rpg_systems/crafting_actions.py b/adventure/rpg_systems/crafting_actions.py index bd7cee3..ff24030 100644 --- a/adventure/rpg_systems/crafting_actions.py +++ b/adventure/rpg_systems/crafting_actions.py @@ -1,8 +1,9 @@ from random import randint + from adventure.context import broadcast, get_current_context, get_dungeon_master from adventure.generate import generate_item -from adventure.models.entity import Item from adventure.models.base import dataclass +from adventure.models.entity import Item @dataclass diff --git a/adventure/rpg_systems/magic_actions.py b/adventure/rpg_systems/magic_actions.py index fdc39ec..6cc1b98 100644 --- a/adventure/rpg_systems/magic_actions.py +++ b/adventure/rpg_systems/magic_actions.py @@ -1,4 +1,5 @@ from random import randint + from adventure.context import broadcast, get_current_context, get_dungeon_master from adventure.search import find_actor_in_room diff --git a/adventure/rpg_systems/movement_actions.py b/adventure/rpg_systems/movement_actions.py index 5822003..2e2cb43 100644 --- a/adventure/rpg_systems/movement_actions.py +++ b/adventure/rpg_systems/movement_actions.py @@ -1,4 +1,5 @@ from random import randint + from adventure.context import broadcast, get_current_context, get_dungeon_master from adventure.search import find_item_in_room diff --git a/adventure/server_socket.py b/adventure/server_socket.py index 6bab3ff..050e1f2 100644 --- a/adventure/server_socket.py +++ b/adventure/server_socket.py @@ -17,7 +17,9 @@ from adventure.context import ( get_actor_agent_for_name, get_current_world, set_actor_agent, + subscribe, ) +from adventure.models.config import DEFAULT_CONFIG, WebsocketServerConfig from adventure.models.entity import Actor, Item, Room, World from adventure.models.event import ( GameEvent, @@ -45,6 +47,7 @@ last_snapshot: str | None = None player_names: Dict[str, str] = {} recent_events: MutableSequence[GameEvent] = deque(maxlen=100) recent_json: MutableSequence[str] = deque(maxlen=100) +server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket def get_player_name(client_id: str) -> str: @@ -257,8 +260,12 @@ def send_and_append(id: str, message: Dict): return json_message -def launch_server(): +def launch_server(config: WebsocketServerConfig): global socket_thread + global server_config + + logger.info("configuring websocket server: %s", config) + server_config = config def run_sockets(): asyncio.run(server_main()) @@ -267,6 +274,8 @@ def launch_server(): socket_thread = Thread(target=run_sockets, daemon=True) socket_thread.start() + subscribe(GameEvent, server_event) + return [socket_thread] diff --git a/adventure/simulate.py b/adventure/simulate.py index bf5f771..959c9f5 100644 --- a/adventure/simulate.py +++ b/adventure/simulate.py @@ -1,7 +1,7 @@ from itertools import count from logging import getLogger -from typing import Callable, Sequence from math import inf +from typing import Callable, Sequence from packit.loops import loop_retry from packit.results import multi_function_or_str_result @@ -17,12 +17,12 @@ from adventure.actions import ( action_tell, ) from adventure.context import ( + broadcast, get_actor_agent_for_name, get_actor_for_agent, get_current_step, get_current_world, set_current_actor, - set_current_broadcast, set_current_room, set_current_step, set_current_world, @@ -30,14 +30,7 @@ from adventure.context import ( ) from adventure.game_system import GameSystem from adventure.models.entity import World -from adventure.models.event import ( - ActionEvent, - EventCallback, - GameEvent, - ReplyEvent, - ResultEvent, - StatusEvent, -) +from adventure.models.event import ActionEvent, ReplyEvent, ResultEvent from adventure.utils.world import describe_entity, format_attributes logger = getLogger(__name__) @@ -67,26 +60,12 @@ def simulate_world( world: World, steps: float | int = inf, actions: Sequence[Callable[..., str]] = [], - callbacks: Sequence[EventCallback] = [], systems: Sequence[GameSystem] = [], ): logger.info("Simulating the world") set_current_world(world) set_game_systems(systems) - # set up a broadcast callback - def broadcast_callback(message: str | GameEvent): - logger.info(message) - if isinstance(message, str): - event = StatusEvent(text=message) - else: - event = message - - for callback in callbacks: - callback(event) - - set_current_broadcast(broadcast_callback) - # build a toolbox for the actions action_tools = Toolbox( [ @@ -135,11 +114,7 @@ def simulate_world( else: event = ReplyEvent.from_text(value, room, actor) - for callback in callbacks: - logger.info( - f"calling input callback for {actor_name}: {callback.__name__}" - ) - callback(event) + broadcast(event) return world_result_parser(value, agent, **kwargs) @@ -174,8 +149,7 @@ def simulate_world( agent.memory.append(result) result_event = ResultEvent(result=result, room=room, actor=actor) - for callback in callbacks: - callback(result_event) + broadcast(result_event) for system in systems: if system.simulate: