From ef8529ef62942d8547781f5f8a6dc48ab970b75d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 1 Jun 2024 14:46:11 -0500 Subject: [PATCH] make sure config is used consistently, start adding tests --- .gitignore | 2 + docs/cli.md | 80 +++++++++++++++++++++++++++++++++ taleweave/actions/planning.py | 18 +++++--- taleweave/bot/discord.py | 24 +++++----- taleweave/context.py | 15 +++++++ taleweave/generate.py | 19 ++++++-- taleweave/main.py | 4 +- taleweave/models/config.py | 13 +++++- taleweave/models/event.py | 1 + taleweave/player.py | 7 ++- taleweave/render/comfy.py | 32 ++++++------- taleweave/server/websocket.py | 29 ++++++------ taleweave/simulate.py | 17 ++++--- taleweave/utils/conversation.py | 10 ++--- tests/utils/__init__.py | 0 tests/utils/test_search.py | 30 +++++++++++++ 16 files changed, 232 insertions(+), 69 deletions(-) create mode 100644 docs/cli.md create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_search.py diff --git a/.gitignore b/.gitignore index 82d783a..99417be 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ venv/ client/node_modules/ client/out/ taleweave/custom_* +.coverage +coverage.* diff --git a/docs/cli.md b/docs/cli.md new file mode 100644 index 0000000..1f30bd9 --- /dev/null +++ b/docs/cli.md @@ -0,0 +1,80 @@ +# TaleWeave AI Command Line Options + +The following command line arguments are available when launching the TaleWeave AI engine: + +- **--actions** + - **Type:** String + - **Description:** Additional actions to include in the simulation. Note: More than one argument is allowed. + +- **--add-rooms** + - **Type:** Integer + - **Default:** 0 + - **Description:** The number of new rooms to generate before starting the simulation. + +- **--config** + - **Type:** String + - **Description:** The file to load additional configuration from. + +- **--discord** + - **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option. + - **Description:** Run the simulation in a Discord bot. + +- **--flavor** + - **Type:** String + - **Default:** "" + - **Description:** Additional flavor text for the generated world. + +- **--optional-actions** + - **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option. + - **Description:** Include optional actions in the simulation. + +- **--player** + - **Type:** String + - **Description:** The name of the character to play as. + +- **--prompts** + - **Type:** String + - **Description:** The file to load game prompts from. Note: More than one argument is allowed. + +- **--render** + - **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option. + - **Description:** Run the render thread. + +- **--render-generated** + - **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option. + - **Description:** Render entities as they are generated. + +- **--rooms** + - **Type:** Integer + - **Description:** The number of rooms to generate. + +- **--server** + - **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option. + - **Description:** Run the websocket server. + +- **--state** + - **Type:** String + - **Description:** The file to save the world state to. Defaults to `$world.state.json` if not set. + +- **--turns** + - **Type:** Integer or "inf" + - **Default:** 10 + - **Description:** The number of simulation turns to run. + +- **--systems** + - **Type:** String + - **Description:** Extra systems to run in the simulation. Note: More than one argument is allowed. + +- **--theme** + - **Type:** String + - **Default:** "fantasy" + - **Description:** The theme of the generated world. + +- **--world** + - **Type:** String + - **Default:** "world" + - **Description:** The file to save the generated world to. + +- **--world-template** + - **Type:** String + - **Description:** The template file to load the world prompt from. \ No newline at end of file diff --git a/taleweave/actions/planning.py b/taleweave/actions/planning.py index 1c4cc2a..1d052e4 100644 --- a/taleweave/actions/planning.py +++ b/taleweave/actions/planning.py @@ -2,16 +2,14 @@ from taleweave.context import ( action_context, get_agent_for_character, get_current_turn, + get_game_config, get_prompt, ) from taleweave.errors import ActionError -from taleweave.models.config import DEFAULT_CONFIG from taleweave.models.planning import CalendarEvent from taleweave.utils.planning import get_recent_notes from taleweave.utils.prompt import format_prompt -character_config = DEFAULT_CONFIG.world.character - def take_note(fact: str): """ @@ -22,11 +20,13 @@ def take_note(fact: str): fact: The fact to remember. """ + config = get_game_config() + with action_context() as (_, action_character): if fact in action_character.planner.notes: raise ActionError(get_prompt("action_take_note_error_duplicate")) - if len(action_character.planner.notes) >= character_config.note_limit: + if len(action_character.planner.notes) >= config.world.character.note_limit: raise ActionError(get_prompt("action_take_note_error_limit")) action_character.planner.notes.append(fact) @@ -103,6 +103,8 @@ def summarize_notes(limit: int) -> str: limit: The maximum number of notes to keep. """ + config = get_game_config() + with action_context() as (_, action_character): notes = action_character.planner.notes if len(notes) == 0: @@ -120,11 +122,11 @@ def summarize_notes(limit: int) -> str: ) new_notes = [note.strip() for note in summary.split("\n") if note.strip()] - if len(new_notes) > character_config.note_limit: + if len(new_notes) > config.world.character.note_limit: raise ActionError( format_prompt( "action_summarize_notes_error_limit", - limit=character_config.note_limit, + limit=config.world.character.note_limit, ) ) @@ -165,7 +167,9 @@ def check_calendar(count: int): count: The number of upcoming events to read. 5 is usually a good number. """ - count = min(count, character_config.event_limit) + config = get_game_config() + + count = min(count, config.world.character.event_limit) current_turn = get_current_turn() with action_context() as (_, action_character): diff --git a/taleweave/bot/discord.py b/taleweave/bot/discord.py index c34e147..d6802e1 100644 --- a/taleweave/bot/discord.py +++ b/taleweave/bot/discord.py @@ -11,10 +11,11 @@ from taleweave.context import ( broadcast, get_character_agent_for_name, get_current_world, + get_game_config, set_character_agent, subscribe, ) -from taleweave.models.config import DEFAULT_CONFIG, DiscordBotConfig +from taleweave.models.config import DiscordBotConfig from taleweave.models.event import ( ActionEvent, GameEvent, @@ -38,7 +39,6 @@ from taleweave.utils.prompt import format_prompt logger = getLogger(__name__) client = None -bot_config: DiscordBotConfig = DEFAULT_CONFIG.bot.discord active_tasks = set() event_messages: Dict[int, str | GameEvent] = {} @@ -78,21 +78,24 @@ class AdventureClient(Client): if message.author == self.user: return + config = get_game_config() author = message.author channel = message.channel user_name = author.name # include nick if message.content.startswith( - bot_config.command_prefix + bot_config.name_command + config.bot.discord.command_prefix + config.bot.discord.name_command ): world = get_current_world() if world: world_message = format_prompt( - "discord_world_active", bot_name=bot_config.name_title, world=world + "discord_world_active", + bot_name=config.bot.discord.name_title, + world=world, ) else: world_message = format_prompt( - "discord_world_none", bot_name=bot_config.name_title + "discord_world_none", bot_name=config.bot.discord.name_title ) await message.channel.send(world_message) @@ -100,7 +103,7 @@ class AdventureClient(Client): if message.content.startswith("!help"): await message.channel.send( - format_prompt("discord_help", bot_name=bot_config.name_command) + format_prompt("discord_help", bot_name=config.bot.discord.name_command) ) return @@ -172,14 +175,11 @@ class AdventureClient(Client): def launch_bot(config: DiscordBotConfig): - global bot_config global client - bot_config = config - # message contents need to be enabled for multi-server bots intents = Intents.default() - if bot_config.content_intent: + if config.content_intent: intents.message_content = True client = AdventureClient(intents=intents) @@ -246,12 +246,14 @@ def get_active_channels(): if not client: return [] + config = get_game_config() + # return client.private_channels return [ channel for guild in client.guilds for channel in guild.text_channels - if channel.name in bot_config.channels + if channel.name in config.bot.discord.channels ] diff --git a/taleweave/context.py b/taleweave/context.py index 3fdfa81..be36142 100644 --- a/taleweave/context.py +++ b/taleweave/context.py @@ -18,6 +18,7 @@ from packit.agent import Agent from pyee.base import EventEmitter from taleweave.game_system import GameSystem +from taleweave.models.config import DEFAULT_CONFIG, Config from taleweave.models.entity import Character, Room, World from taleweave.models.event import GameEvent, StatusEvent from taleweave.models.prompt import PromptLibrary @@ -34,6 +35,7 @@ dungeon_master: Agent | None = None # game context event_emitter = EventEmitter() +game_config: Config = DEFAULT_CONFIG game_systems: List[GameSystem] = [] prompt_library: PromptLibrary = PromptLibrary(prompts={}) system_data: Dict[str, Any] = {} @@ -160,6 +162,10 @@ def get_dungeon_master() -> Agent: return dungeon_master +def get_game_config() -> Config: + return game_config + + def get_game_systems() -> List[GameSystem]: return game_systems @@ -172,6 +178,10 @@ def get_prompt_library() -> PromptLibrary: return prompt_library +def get_system_config(system: str) -> Any | None: + return game_config.systems.data.get(system) + + def get_system_data(system: str) -> Any | None: return system_data.get(system) @@ -209,6 +219,11 @@ def set_dungeon_master(agent): dungeon_master = agent +def set_game_config(config: Config): + global game_config + game_config = config + + def set_game_systems(systems: Sequence[GameSystem]): global game_systems game_systems = list(systems) diff --git a/taleweave/generate.py b/taleweave/generate.py index cbc9d09..2c0d1dc 100644 --- a/taleweave/generate.py +++ b/taleweave/generate.py @@ -7,9 +7,14 @@ from packit.loops import loop_retry from packit.results import enum_result, int_result from packit.utils import could_be_json -from taleweave.context import broadcast, get_prompt, set_current_world, set_system_data +from taleweave.context import ( + broadcast, + get_game_config, + get_prompt, + set_current_world, + set_system_data, +) from taleweave.game_system import GameSystem -from taleweave.models.config import DEFAULT_CONFIG, WorldConfig from taleweave.models.effect import ( EffectPattern, FloatEffectPattern, @@ -33,7 +38,10 @@ from taleweave.utils.string import normalize_name logger = getLogger(__name__) -world_config: WorldConfig = DEFAULT_CONFIG.world + +def get_world_config(): + config = get_game_config() + return config.world def duplicate_name_parser(existing_names: List[str]): @@ -112,6 +120,7 @@ def generate_room( actions = {} room = Room(name=name, description=desc, items=[], characters=[], actions=actions) + world_config = get_world_config() item_count = resolve_int_range(world_config.size.room_items) or 0 broadcast_generated( format_prompt( @@ -276,6 +285,7 @@ def generate_item( item = Item(name=name, description=desc, actions=actions) generate_system_attributes(agent, world, item, systems) + world_config = get_world_config() effect_count = resolve_int_range(world_config.size.item_effects) or 0 broadcast_generated( message=format_prompt( @@ -343,6 +353,7 @@ def generate_character( generate_system_attributes(agent, world, character, systems) # generate the character's inventory + world_config = get_world_config() item_count = resolve_int_range(world_config.size.character_items) or 0 broadcast_generated( message=format_prompt( @@ -499,6 +510,7 @@ def link_rooms( rooms: List[Room] | None = None, ) -> None: rooms = rooms or world.rooms + world_config = get_world_config() for room in rooms: num_portals = resolve_int_range(world_config.size.portals) or 0 @@ -550,6 +562,7 @@ def generate_world( systems: List[GameSystem], room_count: int | None = None, ) -> World: + world_config = get_world_config() room_count = room_count or resolve_int_range(world_config.size.rooms) or 0 broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme")) diff --git a/taleweave/main.py b/taleweave/main.py index 9e918be..8127cfa 100644 --- a/taleweave/main.py +++ b/taleweave/main.py @@ -27,7 +27,7 @@ except Exception as err: logger = logger_with_colors(__name__) # , level="DEBUG") -load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True) +load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True) if True: from taleweave.context import ( @@ -35,6 +35,7 @@ if True: get_system_data, set_current_turn, set_dungeon_master, + set_game_config, set_system_data, subscribe, ) @@ -312,6 +313,7 @@ def main(): if args.config: with open(args.config, "r") as f: config = Config(**load_yaml(f)) + set_game_config(config) else: config = DEFAULT_CONFIG diff --git a/taleweave/models/config.py b/taleweave/models/config.py index 648e7f7..1ce28e2 100644 --- a/taleweave/models/config.py +++ b/taleweave/models/config.py @@ -1,6 +1,6 @@ from typing import Dict, List -from .base import IntRange, dataclass +from .base import Attributes, IntRange, dataclass @dataclass @@ -44,6 +44,15 @@ class ServerConfig: websocket: WebsocketServerConfig +@dataclass +class SystemsConfig: + """ + Configuration for the game systems. + """ + + data: Attributes + + @dataclass class WorldCharacterConfig: conversation_limit: int @@ -80,6 +89,7 @@ class Config: bot: BotConfig render: RenderConfig server: ServerConfig + systems: SystemsConfig world: WorldConfig @@ -107,6 +117,7 @@ DEFAULT_CONFIG = Config( steps=30, ), server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8001)), + systems=SystemsConfig(data={}), world=WorldConfig( character=WorldCharacterConfig( conversation_limit=2, diff --git a/taleweave/models/event.py b/taleweave/models/event.py index 5b5b12f..30bc592 100644 --- a/taleweave/models/event.py +++ b/taleweave/models/event.py @@ -60,6 +60,7 @@ class PromptEvent(BaseModel): A prompt for a character to take an action. """ + actions: Dict[str, Any] prompt: str room: Room character: Character diff --git a/taleweave/player.py b/taleweave/player.py index afb9870..2e6e14b 100644 --- a/taleweave/player.py +++ b/taleweave/player.py @@ -197,15 +197,20 @@ class RemotePlayer(BasePlayer): Ask the player for input. """ + actions = {} formatted_prompt = prompt.format(**kwargs) if toolbox: + actions = toolbox.list_definitions() formatted_prompt += self.format_psuedo_functions(toolbox) self.memory.append(HumanMessage(content=formatted_prompt)) with action_context() as (current_room, current_character): prompt_event = PromptEvent( - prompt=formatted_prompt, room=current_room, character=current_character + actions=actions, + prompt=formatted_prompt, + room=current_room, + character=current_character, ) try: diff --git a/taleweave/render/comfy.py b/taleweave/render/comfy.py index ced7dad..39ea993 100644 --- a/taleweave/render/comfy.py +++ b/taleweave/render/comfy.py @@ -15,9 +15,9 @@ from fnvhash import fnv1a_32 from jinja2 import Environment, FileSystemLoader, select_autoescape from PIL import Image -from taleweave.context import broadcast -from taleweave.models.base import uuid -from taleweave.models.config import DEFAULT_CONFIG, RenderConfig +from taleweave.context import broadcast, get_game_config +from taleweave.models.base import IntRange, uuid +from taleweave.models.config import RenderConfig from taleweave.models.entity import WorldEntity from taleweave.models.event import ( ActionEvent, @@ -36,7 +36,6 @@ logger = getLogger(__name__) server_address = environ["COMFY_API"] client_id = uuid() -render_config: RenderConfig = DEFAULT_CONFIG.render # requests to generate images for game events @@ -44,12 +43,17 @@ render_queue: Queue[GameEvent | WorldEntity] = Queue() render_thread: Thread | None = None -def generate_cfg(): - return resolve_int_range(render_config.cfg) +def get_render_config(): + config = get_game_config() + return config.render -def generate_steps(): - return resolve_int_range(render_config.steps) +def generate_cfg(cfg: int | IntRange): + return resolve_int_range(cfg) + + +def generate_steps(steps: int | IntRange): + return resolve_int_range(steps) def generate_batches( @@ -148,9 +152,10 @@ def generate_image_tool(prompt, count, size="landscape"): def generate_images( prompt: str, count: int, size="landscape", prefix="output" ) -> List[str]: - cfg = generate_cfg() + render_config = get_render_config() + cfg = generate_cfg(render_config.cfg) dims = render_config.sizes[size] - steps = generate_steps() + steps = generate_steps(render_config.steps) seed = randint(0, 10000000) checkpoint = choice(render_config.checkpoints) logger.info( @@ -248,6 +253,8 @@ def get_image_prefix(event: GameEvent | WorldEntity) -> str: def render_loop(): + render_config = get_render_config() + while True: event = render_queue.get() prefix = get_image_prefix(event) @@ -317,13 +324,8 @@ def render_generated(event: GameEvent): def launch_render(config: RenderConfig): - global render_config global render_thread - # update the config - logger.info("updating render config: %s", config) - render_config = config - # start the render thread logger.info("launching render thread") render_thread = Thread(target=render_loop, daemon=True) diff --git a/taleweave/server/websocket.py b/taleweave/server/websocket.py index e77b8b1..6c266a9 100644 --- a/taleweave/server/websocket.py +++ b/taleweave/server/websocket.py @@ -16,10 +16,11 @@ from taleweave.context import ( broadcast, get_character_agent_for_name, get_current_world, + get_game_config, set_character_agent, subscribe, ) -from taleweave.models.config import DEFAULT_CONFIG, WebsocketServerConfig +from taleweave.models.config import WebsocketServerConfig from taleweave.models.entity import World, WorldEntity from taleweave.models.event import ( GameEvent, @@ -47,7 +48,6 @@ last_snapshot: str | None = None player_names: Dict[str, str] = {} recent_events: MutableSequence[GameEvent] = deque(maxlen=100) recent_json: MutableSequence[str] = deque(maxlen=100) -server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket def get_player_name(client_id: str) -> str: @@ -59,16 +59,15 @@ async def handler(websocket): logger.info("client connected, given id: %s", id) connected.add(websocket) - async def next_turn(character: str, prompt: str) -> None: + async def next_turn(event: PromptEvent) -> None: await websocket.send( dumps( { - # TODO: these should be fields in the PromptEvent - "type": "prompt", - "client": id, - "character": character, - "prompt": prompt, - "actions": [], + "type": event.type, + "client": id, # TODO: this should be a field in the PromptEvent + "character": event.character, + "prompt": event.prompt, + "actions": event.actions, } ), ) @@ -77,7 +76,7 @@ async def handler(websocket): # TODO: nothing about this is good player = get_player(id) if player and player.name == event.character.name: - asyncio.run(next_turn(event.character.name, event.prompt)) + asyncio.run(next_turn(event)) return True return False @@ -303,10 +302,6 @@ def send_and_append(id: str, message: Dict): def launch_server(config: WebsocketServerConfig): global socket_thread - global server_config - - logger.info("configuring websocket server: %s", config) - server_config = config def run_sockets(): asyncio.run(server_main()) @@ -321,7 +316,11 @@ def launch_server(config: WebsocketServerConfig): async def server_main(): - async with websockets.serve(handler, server_config.host, server_config.port): + config = get_game_config() + + async with websockets.serve( + handler, config.server.websocket.host, config.server.websocket.port + ): logger.info("websocket server started") await asyncio.Future() # run forever diff --git a/taleweave/simulate.py b/taleweave/simulate.py index 2306e0e..216741a 100644 --- a/taleweave/simulate.py +++ b/taleweave/simulate.py @@ -36,6 +36,7 @@ from taleweave.context import ( get_character_for_agent, get_current_turn, get_current_world, + get_game_config, get_prompt, set_current_character, set_current_room, @@ -44,7 +45,6 @@ from taleweave.context import ( set_game_systems, ) from taleweave.game_system import GameSystem -from taleweave.models.config import DEFAULT_CONFIG from taleweave.models.entity import Character, Room, World from taleweave.models.event import ActionEvent, ResultEvent from taleweave.utils.conversation import make_keyword_condition, summarize_room @@ -57,9 +57,6 @@ from taleweave.utils.world import describe_entity, format_attributes logger = getLogger(__name__) -turn_config = DEFAULT_CONFIG.world.turn - - def world_result_parser(value, agent, **kwargs): current_world = get_current_world() if not current_world: @@ -121,10 +118,11 @@ def prompt_character_action( event = ActionEvent.from_json(value, room, character) else: # TODO: this path should be removed and throw - logger.warning( - "invalid action, emitting as result event - this is a bug somewhere" - ) - event = ResultEvent(value, room, character) + # logger.warning( + # "invalid action, emitting as result event - this is a bug somewhere" + # ) + # event = ResultEvent(value, room, character) + raise ValueError("invalid non-JSON action") broadcast(event) @@ -198,7 +196,8 @@ def prompt_character_planning( current_turn: int, max_steps: int | None = None, ) -> str: - max_steps = max_steps or turn_config.planning_steps + config = get_game_config() + max_steps = max_steps or config.world.turn.planning_steps notes_prompt, events_prompt = get_notes_events(character, current_turn) diff --git a/taleweave/utils/conversation.py b/taleweave/utils/conversation.py index 5cea46e..229a2f5 100644 --- a/taleweave/utils/conversation.py +++ b/taleweave/utils/conversation.py @@ -8,8 +8,7 @@ from packit.conditions import condition_or, condition_threshold, make_flag_condi from packit.results import multi_function_or_str_result from packit.utils import could_be_json -from taleweave.context import broadcast -from taleweave.models.config import DEFAULT_CONFIG +from taleweave.context import broadcast, get_game_config from taleweave.models.entity import Character, Room from taleweave.models.event import ReplyEvent from taleweave.utils.prompt import format_str @@ -19,9 +18,6 @@ from .string import and_list, normalize_name logger = getLogger(__name__) -character_config = DEFAULT_CONFIG.world.character - - def make_keyword_condition(end_message: str, keywords=["end", "stop"]): set_end, condition_end = make_flag_condition() @@ -99,8 +95,10 @@ def loop_conversation( Loop through a conversation between a series of agents, using metadata from their characters. """ + config = get_game_config() + if max_length is None: - max_length = character_config.conversation_limit + max_length = config.world.character.conversation_limit if len(characters) != len(agents): raise ValueError("The number of characters and agents must match.") diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_search.py b/tests/utils/test_search.py new file mode 100644 index 0000000..9b109bb --- /dev/null +++ b/tests/utils/test_search.py @@ -0,0 +1,30 @@ +from unittest import TestCase + +from taleweave.models.entity import Room, World +from taleweave.utils.search import find_room + + +class TestFindRoom(TestCase): + def test_existing_room(self): + world = World(name="Test World", rooms=[], theme="testing", order=[]) + room = Room( + name="Test Room", + description="A test room.", + characters=[], + items=[], + portals=[], + ) + world.rooms.append(room) + self.assertEqual(find_room(world, "Test Room"), room) + + def test_missing_room(self): + world = World(name="Test World", rooms=[], theme="testing", order=[]) + room = Room( + name="Test Room", + description="A test room.", + characters=[], + items=[], + portals=[], + ) + world.rooms.append(room) + self.assertEqual(find_room(world, "Missing Room"), None)