From 0927259ecee0fb226bf36a6cb0c9cf31174e38a8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 2 Jun 2024 20:14:41 -0500 Subject: [PATCH] add missing fields to config, use config memory limit --- config.yml | 15 +++++++++++++++ taleweave/bot/discord.py | 28 +++++++++++++++++++++------- taleweave/main.py | 20 +++++++++----------- taleweave/models/config.py | 2 ++ 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/config.yml b/config.yml index 5a479ca..be93e37 100644 --- a/config.yml +++ b/config.yml @@ -1,6 +1,9 @@ bot: discord: channels: [taleweave] + command_prefix: "!" + name_command: taleweave + name_title: TaleWeave AI render: cfg: min: 5 @@ -8,6 +11,7 @@ render: checkpoints: [ "diffusion-sdxl-dynavision-0-5-5-7.safetensors", ] + count: 2 path: /tmp/taleweave-images sizes: landscape: @@ -26,7 +30,14 @@ server: websocket: host: 0.0.0.0 port: 8001 +systems: + data: {} world: + character: + conversation_limit: 2 + event_limit: 5 + memory_limit: 25 + note_limit: 10 size: character_items: min: 0 @@ -46,3 +57,7 @@ world: room_items: min: 0 max: 3 + turn: + action_retries: 5 + planning_retries: 3 + planning_steps: 3 \ No newline at end of file diff --git a/taleweave/bot/discord.py b/taleweave/bot/discord.py index 4fd133a..08f41dc 100644 --- a/taleweave/bot/discord.py +++ b/taleweave/bot/discord.py @@ -31,6 +31,7 @@ from taleweave.player import ( RemotePlayer, get_player, has_player, + list_players, remove_player, set_player, ) @@ -50,7 +51,7 @@ def remove_tags(text: str) -> str: Remove any tags. """ - return sub(r"<[^>]*>", "", text) + return sub(r"<[^>]*>", "", text).strip() class AdventureClient(Client): @@ -82,8 +83,9 @@ class AdventureClient(Client): author = message.author channel = message.channel user_name = author.name # include nick + content = remove_tags(message.content) - if message.content.startswith( + if content.startswith( config.bot.discord.command_prefix + config.bot.discord.name_command ): world = get_current_world() @@ -101,14 +103,14 @@ class AdventureClient(Client): await message.channel.send(world_message) return - if message.content.startswith("!help"): + if content.startswith("!help"): await message.channel.send( format_prompt("discord_help", bot_name=config.bot.discord.name_command) ) return - if message.content.startswith("!join"): - character_name = remove_tags(message.content).replace("!join", "").strip() + if content.startswith("!join"): + character_name = content.replace("!join", "").strip() if has_player(character_name): await channel.send( format_prompt("discord_join_error_taken", character=character_name) @@ -145,9 +147,14 @@ class AdventureClient(Client): join_event = PlayerEvent("join", character_name, user_name) return broadcast(join_event) + if content.startswith("!players"): + players = list_players() + await channel.send(embed=format_players(players)) + return + player = get_player(user_name) if isinstance(player, RemotePlayer): - if message.content.startswith("!leave"): + if content.startswith("!leave"): remove_player(user_name) # revert to LLM agent @@ -163,7 +170,6 @@ class AdventureClient(Client): leave_event = PlayerEvent("leave", player.name, user_name) return broadcast(leave_event) else: - content = remove_tags(message.content) player.input_queue.put(content) logger.info( f"received message from {user_name} for {player.name}: {content}" @@ -174,6 +180,14 @@ class AdventureClient(Client): return +def format_players(players: Dict[str, str]): + player_embed = Embed(title="Players") + for player, character in players.items(): + player_embed.add_field(name=player, value=character) + + return player_embed + + def launch_bot(config: DiscordBotConfig): global client diff --git a/taleweave/main.py b/taleweave/main.py index bd439df..503c4d2 100644 --- a/taleweave/main.py +++ b/taleweave/main.py @@ -48,12 +48,7 @@ if True: from taleweave.models.prompt import PromptLibrary from taleweave.plugins import load_plugin from taleweave.simulate import simulate_world - from taleweave.state import ( - MEMORY_LIMIT, - create_agents, - save_world, - save_world_state, - ) + from taleweave.state import create_agents, save_world, save_world_state from taleweave.utils.prompt import format_prompt # start the debugger, if needed @@ -65,9 +60,6 @@ if environ.get("DEBUG", "false").lower() == "true": debugpy.wait_for_client() -memory_factory = partial(make_limited_memory, limit=MEMORY_LIMIT) - - def int_or_inf(value: str) -> float | int: if value == "inf": return float("inf") @@ -241,7 +233,7 @@ def save_system_data(args, systems: List[GameSystem]): def load_or_generate_world( - args, players, systems: List[GameSystem], world_prompt: WorldPrompt + args, config: Config, players, systems: List[GameSystem], world_prompt: WorldPrompt ): world_file = args.world + ".json" world_state_file = args.state or (args.world + ".state.json") @@ -252,6 +244,9 @@ def load_or_generate_world( # prepare an agent for the world builder llm = agent_easy_connect() + memory_factory = partial( + make_limited_memory, limit=config.world.character.memory_limit + ) world_builder = Agent( "World Builder", format_prompt( @@ -388,7 +383,7 @@ def main(): # load or generate the world world_prompt = get_world_prompt(args) world, world_state_file, world_turn = load_or_generate_world( - args, players, extra_systems, world_prompt=world_prompt + args, config, players, extra_systems, world_prompt=world_prompt ) # make sure the snapshot system runs last @@ -404,6 +399,9 @@ def main(): # create the DM llm = agent_easy_connect() + memory_factory = partial( + make_limited_memory, limit=config.world.character.memory_limit + ) world_builder = Agent( "dungeon master", format_prompt( diff --git a/taleweave/models/config.py b/taleweave/models/config.py index 1ce28e2..f780b8b 100644 --- a/taleweave/models/config.py +++ b/taleweave/models/config.py @@ -57,6 +57,7 @@ class SystemsConfig: class WorldCharacterConfig: conversation_limit: int event_limit: int + memory_limit: int note_limit: int @@ -122,6 +123,7 @@ DEFAULT_CONFIG = Config( character=WorldCharacterConfig( conversation_limit=2, event_limit=5, + memory_limit=25, note_limit=10, ), size=WorldSizeConfig(