1
0
Fork 0

add discord bot and render to comfy

This commit is contained in:
Sean Sube 2024-05-07 20:42:10 -05:00
parent 5117db7150
commit d72c1326f1
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 555 additions and 24 deletions

276
adventure/discord_bot.py Normal file
View File

@ -0,0 +1,276 @@
# from functools import cache
from json import loads
from logging import getLogger
from os import environ
from queue import Queue
from re import sub
from threading import Thread
from typing import Literal
from discord import Client, Embed, File, Intents
from packit.utils import could_be_json
from adventure.context import (
get_actor_agent_for_name,
get_current_world,
set_actor_agent_for_name,
)
from adventure.models import Actor, Room
from adventure.player import RemotePlayer, get_player, has_player, set_player
from adventure.render_comfy import generate_image_tool
logger = getLogger(__name__)
client = None
prompt_queue: Queue = Queue()
def remove_tags(text: str) -> str:
"""
Remove any <foo> tags.
"""
return sub(r"<[^>]*>", "", text)
class AdventureClient(Client):
async def on_ready(self):
logger.info(f"Logged in as {self.user}")
async def on_reaction_add(self, reaction, user):
if user == self.user:
return
logger.info(f"Reaction added: {reaction} by {user}")
if reaction.emoji == "📷":
# message_id = reaction.message.id
# TODO: look up event that caused this message, get the room and actors
if len(reaction.message.embeds) > 0:
embed = reaction.message.embeds[0]
room_name = embed.title
actor_name = embed.description
prompt = f"{room_name}. {actor_name}."
await reaction.message.channel.send(f"Generating image for: {prompt}")
world = get_current_world()
if not world:
return
room = next(
(room for room in world.rooms if room.name == room_name), None
)
if not room:
return
actor = next(
(actor for actor in room.actors if actor.name == actor_name), None
)
if not actor:
return
prompt = f"{room.name}. {actor.name}."
else:
prompt = remove_tags(reaction.message.content)
paths = generate_image_tool(prompt, 2)
logger.info(f"Generated images: {paths}")
files = [File(filename) for filename in paths]
await reaction.message.channel.send(files=files, reference=reaction.message)
async def on_message(self, message):
if message.author == self.user:
return
author = message.author
channel = message.channel
user_name = author.name # include nick
world = get_current_world()
if world:
active_world = f"Active world: {world.name} (theme: {world.theme})"
else:
active_world = "No active world"
if message.content.startswith("!adventure"):
await message.channel.send(f"Hello! Welcome to Adventure! {active_world}")
return
if message.content.startswith("!help"):
await message.channel.send("Type `!join` to start playing!")
return
if message.content.startswith("!join"):
character_name = remove_tags(message.content).replace("!join", "").strip()
if has_player(character_name):
await channel.send(f"{character_name} has already been taken!")
return
actor, agent = get_actor_agent_for_name(character_name)
if not actor:
await channel.send(f"Character `{character_name}` not found!")
return
def prompt_player(character: str, prompt: str):
logger.info(
"append prompt for character %s (user %s) to queue: %s",
character,
user_name,
prompt,
)
prompt_queue.put((character, prompt))
return True
player = RemotePlayer(
actor.name, actor.backstory, prompt_player, fallback_agent=agent
)
set_actor_agent_for_name(character_name, actor, player)
set_player(user_name, player)
logger.info(f"{user_name} has joined the game as {actor.name}!")
await message.channel.send(
f"{user_name} has joined the game as {actor.name}!"
)
return
if message.content.startswith("!leave"):
# TODO: revert to LLM agent
logger.info(f"{user_name} has left the game!")
await message.channel.send(f"{user_name} has left the game!")
return
player = get_player(user_name)
if player and isinstance(player, RemotePlayer):
content = remove_tags(message.content)
player.input_queue.put(content)
logger.info(
f"Received message from {user_name} for {player.name}: {content}"
)
return
await message.channel.send(
"You are not currently playing Adventure! Type `!join` to start playing!"
)
return
active_tasks = set()
def launch_bot():
def bot_main():
global client
intents = Intents.default()
# intents.message_content = True
client = AdventureClient(intents=intents)
client.run(environ["DISCORD_TOKEN"])
def prompt_main():
from time import sleep
while True:
sleep(0.5)
if prompt_queue.empty():
continue
if len(active_tasks) > 0:
continue
character, prompt = prompt_queue.get()
logger.info("Prompting character %s: %s", character, prompt)
if client:
prompt_task = client.loop.create_task(broadcast_event(prompt))
active_tasks.add(prompt_task)
prompt_task.add_done_callback(active_tasks.discard)
bot_thread = Thread(target=bot_main)
bot_thread.start()
prompt_thread = Thread(target=prompt_main)
prompt_thread.start()
def stop_bot():
global client
if client:
client.close()
client = None
# @cache
def get_active_channels():
if not client:
return []
# return client.private_channels
return [
channel
for guild in client.guilds
for channel in guild.text_channels
if channel.name == "bots"
]
async def broadcast_event(message: str | Embed):
if not client:
logger.warning("No Discord client available")
return
active_channels = get_active_channels()
if not active_channels:
logger.warning("No active channels")
return
for channel in active_channels:
if isinstance(message, str):
logger.info("Broadcasting to channel %s: %s", channel, message)
await channel.send(content=message)
elif isinstance(message, Embed):
logger.info(
"Broadcasting to channel %s: %s - %s",
channel,
message.title,
message.description,
)
await channel.send(embed=message)
def bot_action(room: Room, actor: Actor, message: str):
try:
action_embed = Embed(title=room.name, description=actor.name)
if could_be_json(message):
action_data = loads(message)
action_name = action_data["function"].replace("action_", "").title()
action_parameters = action_data.get("parameters", {})
action_embed.add_field(name="Action", value=action_name)
for key, value in action_parameters.items():
action_embed.add_field(name=key.replace("_", " ").title(), value=value)
else:
action_embed.add_field(name="Message", value=message)
prompt_queue.put((actor.name, action_embed))
except Exception as e:
logger.error("Failed to broadcast action: %s", e)
def bot_event(message: str):
prompt_queue.put((None, message))
def bot_result(room: Room, actor: Actor, action: str):
result_embed = Embed(title=room.name, description=actor.name)
result_embed.add_field(name="Result", value=action)
prompt_queue.put((actor.name, result_embed))
def player_event(character: str, id: str, event: Literal["join", "leave"]):
if event == "join":
prompt_queue.put((character, f"{character} has joined the game!"))
elif event == "leave":
prompt_queue.put((character, f"{character} has left the game!"))

