1
0
Fork 0
taleweave-ai/taleweave/bot/discord.py

481 lines
15 KiB
Python
Raw Permalink Normal View History

2024-05-08 01:42:10 +00:00
from logging import getLogger
from os import environ
from queue import Queue
from re import sub
from threading import Thread
from typing import Dict
2024-05-08 01:42:10 +00:00
from discord import Client, Embed, File, Intents
from taleweave.context import (
broadcast,
get_character_agent_for_name,
2024-05-08 01:42:10 +00:00
get_current_world,
get_game_config,
set_character_agent,
2024-05-18 21:58:11 +00:00
subscribe,
2024-05-09 02:11:16 +00:00
)
from taleweave.models.config import DiscordBotConfig
from taleweave.models.event import (
2024-05-09 02:11:16 +00:00
ActionEvent,
GameEvent,
GenerateEvent,
PlayerEvent,
2024-05-09 02:11:16 +00:00
PromptEvent,
RenderEvent,
2024-05-09 02:11:16 +00:00
ReplyEvent,
ResultEvent,
StatusEvent,
2024-05-08 01:42:10 +00:00
)
from taleweave.player import (
RemotePlayer,
get_player,
has_player,
list_players,
remove_player,
set_player,
)
from taleweave.render.comfy import render_event
2024-06-04 03:00:14 +00:00
from taleweave.utils.search import list_characters
from taleweave.utils.template import format_prompt
2024-05-08 01:42:10 +00:00
logger = getLogger(__name__)
client = None
2024-05-09 02:11:16 +00:00
active_tasks = set()
event_messages: Dict[int, str | GameEvent] = {}
event_queue: Queue[GameEvent] = Queue()
player_mentions: Dict[str, str] = {}
2024-05-08 01:42:10 +00:00
def remove_tags(text: str) -> str:
"""
Remove any <foo> tags.
"""
return sub(r"<[^>]*>", "", text).strip()
2024-05-08 01:42:10 +00:00
class AdventureClient(Client):
async def on_ready(self):
logger.info(f"logged in as {self.user}")
2024-05-08 01:42:10 +00:00
async def on_reaction_add(self, reaction, user):
if user == self.user:
return
logger.info(f"reaction added: {reaction} by {user}")
2024-05-08 01:42:10 +00:00
if reaction.emoji == "📷":
message_id = reaction.message.id
if message_id not in event_messages:
logger.warning(f"message {message_id} not found in event messages")
await reaction.message.add_reaction("")
return
2024-05-08 01:42:10 +00:00
event = event_messages[message_id]
if isinstance(event, GameEvent):
render_event(event)
await reaction.message.add_reaction("📸")
2024-05-08 01:42:10 +00:00
async def on_message(self, message):
if message.author == self.user:
return
# make sure the message was in a valid channel
active_channels = get_active_channels()
if message.channel not in active_channels:
return
# get message contents
config = get_game_config()
2024-05-08 01:42:10 +00:00
author = message.author
channel = message.channel
user_name = author.name # include nick
content = remove_tags(message.content)
2024-05-08 01:42:10 +00:00
if content.startswith(
config.bot.discord.command_prefix + config.bot.discord.name_command
):
world = get_current_world()
if world:
2024-05-31 23:58:01 +00:00
world_message = format_prompt(
"discord_world_active",
bot_name=config.bot.discord.name_title,
world=world,
2024-05-31 23:58:01 +00:00
)
else:
2024-05-31 23:58:01 +00:00
world_message = format_prompt(
"discord_world_none", bot_name=config.bot.discord.name_title
2024-05-31 23:58:01 +00:00
)
2024-05-31 23:58:01 +00:00
await message.channel.send(world_message)
2024-05-08 01:42:10 +00:00
return
2024-06-04 03:00:14 +00:00
if content.startswith(config.bot.discord.command_prefix + "help"):
2024-05-31 23:58:01 +00:00
await message.channel.send(
format_prompt("discord_help", bot_name=config.bot.discord.name_command)
2024-05-31 23:58:01 +00:00
)
2024-05-08 01:42:10 +00:00
return
2024-06-04 03:00:14 +00:00
if content.startswith(config.bot.discord.command_prefix + "join"):
character_name = content.replace("!join", "").strip()
2024-05-08 01:42:10 +00:00
if has_player(character_name):
2024-05-31 23:58:01 +00:00
await channel.send(
format_prompt("discord_join_error_taken", character=character_name)
)
2024-05-08 01:42:10 +00:00
return
character, agent = get_character_agent_for_name(character_name)
if not character:
2024-05-31 23:58:01 +00:00
await channel.send(
format_prompt(
"discord_join_error_not_found", character=character_name
)
)
2024-05-08 01:42:10 +00:00
return
2024-05-09 02:11:16 +00:00
def prompt_player(event: PromptEvent):
2024-05-08 01:42:10 +00:00
logger.info(
"append prompt for character %s (user %s) to queue: %s",
event.character.name,
2024-05-08 01:42:10 +00:00
user_name,
2024-05-09 02:11:16 +00:00
event.prompt,
2024-05-08 01:42:10 +00:00
)
2024-05-09 02:11:16 +00:00
event_queue.put(event)
2024-05-08 01:42:10 +00:00
return True
player = RemotePlayer(
character.name, character.backstory, prompt_player, fallback_agent=agent
2024-05-08 01:42:10 +00:00
)
set_character_agent(character_name, character, player)
2024-05-08 01:42:10 +00:00
set_player(user_name, player)
player_mentions[user_name] = author.mention
2024-05-08 01:42:10 +00:00
logger.info(f"{user_name} has joined the game as {character.name}!")
join_event = PlayerEvent("join", character_name, user_name)
return broadcast(join_event)
2024-05-08 01:42:10 +00:00
2024-06-04 03:00:14 +00:00
if content.startswith(config.bot.discord.command_prefix + "characters"):
world = get_current_world()
if not world:
await channel.send(
format_prompt(
"discord_characters_none",
bot_name=config.bot.discord.name_title,
)
)
return
characters = [character.name for character in list_characters(world)]
await channel.send(
format_prompt(
"discord_characters_list",
bot_name=config.bot.discord.name_title,
characters=characters,
)
)
return
if content.startswith(config.bot.discord.command_prefix + "players"):
players = list_players()
await channel.send(embed=format_players(players))
return
2024-05-08 01:42:10 +00:00
player = get_player(user_name)
if isinstance(player, RemotePlayer):
2024-06-04 03:00:14 +00:00
if content.startswith(config.bot.discord.command_prefix + "leave"):
remove_player(user_name)
if user_name in player_mentions:
del player_mentions[user_name]
# revert to LLM agent
character, _ = get_character_agent_for_name(player.name)
if character and player.fallback_agent:
logger.info("restoring LLM agent for %s", player.name)
set_character_agent(
character.name, character, player.fallback_agent
)
# broadcast leave event
logger.info("disconnecting player %s from %s", user_name, player.name)
leave_event = PlayerEvent("leave", player.name, user_name)
return broadcast(leave_event)
else:
player.input_queue.put(content)
logger.info(
f"received message from {user_name} for {player.name}: {content}"
)
return
2024-05-08 01:42:10 +00:00
2024-05-31 23:58:01 +00:00
await message.channel.send(format_prompt("discord_user_new"))
2024-05-08 01:42:10 +00:00
return
def format_players(players: Dict[str, str]):
player_embed = Embed(title="Players")
for player, character in players.items():
player_embed.add_field(name=player, value=character)
return player_embed
def launch_bot(config: DiscordBotConfig):
global client
# message contents need to be enabled for multi-server bots
intents = Intents.default()
if config.content_intent:
intents.message_content = True
client = AdventureClient(intents=intents)
2024-05-08 01:42:10 +00:00
def bot_main():
if not client:
raise ValueError("No Discord client available")
2024-05-08 01:42:10 +00:00
client.run(environ["DISCORD_TOKEN"])
def send_main():
# from time import sleep
2024-05-08 01:42:10 +00:00
while True:
# sleep(0.05)
# if event_queue.empty():
# logger.debug("no events to prompt")
# continue
2024-05-08 01:42:10 +00:00
# wait for pending messages to send, to keep them in order
2024-05-08 01:42:10 +00:00
if len(active_tasks) > 0:
# logger.debug("waiting for active tasks to complete")
2024-05-08 01:42:10 +00:00
continue
event = event_queue.get()
logger.debug("broadcasting %s event", event.type)
2024-05-08 01:42:10 +00:00
if client:
event_task = client.loop.create_task(broadcast_event(event))
active_tasks.add(event_task)
event_task.add_done_callback(active_tasks.discard)
else:
logger.warning("no Discord client available")
2024-05-08 01:42:10 +00:00
logger.info("launching Discord bot")
2024-05-09 02:11:16 +00:00
bot_thread = Thread(target=bot_main, daemon=True)
2024-05-08 01:42:10 +00:00
bot_thread.start()
send_thread = Thread(target=send_main, daemon=True)
send_thread.start()
2024-05-08 01:42:10 +00:00
2024-05-18 21:58:11 +00:00
subscribe(GameEvent, bot_event)
return [bot_thread, send_thread]
2024-05-09 02:11:16 +00:00
2024-05-08 01:42:10 +00:00
def stop_bot():
global client
if client:
close_task = client.loop.create_task(client.close())
active_tasks.add(close_task)
def on_close_task_done(future):
logger.info("discord client closed")
active_tasks.discard(future)
close_task.add_done_callback(on_close_task_done)
2024-05-08 01:42:10 +00:00
client = None
# @cache
def get_active_channels():
if not client:
return []
config = get_game_config()
2024-05-08 01:42:10 +00:00
# return client.private_channels
return [
channel
for guild in client.guilds
for channel in guild.text_channels
if channel.name in config.bot.discord.channels
2024-05-08 01:42:10 +00:00
]
def bot_event(event: GameEvent):
event_queue.put(event)
async def broadcast_event(message: str | GameEvent):
2024-05-08 01:42:10 +00:00
if not client:
logger.warning("no Discord client available")
2024-05-08 01:42:10 +00:00
return
active_channels = get_active_channels()
if not active_channels:
logger.warning("no active channels")
2024-05-08 01:42:10 +00:00
return
for channel in active_channels:
if isinstance(message, str):
# deprecated, use events instead
logger.warning(
"broadcasting non-event message to channel %s: %s", channel, message
)
event_message = await channel.send(content=message)
elif isinstance(message, RenderEvent):
# special handling to upload images
# find the source event
source_event_id = message.source.id
source_message_id = next(
(
message_id
for message_id, event in event_messages.items()
if isinstance(event, GameEvent) and event.id == source_event_id
),
None,
)
if not source_message_id:
logger.warning("source event not found: %s", source_event_id)
return
# open and upload images
files = [File(filename) for filename in message.paths]
try:
source_message = await channel.fetch_message(source_message_id)
except Exception as err:
logger.warning("source message not found: %s", err)
return
# send the images as a reply to the source message
event_message = await source_message.channel.send(
files=files, reference=source_message
)
else:
embed = embed_from_event(message)
if not embed:
logger.warning("no embed for event: %s", message)
return
2024-05-08 01:42:10 +00:00
logger.info(
"broadcasting to channel %s: %s - %s",
2024-05-08 01:42:10 +00:00
channel,
embed.title,
embed.description,
2024-05-08 01:42:10 +00:00
)
event_message = await channel.send(embed=embed)
event_messages[event_message.id] = message
2024-05-08 01:42:10 +00:00
def truncate(text: str, length: int = 1000) -> str:
if len(text) > length:
return text[:length] + "..."
return text
def embed_from_event(event: GameEvent) -> Embed | None:
2024-05-09 02:11:16 +00:00
if isinstance(event, GenerateEvent):
return embed_from_generate(event)
2024-05-09 02:11:16 +00:00
elif isinstance(event, ResultEvent):
return embed_from_result(event)
2024-05-31 23:58:01 +00:00
elif isinstance(event, ActionEvent):
return embed_from_action(event)
2024-05-31 23:58:01 +00:00
elif isinstance(event, ReplyEvent):
return embed_from_reply(event)
2024-05-09 02:11:16 +00:00
elif isinstance(event, StatusEvent):
return embed_from_status(event)
elif isinstance(event, PlayerEvent):
return embed_from_player(event)
elif isinstance(event, PromptEvent):
return embed_from_prompt(event)
2024-05-09 02:11:16 +00:00
else:
logger.warning("unknown event type: %s", event)
2024-05-09 02:11:16 +00:00
2024-05-31 23:58:01 +00:00
def embed_from_action(event: ActionEvent):
action_embed = Embed(title=event.room.name, description=event.character.name)
action_name = event.action.replace("action_", "").title()
action_parameters = event.parameters
2024-05-08 01:42:10 +00:00
2024-05-31 23:58:01 +00:00
action_embed.add_field(name="Action", value=action_name)
2024-05-08 01:42:10 +00:00
2024-05-31 23:58:01 +00:00
for key, value in action_parameters.items():
action_embed.add_field(name=key.replace("_", " ").title(), value=value)
2024-05-08 01:42:10 +00:00
return action_embed
2024-05-08 01:42:10 +00:00
2024-05-31 23:58:01 +00:00
def embed_from_reply(event: ReplyEvent):
reply_embed = Embed(title=event.room.name, description=event.speaker.name)
reply_embed.add_field(name="Reply", value=truncate(event.text))
2024-05-31 23:58:01 +00:00
return reply_embed
def embed_from_generate(event: GenerateEvent) -> Embed:
generate_embed = Embed(title="Generating", description=event.name)
return generate_embed
2024-05-08 01:42:10 +00:00
def embed_from_result(event: ResultEvent):
result_embed = Embed(title=event.room.name, description=event.character.name)
result_embed.add_field(name="Result", value=truncate(event.result))
return result_embed
def embed_from_player(event: PlayerEvent):
if event.status == "join":
2024-05-31 23:58:01 +00:00
title = format_prompt("discord_join_title", event=event)
description = format_prompt("discord_join_result", event=event)
else:
2024-05-31 23:58:01 +00:00
title = format_prompt("discord_leave_title", event=event)
description = format_prompt("discord_leave_result", event=event)
player_embed = Embed(title=title, description=truncate(description))
return player_embed
def embed_from_prompt(event: PromptEvent):
prompt_embed = Embed(title=event.room.name, description=event.character.name)
prompt_embed.add_field(name="Prompt", value=truncate(event.prompt))
if has_player(event.character.name):
players = list_players()
user = next(
(
player
for player, character in players.items()
if character == event.character.name
),
None,
)
if user:
2024-06-14 03:37:41 +00:00
# use Discord user.mention to ping the user
if user in player_mentions:
user = player_mentions[user]
prompt_embed.add_field(
name="Player",
value=user,
)
for action in event.actions:
# TODO: use a prompt template to summarize actions
action_name = action["function"]["name"]
action_description = action["function"]["description"]
prompt_embed.add_field(name=action_name, value=action_description)
return prompt_embed
def embed_from_status(event: StatusEvent):
status_embed = Embed(
title=event.room.name if event.room else "",
description=event.character.name if event.character else "",
)
status_embed.add_field(name="Status", value=truncate(event.text))
return status_embed