From d72c1326f14373ce60b94cac4fe005bb72d3e389 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 7 May 2024 20:42:10 -0500 Subject: [PATCH] add discord bot and render to comfy --- adventure/discord_bot.py | 276 ++++++++++++++++++++++++++++++++++++++ adventure/main.py | 8 ++ adventure/player.py | 46 ++++++- adventure/render_comfy.py | 195 +++++++++++++++++++++++++++ adventure/server.py | 54 ++++---- 5 files changed, 555 insertions(+), 24 deletions(-) create mode 100644 adventure/discord_bot.py create mode 100644 adventure/render_comfy.py diff --git a/adventure/discord_bot.py b/adventure/discord_bot.py new file mode 100644 index 0000000..eb6807f --- /dev/null +++ b/adventure/discord_bot.py @@ -0,0 +1,276 @@ +# from functools import cache +from json import loads +from logging import getLogger +from os import environ +from queue import Queue +from re import sub +from threading import Thread +from typing import Literal + +from discord import Client, Embed, File, Intents +from packit.utils import could_be_json + +from adventure.context import ( + get_actor_agent_for_name, + get_current_world, + set_actor_agent_for_name, +) +from adventure.models import Actor, Room +from adventure.player import RemotePlayer, get_player, has_player, set_player +from adventure.render_comfy import generate_image_tool + +logger = getLogger(__name__) +client = None +prompt_queue: Queue = Queue() + + +def remove_tags(text: str) -> str: + """ + Remove any tags. + """ + + return sub(r"<[^>]*>", "", text) + + +class AdventureClient(Client): + async def on_ready(self): + logger.info(f"Logged in as {self.user}") + + async def on_reaction_add(self, reaction, user): + if user == self.user: + return + + logger.info(f"Reaction added: {reaction} by {user}") + if reaction.emoji == "📷": + # message_id = reaction.message.id + # TODO: look up event that caused this message, get the room and actors + if len(reaction.message.embeds) > 0: + embed = reaction.message.embeds[0] + room_name = embed.title + actor_name = embed.description + prompt = f"{room_name}. {actor_name}." + await reaction.message.channel.send(f"Generating image for: {prompt}") + + world = get_current_world() + if not world: + return + + room = next( + (room for room in world.rooms if room.name == room_name), None + ) + if not room: + return + + actor = next( + (actor for actor in room.actors if actor.name == actor_name), None + ) + if not actor: + return + + prompt = f"{room.name}. {actor.name}." + else: + prompt = remove_tags(reaction.message.content) + + paths = generate_image_tool(prompt, 2) + logger.info(f"Generated images: {paths}") + + files = [File(filename) for filename in paths] + await reaction.message.channel.send(files=files, reference=reaction.message) + + async def on_message(self, message): + if message.author == self.user: + return + + author = message.author + channel = message.channel + user_name = author.name # include nick + + world = get_current_world() + if world: + active_world = f"Active world: {world.name} (theme: {world.theme})" + else: + active_world = "No active world" + + if message.content.startswith("!adventure"): + await message.channel.send(f"Hello! Welcome to Adventure! {active_world}") + return + + if message.content.startswith("!help"): + await message.channel.send("Type `!join` to start playing!") + return + + if message.content.startswith("!join"): + character_name = remove_tags(message.content).replace("!join", "").strip() + if has_player(character_name): + await channel.send(f"{character_name} has already been taken!") + return + + actor, agent = get_actor_agent_for_name(character_name) + if not actor: + await channel.send(f"Character `{character_name}` not found!") + return + + def prompt_player(character: str, prompt: str): + logger.info( + "append prompt for character %s (user %s) to queue: %s", + character, + user_name, + prompt, + ) + prompt_queue.put((character, prompt)) + return True + + player = RemotePlayer( + actor.name, actor.backstory, prompt_player, fallback_agent=agent + ) + set_actor_agent_for_name(character_name, actor, player) + set_player(user_name, player) + + logger.info(f"{user_name} has joined the game as {actor.name}!") + await message.channel.send( + f"{user_name} has joined the game as {actor.name}!" + ) + return + + if message.content.startswith("!leave"): + # TODO: revert to LLM agent + logger.info(f"{user_name} has left the game!") + await message.channel.send(f"{user_name} has left the game!") + return + + player = get_player(user_name) + if player and isinstance(player, RemotePlayer): + content = remove_tags(message.content) + player.input_queue.put(content) + logger.info( + f"Received message from {user_name} for {player.name}: {content}" + ) + return + + await message.channel.send( + "You are not currently playing Adventure! Type `!join` to start playing!" + ) + return + + +active_tasks = set() + + +def launch_bot(): + def bot_main(): + global client + + intents = Intents.default() + # intents.message_content = True + + client = AdventureClient(intents=intents) + client.run(environ["DISCORD_TOKEN"]) + + def prompt_main(): + from time import sleep + + while True: + sleep(0.5) + if prompt_queue.empty(): + continue + + if len(active_tasks) > 0: + continue + + character, prompt = prompt_queue.get() + logger.info("Prompting character %s: %s", character, prompt) + + if client: + prompt_task = client.loop.create_task(broadcast_event(prompt)) + active_tasks.add(prompt_task) + prompt_task.add_done_callback(active_tasks.discard) + + bot_thread = Thread(target=bot_main) + bot_thread.start() + + prompt_thread = Thread(target=prompt_main) + prompt_thread.start() + + +def stop_bot(): + global client + + if client: + client.close() + client = None + + +# @cache +def get_active_channels(): + if not client: + return [] + + # return client.private_channels + return [ + channel + for guild in client.guilds + for channel in guild.text_channels + if channel.name == "bots" + ] + + +async def broadcast_event(message: str | Embed): + if not client: + logger.warning("No Discord client available") + return + + active_channels = get_active_channels() + if not active_channels: + logger.warning("No active channels") + return + + for channel in active_channels: + if isinstance(message, str): + logger.info("Broadcasting to channel %s: %s", channel, message) + await channel.send(content=message) + elif isinstance(message, Embed): + logger.info( + "Broadcasting to channel %s: %s - %s", + channel, + message.title, + message.description, + ) + await channel.send(embed=message) + + +def bot_action(room: Room, actor: Actor, message: str): + try: + action_embed = Embed(title=room.name, description=actor.name) + + if could_be_json(message): + action_data = loads(message) + action_name = action_data["function"].replace("action_", "").title() + action_parameters = action_data.get("parameters", {}) + + action_embed.add_field(name="Action", value=action_name) + + for key, value in action_parameters.items(): + action_embed.add_field(name=key.replace("_", " ").title(), value=value) + else: + action_embed.add_field(name="Message", value=message) + + prompt_queue.put((actor.name, action_embed)) + except Exception as e: + logger.error("Failed to broadcast action: %s", e) + + +def bot_event(message: str): + prompt_queue.put((None, message)) + + +def bot_result(room: Room, actor: Actor, action: str): + result_embed = Embed(title=room.name, description=actor.name) + result_embed.add_field(name="Result", value=action) + prompt_queue.put((actor.name, result_embed)) + + +def player_event(character: str, id: str, event: Literal["join", "leave"]): + if event == "join": + prompt_queue.put((character, f"{character} has joined the game!")) + elif event == "leave": + prompt_queue.put((character, f"{character} has left the game!")) diff --git a/adventure/main.py b/adventure/main.py index 250cbdc..3b01922 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -267,6 +267,14 @@ def main(): input_callbacks = [] result_callbacks = [] + if args.discord: + from adventure.discord_bot import bot_action, bot_event, bot_result, launch_bot + + launch_bot() + event_callbacks.append(bot_event) + input_callbacks.append(bot_action) + result_callbacks.append(bot_result) + if args.server: from adventure.server import ( launch_server, diff --git a/adventure/player.py b/adventure/player.py index 721e1b5..64dea40 100644 --- a/adventure/player.py +++ b/adventure/player.py @@ -2,7 +2,7 @@ from json import dumps from logging import getLogger from queue import Queue from readline import add_history -from typing import Any, Callable, Dict, List, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from packit.agent import Agent @@ -11,6 +11,50 @@ from packit.utils import could_be_json logger = getLogger(__name__) +# Dict[client, player] +active_players: Dict[str, "BasePlayer"] = {} + + +def get_player(client: str) -> Optional["BasePlayer"]: + """ + Get a player by name. + """ + + return active_players.get(client, None) + + +def set_player(client: str, player: "BasePlayer"): + """ + Add a player to the active players. + """ + + if has_player(player.name): + raise ValueError(f"Someone is already playing as {player.name}!") + + active_players[client] = player + + +def remove_player(client: str): + """ + Remove a player from the active players. + """ + + if client in active_players: + del active_players[client] + + +def has_player(character_name: str) -> bool: + """ + Check if a character is already being played. + """ + + return character_name in [player.name for player in active_players.values()] + + +def list_players(): + return {client: player.name for client, player in active_players.items()} + + class BasePlayer: """ A human agent that can interact with the world. diff --git a/adventure/render_comfy.py b/adventure/render_comfy.py new file mode 100644 index 0000000..bebb14d --- /dev/null +++ b/adventure/render_comfy.py @@ -0,0 +1,195 @@ +# This is an example that uses the websockets api to know when a prompt execution is done +# Once the prompt execution is done it downloads the images using the /history endpoint + +import io +import json +import urllib.parse +import urllib.request +import uuid +from logging import getLogger +from os import environ, path +from random import choice, randint +from typing import List + +import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +from PIL import Image + +logger = getLogger(__name__) + +server_address = environ["COMFY_API"] +client_id = str(uuid.uuid4()) + + +def generate_cfg(): + return randint(5, 8) + + +def generate_steps(): + return 30 + + +def generate_batches( + count: int, + batch_size: int = 3, +) -> List[int]: + """ + Generate count images in batches of at most batch_size. + """ + + batches = [] + for i in range(0, count, batch_size): + batches.append(min(count - i, batch_size)) + + return batches + + +def queue_prompt(prompt): + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode("utf-8") + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + +def get_image(filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen( + "http://{}/view?{}".format(server_address, url_values) + ) as response: + return response.read() + + +def get_history(prompt_id): + with urllib.request.urlopen( + "http://{}/history/{}".format(server_address, prompt_id) + ) as response: + return json.loads(response.read()) + + +def get_images(ws, prompt): + prompt_id = queue_prompt(prompt)["prompt_id"] + output_images = {} + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "executing": + data = message["data"] + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + else: + continue # previews are binary data + + history = get_history(prompt_id)[prompt_id] + for o in history["outputs"]: + for node_id in history["outputs"]: + node_output = history["outputs"][node_id] + if "images" in node_output: + images_output = [] + for image in node_output["images"]: + image_data = get_image( + image["filename"], image["subfolder"], image["type"] + ) + images_output.append(image_data) + output_images[node_id] = images_output + + return output_images + + +def generate_image_tool(prompt, count, size="landscape"): + output_paths = [] + for i, count in enumerate(generate_batches(count)): + results = generate_images(prompt, count, size, prefix=f"output-{i}") + output_paths.extend(results) + + return output_paths + + +sizes = { + "landscape": (1024, 768), + "portrait": (768, 1024), + "square": (768, 768), +} + + +def generate_images( + prompt: str, count: int, size="landscape", prefix="output" +) -> List[str]: + cfg = generate_cfg() + width, height = sizes.get(size, (512, 512)) + steps = generate_steps() + seed = randint(0, 10000000) + checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"]) + logger.info( + "generating %s images at %s by %s with prompt: %s", count, width, height, prompt + ) + + # parsing here helps ensure the template emits valid JSON + prompt_workflow = { + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": cfg, + "denoise": 1, + "latent_image": ["5", 0], + "model": ["4", 0], + "negative": ["7", 0], + "positive": ["6", 0], + "sampler_name": "euler_ancestral", + "scheduler": "normal", + "seed": seed, + "steps": steps, + }, + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": checkpoint}, + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": {"batch_size": count, "height": height, "width": width}, + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": {"text": prompt, "clip": ["4", 1]}, + }, + "7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}}, + "8": { + "class_type": "VAEDecode", + "inputs": {"samples": ["3", 0], "vae": ["4", 2]}, + }, + "9": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": prefix, "images": ["8", 0]}, + }, + } + + logger.debug("Connecting to Comfy API at %s", server_address) + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) + images = get_images(ws, prompt_workflow) + + results = [] + for node_id in images: + for image_data in images[node_id]: + image = Image.open(io.BytesIO(image_data)) + results.append(image) + + paths: List[str] = [] + for j, image in enumerate(results): + image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png") + with open(image_path, "wb") as f: + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + f.write(image_bytes.getvalue()) + + paths.append(image_path) + + return paths + + +if __name__ == "__main__": + paths = generate_images( + "A painting of a beautiful sunset over a calm lake", 3, "landscape" + ) + logger.info("Generated %d images: %s", len(paths), paths) diff --git a/adventure/server.py b/adventure/server.py index 4cdd42c..024b834 100644 --- a/adventure/server.py +++ b/adventure/server.py @@ -3,20 +3,26 @@ from collections import deque from json import dumps, loads from logging import getLogger from threading import Thread -from typing import Dict, Literal +from typing import Literal from uuid import uuid4 import websockets from adventure.context import get_actor_agent_for_name, set_actor_agent_for_name from adventure.models import Actor, Room, World -from adventure.player import RemotePlayer +from adventure.player import ( + RemotePlayer, + get_player, + has_player, + list_players, + remove_player, + set_player, +) from adventure.state import snapshot_world, world_json logger = getLogger(__name__) connected = set() -characters: Dict[str, RemotePlayer] = {} recent_events = deque(maxlen=100) recent_world = None @@ -40,11 +46,12 @@ async def handler(websocket): ) def sync_turn(character: str, prompt: str) -> bool: - if id not in characters: - return False + player = get_player(id) + if player and player.name == character: + asyncio.run(next_turn(character, prompt)) + return True - asyncio.run(next_turn(character, prompt)) - return True + return False try: await websocket.send(dumps({"type": "id", "id": id})) @@ -67,9 +74,8 @@ async def handler(websocket): data = loads(message) message_type = data.get("type", None) if message_type == "player": - character = characters.get(id) - if character: - del characters[id] + # TODO: should this always remove? + remove_player(id) character_name = data["become"] actor, llm_agent = get_actor_agent_for_name(character_name) @@ -84,9 +90,7 @@ async def handler(websocket): ) llm_agent = llm_agent.fallback_agent - if character_name in [ - player.name for player in characters.values() - ]: + if has_player(character_name): logger.error(f"Character {character_name} is already in use") continue @@ -94,7 +98,7 @@ async def handler(websocket): player = RemotePlayer( actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent ) - characters[id] = player + set_player(id, player) logger.info(f"Client {id} is now character {character_name}") # swap out the LLM agent @@ -103,10 +107,13 @@ async def handler(websocket): # notify all clients that this character is now active player_event(character_name, id, "join") player_list() - elif message_type == "input" and id in characters: - player = characters[id] - logger.info("queueing input for player %s: %s", player.name, data) - player.input_queue.put(data["input"]) + elif message_type == "input": + player = get_player(id) + if player and isinstance(player, RemotePlayer): + logger.info( + "queueing input for player %s: %s", player.name, data + ) + player.input_queue.put(data["input"]) except Exception: logger.exception("Failed to parse message") @@ -116,9 +123,9 @@ async def handler(websocket): connected.remove(websocket) # swap out the character for the original agent when they disconnect - if id in characters: - player = characters[id] - del characters[id] + player = get_player(id) + if player and isinstance(player, RemotePlayer): + remove_player(id) logger.info("Disconnecting player for %s", player.name) player_event(player.name, id, "leave") @@ -217,8 +224,9 @@ def player_event(character: str, id: str, event: Literal["join", "leave"]): def player_list(): - json_broadcast ={ + players = {value: key for key, value in list_players()} + json_broadcast = { "type": "players", - "players": {player.name: player_id for player_id, player in characters.items()}, + "players": players, } send_and_append(json_broadcast)