View File

@ -267,6 +267,14 @@ def main():
input_callbacks = []
result_callbacks = []
if args.discord:
from adventure.discord_bot import bot_action, bot_event, bot_result, launch_bot
launch_bot()
event_callbacks.append(bot_event)
input_callbacks.append(bot_action)
result_callbacks.append(bot_result)
if args.server:
from adventure.server import (
launch_server,

View File

@ -2,7 +2,7 @@ from json import dumps
from logging import getLogger
from queue import Queue
from readline import add_history
from typing import Any, Callable, Dict, List, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from packit.agent import Agent
@ -11,6 +11,50 @@ from packit.utils import could_be_json
logger = getLogger(__name__)
# Dict[client, player]
active_players: Dict[str, "BasePlayer"] = {}
def get_player(client: str) -> Optional["BasePlayer"]:
"""
Get a player by name.
"""
return active_players.get(client, None)
def set_player(client: str, player: "BasePlayer"):
"""
Add a player to the active players.
"""
if has_player(player.name):
raise ValueError(f"Someone is already playing as {player.name}!")
active_players[client] = player
def remove_player(client: str):
"""
Remove a player from the active players.
"""
if client in active_players:
del active_players[client]
def has_player(character_name: str) -> bool:
"""
Check if a character is already being played.
"""
return character_name in [player.name for player in active_players.values()]
def list_players():
return {client: player.name for client, player in active_players.items()}
class BasePlayer:
"""
A human agent that can interact with the world.

195
adventure/render_comfy.py Normal file
View File

@ -0,0 +1,195 @@
# This is an example that uses the websockets api to know when a prompt execution is done
# Once the prompt execution is done it downloads the images using the /history endpoint
import io
import json
import urllib.parse
import urllib.request
import uuid
from logging import getLogger
from os import environ, path
from random import choice, randint
from typing import List
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
from PIL import Image
logger = getLogger(__name__)
server_address = environ["COMFY_API"]
client_id = str(uuid.uuid4())
def generate_cfg():
return randint(5, 8)
def generate_steps():
return 30
def generate_batches(
count: int,
batch_size: int = 3,
) -> List[int]:
"""
Generate count images in batches of at most batch_size.
"""
batches = []
for i in range(0, count, batch_size):
batches.append(min(count - i, batch_size))
return batches
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(
"http://{}/view?{}".format(server_address, url_values)
) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen(
"http://{}/history/{}".format(server_address, prompt_id)
) as response:
return json.loads(response.read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)["prompt_id"]
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
images_output = []
for image in node_output["images"]:
image_data = get_image(
image["filename"], image["subfolder"], image["type"]
)
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
def generate_image_tool(prompt, count, size="landscape"):
output_paths = []
for i, count in enumerate(generate_batches(count)):
results = generate_images(prompt, count, size, prefix=f"output-{i}")
output_paths.extend(results)
return output_paths
sizes = {
"landscape": (1024, 768),
"portrait": (768, 1024),
"square": (768, 768),
}
def generate_images(
prompt: str, count: int, size="landscape", prefix="output"
) -> List[str]:
cfg = generate_cfg()
width, height = sizes.get(size, (512, 512))
steps = generate_steps()
seed = randint(0, 10000000)
checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"])
logger.info(
"generating %s images at %s by %s with prompt: %s", count, width, height, prompt
)
# parsing here helps ensure the template emits valid JSON
prompt_workflow = {
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": cfg,
"denoise": 1,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"seed": seed,
"steps": steps,
},
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {"ckpt_name": checkpoint},
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {"batch_size": count, "height": height, "width": width},
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {"text": prompt, "clip": ["4", 1]},
},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}},
"8": {
"class_type": "VAEDecode",
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
},
"9": {
"class_type": "SaveImage",
"inputs": {"filename_prefix": prefix, "images": ["8", 0]},
},
}
logger.debug("Connecting to Comfy API at %s", server_address)
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt_workflow)
results = []
for node_id in images:
for image_data in images[node_id]:
image = Image.open(io.BytesIO(image_data))
results.append(image)
paths: List[str] = []
for j, image in enumerate(results):
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
with open(image_path, "wb") as f:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
f.write(image_bytes.getvalue())
paths.append(image_path)
return paths
if __name__ == "__main__":
paths = generate_images(
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
)
logger.info("Generated %d images: %s", len(paths), paths)

View File

@ -3,20 +3,26 @@ from collections import deque
from json import dumps, loads
from logging import getLogger
from threading import Thread
from typing import Dict, Literal
from typing import Literal
from uuid import uuid4
import websockets
from adventure.context import get_actor_agent_for_name, set_actor_agent_for_name
from adventure.models import Actor, Room, World
from adventure.player import RemotePlayer
from adventure.player import (
RemotePlayer,
get_player,
has_player,
list_players,
remove_player,
set_player,
)
from adventure.state import snapshot_world, world_json
logger = getLogger(__name__)
connected = set()
characters: Dict[str, RemotePlayer] = {}
recent_events = deque(maxlen=100)
recent_world = None
@ -40,11 +46,12 @@ async def handler(websocket):
)
def sync_turn(character: str, prompt: str) -> bool:
if id not in characters:
return False
player = get_player(id)
if player and player.name == character:
asyncio.run(next_turn(character, prompt))
return True
asyncio.run(next_turn(character, prompt))
return True
return False
try:
await websocket.send(dumps({"type": "id", "id": id}))
@ -67,9 +74,8 @@ async def handler(websocket):
data = loads(message)
message_type = data.get("type", None)
if message_type == "player":
character = characters.get(id)
if character:
del characters[id]
# TODO: should this always remove?
remove_player(id)
character_name = data["become"]
actor, llm_agent = get_actor_agent_for_name(character_name)
@ -84,9 +90,7 @@ async def handler(websocket):
)
llm_agent = llm_agent.fallback_agent
if character_name in [
player.name for player in characters.values()
]:
if has_player(character_name):
logger.error(f"Character {character_name} is already in use")
continue
@ -94,7 +98,7 @@ async def handler(websocket):
player = RemotePlayer(
actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent
)
characters[id] = player
set_player(id, player)
logger.info(f"Client {id} is now character {character_name}")
# swap out the LLM agent
@ -103,10 +107,13 @@ async def handler(websocket):
# notify all clients that this character is now active
player_event(character_name, id, "join")
player_list()
elif message_type == "input" and id in characters:
player = characters[id]
logger.info("queueing input for player %s: %s", player.name, data)
player.input_queue.put(data["input"])
elif message_type == "input":
player = get_player(id)
if player and isinstance(player, RemotePlayer):
logger.info(
"queueing input for player %s: %s", player.name, data
)
player.input_queue.put(data["input"])
except Exception:
logger.exception("Failed to parse message")
@ -116,9 +123,9 @@ async def handler(websocket):
connected.remove(websocket)
# swap out the character for the original agent when they disconnect
if id in characters:
player = characters[id]
del characters[id]
player = get_player(id)
if player and isinstance(player, RemotePlayer):
remove_player(id)
logger.info("Disconnecting player for %s", player.name)
player_event(player.name, id, "leave")
@ -217,8 +224,9 @@ def player_event(character: str, id: str, event: Literal["join", "leave"]):
def player_list():
json_broadcast ={
players = {value: key for key, value in list_players()}
json_broadcast = {
"type": "players",
"players": {player.name: player_id for player_id, player in characters.items()},
"players": players,
}
send_and_append(json_broadcast)