add discord bot and render to comfy
This commit is contained in:
parent
5117db7150
commit
d72c1326f1
|
@ -0,0 +1,276 @@
|
||||||
|
# from functools import cache
|
||||||
|
from json import loads
|
||||||
|
from logging import getLogger
|
||||||
|
from os import environ
|
||||||
|
from queue import Queue
|
||||||
|
from re import sub
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from discord import Client, Embed, File, Intents
|
||||||
|
from packit.utils import could_be_json
|
||||||
|
|
||||||
|
from adventure.context import (
|
||||||
|
get_actor_agent_for_name,
|
||||||
|
get_current_world,
|
||||||
|
set_actor_agent_for_name,
|
||||||
|
)
|
||||||
|
from adventure.models import Actor, Room
|
||||||
|
from adventure.player import RemotePlayer, get_player, has_player, set_player
|
||||||
|
from adventure.render_comfy import generate_image_tool
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
client = None
|
||||||
|
prompt_queue: Queue = Queue()
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tags(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove any <foo> tags.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return sub(r"<[^>]*>", "", text)
|
||||||
|
|
||||||
|
|
||||||
|
class AdventureClient(Client):
|
||||||
|
async def on_ready(self):
|
||||||
|
logger.info(f"Logged in as {self.user}")
|
||||||
|
|
||||||
|
async def on_reaction_add(self, reaction, user):
|
||||||
|
if user == self.user:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Reaction added: {reaction} by {user}")
|
||||||
|
if reaction.emoji == "📷":
|
||||||
|
# message_id = reaction.message.id
|
||||||
|
# TODO: look up event that caused this message, get the room and actors
|
||||||
|
if len(reaction.message.embeds) > 0:
|
||||||
|
embed = reaction.message.embeds[0]
|
||||||
|
room_name = embed.title
|
||||||
|
actor_name = embed.description
|
||||||
|
prompt = f"{room_name}. {actor_name}."
|
||||||
|
await reaction.message.channel.send(f"Generating image for: {prompt}")
|
||||||
|
|
||||||
|
world = get_current_world()
|
||||||
|
if not world:
|
||||||
|
return
|
||||||
|
|
||||||
|
room = next(
|
||||||
|
(room for room in world.rooms if room.name == room_name), None
|
||||||
|
)
|
||||||
|
if not room:
|
||||||
|
return
|
||||||
|
|
||||||
|
actor = next(
|
||||||
|
(actor for actor in room.actors if actor.name == actor_name), None
|
||||||
|
)
|
||||||
|
if not actor:
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt = f"{room.name}. {actor.name}."
|
||||||
|
else:
|
||||||
|
prompt = remove_tags(reaction.message.content)
|
||||||
|
|
||||||
|
paths = generate_image_tool(prompt, 2)
|
||||||
|
logger.info(f"Generated images: {paths}")
|
||||||
|
|
||||||
|
files = [File(filename) for filename in paths]
|
||||||
|
await reaction.message.channel.send(files=files, reference=reaction.message)
|
||||||
|
|
||||||
|
async def on_message(self, message):
|
||||||
|
if message.author == self.user:
|
||||||
|
return
|
||||||
|
|
||||||
|
author = message.author
|
||||||
|
channel = message.channel
|
||||||
|
user_name = author.name # include nick
|
||||||
|
|
||||||
|
world = get_current_world()
|
||||||
|
if world:
|
||||||
|
active_world = f"Active world: {world.name} (theme: {world.theme})"
|
||||||
|
else:
|
||||||
|
active_world = "No active world"
|
||||||
|
|
||||||
|
if message.content.startswith("!adventure"):
|
||||||
|
await message.channel.send(f"Hello! Welcome to Adventure! {active_world}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.content.startswith("!help"):
|
||||||
|
await message.channel.send("Type `!join` to start playing!")
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.content.startswith("!join"):
|
||||||
|
character_name = remove_tags(message.content).replace("!join", "").strip()
|
||||||
|
if has_player(character_name):
|
||||||
|
await channel.send(f"{character_name} has already been taken!")
|
||||||
|
return
|
||||||
|
|
||||||
|
actor, agent = get_actor_agent_for_name(character_name)
|
||||||
|
if not actor:
|
||||||
|
await channel.send(f"Character `{character_name}` not found!")
|
||||||
|
return
|
||||||
|
|
||||||
|
def prompt_player(character: str, prompt: str):
|
||||||
|
logger.info(
|
||||||
|
"append prompt for character %s (user %s) to queue: %s",
|
||||||
|
character,
|
||||||
|
user_name,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
prompt_queue.put((character, prompt))
|
||||||
|
return True
|
||||||
|
|
||||||
|
player = RemotePlayer(
|
||||||
|
actor.name, actor.backstory, prompt_player, fallback_agent=agent
|
||||||
|
)
|
||||||
|
set_actor_agent_for_name(character_name, actor, player)
|
||||||
|
set_player(user_name, player)
|
||||||
|
|
||||||
|
logger.info(f"{user_name} has joined the game as {actor.name}!")
|
||||||
|
await message.channel.send(
|
||||||
|
f"{user_name} has joined the game as {actor.name}!"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.content.startswith("!leave"):
|
||||||
|
# TODO: revert to LLM agent
|
||||||
|
logger.info(f"{user_name} has left the game!")
|
||||||
|
await message.channel.send(f"{user_name} has left the game!")
|
||||||
|
return
|
||||||
|
|
||||||
|
player = get_player(user_name)
|
||||||
|
if player and isinstance(player, RemotePlayer):
|
||||||
|
content = remove_tags(message.content)
|
||||||
|
player.input_queue.put(content)
|
||||||
|
logger.info(
|
||||||
|
f"Received message from {user_name} for {player.name}: {content}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await message.channel.send(
|
||||||
|
"You are not currently playing Adventure! Type `!join` to start playing!"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
active_tasks = set()
|
||||||
|
|
||||||
|
|
||||||
|
def launch_bot():
|
||||||
|
def bot_main():
|
||||||
|
global client
|
||||||
|
|
||||||
|
intents = Intents.default()
|
||||||
|
# intents.message_content = True
|
||||||
|
|
||||||
|
client = AdventureClient(intents=intents)
|
||||||
|
client.run(environ["DISCORD_TOKEN"])
|
||||||
|
|
||||||
|
def prompt_main():
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
while True:
|
||||||
|
sleep(0.5)
|
||||||
|
if prompt_queue.empty():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(active_tasks) > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
character, prompt = prompt_queue.get()
|
||||||
|
logger.info("Prompting character %s: %s", character, prompt)
|
||||||
|
|
||||||
|
if client:
|
||||||
|
prompt_task = client.loop.create_task(broadcast_event(prompt))
|
||||||
|
active_tasks.add(prompt_task)
|
||||||
|
prompt_task.add_done_callback(active_tasks.discard)
|
||||||
|
|
||||||
|
bot_thread = Thread(target=bot_main)
|
||||||
|
bot_thread.start()
|
||||||
|
|
||||||
|
prompt_thread = Thread(target=prompt_main)
|
||||||
|
prompt_thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
def stop_bot():
|
||||||
|
global client
|
||||||
|
|
||||||
|
if client:
|
||||||
|
client.close()
|
||||||
|
client = None
|
||||||
|
|
||||||
|
|
||||||
|
# @cache
|
||||||
|
def get_active_channels():
|
||||||
|
if not client:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# return client.private_channels
|
||||||
|
return [
|
||||||
|
channel
|
||||||
|
for guild in client.guilds
|
||||||
|
for channel in guild.text_channels
|
||||||
|
if channel.name == "bots"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def broadcast_event(message: str | Embed):
|
||||||
|
if not client:
|
||||||
|
logger.warning("No Discord client available")
|
||||||
|
return
|
||||||
|
|
||||||
|
active_channels = get_active_channels()
|
||||||
|
if not active_channels:
|
||||||
|
logger.warning("No active channels")
|
||||||
|
return
|
||||||
|
|
||||||
|
for channel in active_channels:
|
||||||
|
if isinstance(message, str):
|
||||||
|
logger.info("Broadcasting to channel %s: %s", channel, message)
|
||||||
|
await channel.send(content=message)
|
||||||
|
elif isinstance(message, Embed):
|
||||||
|
logger.info(
|
||||||
|
"Broadcasting to channel %s: %s - %s",
|
||||||
|
channel,
|
||||||
|
message.title,
|
||||||
|
message.description,
|
||||||
|
)
|
||||||
|
await channel.send(embed=message)
|
||||||
|
|
||||||
|
|
||||||
|
def bot_action(room: Room, actor: Actor, message: str):
|
||||||
|
try:
|
||||||
|
action_embed = Embed(title=room.name, description=actor.name)
|
||||||
|
|
||||||
|
if could_be_json(message):
|
||||||
|
action_data = loads(message)
|
||||||
|
action_name = action_data["function"].replace("action_", "").title()
|
||||||
|
action_parameters = action_data.get("parameters", {})
|
||||||
|
|
||||||
|
action_embed.add_field(name="Action", value=action_name)
|
||||||
|
|
||||||
|
for key, value in action_parameters.items():
|
||||||
|
action_embed.add_field(name=key.replace("_", " ").title(), value=value)
|
||||||
|
else:
|
||||||
|
action_embed.add_field(name="Message", value=message)
|
||||||
|
|
||||||
|
prompt_queue.put((actor.name, action_embed))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to broadcast action: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
def bot_event(message: str):
|
||||||
|
prompt_queue.put((None, message))
|
||||||
|
|
||||||
|
|
||||||
|
def bot_result(room: Room, actor: Actor, action: str):
|
||||||
|
result_embed = Embed(title=room.name, description=actor.name)
|
||||||
|
result_embed.add_field(name="Result", value=action)
|
||||||
|
prompt_queue.put((actor.name, result_embed))
|
||||||
|
|
||||||
|
|
||||||
|
def player_event(character: str, id: str, event: Literal["join", "leave"]):
|
||||||
|
if event == "join":
|
||||||
|
prompt_queue.put((character, f"{character} has joined the game!"))
|
||||||
|
elif event == "leave":
|
||||||
|
prompt_queue.put((character, f"{character} has left the game!"))
|
|
@ -267,6 +267,14 @@ def main():
|
||||||
input_callbacks = []
|
input_callbacks = []
|
||||||
result_callbacks = []
|
result_callbacks = []
|
||||||
|
|
||||||
|
if args.discord:
|
||||||
|
from adventure.discord_bot import bot_action, bot_event, bot_result, launch_bot
|
||||||
|
|
||||||
|
launch_bot()
|
||||||
|
event_callbacks.append(bot_event)
|
||||||
|
input_callbacks.append(bot_action)
|
||||||
|
result_callbacks.append(bot_result)
|
||||||
|
|
||||||
if args.server:
|
if args.server:
|
||||||
from adventure.server import (
|
from adventure.server import (
|
||||||
launch_server,
|
launch_server,
|
||||||
|
|
|
@ -2,7 +2,7 @@ from json import dumps
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from readline import add_history
|
from readline import add_history
|
||||||
from typing import Any, Callable, Dict, List, Sequence
|
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from packit.agent import Agent
|
from packit.agent import Agent
|
||||||
|
@ -11,6 +11,50 @@ from packit.utils import could_be_json
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Dict[client, player]
|
||||||
|
active_players: Dict[str, "BasePlayer"] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_player(client: str) -> Optional["BasePlayer"]:
|
||||||
|
"""
|
||||||
|
Get a player by name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return active_players.get(client, None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_player(client: str, player: "BasePlayer"):
|
||||||
|
"""
|
||||||
|
Add a player to the active players.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if has_player(player.name):
|
||||||
|
raise ValueError(f"Someone is already playing as {player.name}!")
|
||||||
|
|
||||||
|
active_players[client] = player
|
||||||
|
|
||||||
|
|
||||||
|
def remove_player(client: str):
|
||||||
|
"""
|
||||||
|
Remove a player from the active players.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if client in active_players:
|
||||||
|
del active_players[client]
|
||||||
|
|
||||||
|
|
||||||
|
def has_player(character_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a character is already being played.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return character_name in [player.name for player in active_players.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def list_players():
|
||||||
|
return {client: player.name for client, player in active_players.items()}
|
||||||
|
|
||||||
|
|
||||||
class BasePlayer:
|
class BasePlayer:
|
||||||
"""
|
"""
|
||||||
A human agent that can interact with the world.
|
A human agent that can interact with the world.
|
||||||
|
|
|
@ -0,0 +1,195 @@
|
||||||
|
# This is an example that uses the websockets api to know when a prompt execution is done
|
||||||
|
# Once the prompt execution is done it downloads the images using the /history endpoint
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
import uuid
|
||||||
|
from logging import getLogger
|
||||||
|
from os import environ, path
|
||||||
|
from random import choice, randint
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
server_address = environ["COMFY_API"]
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cfg():
|
||||||
|
return randint(5, 8)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_steps():
|
||||||
|
return 30
|
||||||
|
|
||||||
|
|
||||||
|
def generate_batches(
|
||||||
|
count: int,
|
||||||
|
batch_size: int = 3,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Generate count images in batches of at most batch_size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
for i in range(0, count, batch_size):
|
||||||
|
batches.append(min(count - i, batch_size))
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|
||||||
|
def queue_prompt(prompt):
|
||||||
|
p = {"prompt": prompt, "client_id": client_id}
|
||||||
|
data = json.dumps(p).encode("utf-8")
|
||||||
|
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||||||
|
return json.loads(urllib.request.urlopen(req).read())
|
||||||
|
|
||||||
|
|
||||||
|
def get_image(filename, subfolder, folder_type):
|
||||||
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||||
|
url_values = urllib.parse.urlencode(data)
|
||||||
|
with urllib.request.urlopen(
|
||||||
|
"http://{}/view?{}".format(server_address, url_values)
|
||||||
|
) as response:
|
||||||
|
return response.read()
|
||||||
|
|
||||||
|
|
||||||
|
def get_history(prompt_id):
|
||||||
|
with urllib.request.urlopen(
|
||||||
|
"http://{}/history/{}".format(server_address, prompt_id)
|
||||||
|
) as response:
|
||||||
|
return json.loads(response.read())
|
||||||
|
|
||||||
|
|
||||||
|
def get_images(ws, prompt):
|
||||||
|
prompt_id = queue_prompt(prompt)["prompt_id"]
|
||||||
|
output_images = {}
|
||||||
|
while True:
|
||||||
|
out = ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message["type"] == "executing":
|
||||||
|
data = message["data"]
|
||||||
|
if data["node"] is None and data["prompt_id"] == prompt_id:
|
||||||
|
break # Execution is done
|
||||||
|
else:
|
||||||
|
continue # previews are binary data
|
||||||
|
|
||||||
|
history = get_history(prompt_id)[prompt_id]
|
||||||
|
for o in history["outputs"]:
|
||||||
|
for node_id in history["outputs"]:
|
||||||
|
node_output = history["outputs"][node_id]
|
||||||
|
if "images" in node_output:
|
||||||
|
images_output = []
|
||||||
|
for image in node_output["images"]:
|
||||||
|
image_data = get_image(
|
||||||
|
image["filename"], image["subfolder"], image["type"]
|
||||||
|
)
|
||||||
|
images_output.append(image_data)
|
||||||
|
output_images[node_id] = images_output
|
||||||
|
|
||||||
|
return output_images
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_tool(prompt, count, size="landscape"):
|
||||||
|
output_paths = []
|
||||||
|
for i, count in enumerate(generate_batches(count)):
|
||||||
|
results = generate_images(prompt, count, size, prefix=f"output-{i}")
|
||||||
|
output_paths.extend(results)
|
||||||
|
|
||||||
|
return output_paths
|
||||||
|
|
||||||
|
|
||||||
|
sizes = {
|
||||||
|
"landscape": (1024, 768),
|
||||||
|
"portrait": (768, 1024),
|
||||||
|
"square": (768, 768),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_images(
|
||||||
|
prompt: str, count: int, size="landscape", prefix="output"
|
||||||
|
) -> List[str]:
|
||||||
|
cfg = generate_cfg()
|
||||||
|
width, height = sizes.get(size, (512, 512))
|
||||||
|
steps = generate_steps()
|
||||||
|
seed = randint(0, 10000000)
|
||||||
|
checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"])
|
||||||
|
logger.info(
|
||||||
|
"generating %s images at %s by %s with prompt: %s", count, width, height, prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# parsing here helps ensure the template emits valid JSON
|
||||||
|
prompt_workflow = {
|
||||||
|
"3": {
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": {
|
||||||
|
"cfg": cfg,
|
||||||
|
"denoise": 1,
|
||||||
|
"latent_image": ["5", 0],
|
||||||
|
"model": ["4", 0],
|
||||||
|
"negative": ["7", 0],
|
||||||
|
"positive": ["6", 0],
|
||||||
|
"sampler_name": "euler_ancestral",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"seed": seed,
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {"ckpt_name": checkpoint},
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"inputs": {"batch_size": count, "height": height, "width": width},
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {"text": prompt, "clip": ["4", 1]},
|
||||||
|
},
|
||||||
|
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}},
|
||||||
|
"8": {
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"inputs": {"filename_prefix": prefix, "images": ["8", 0]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug("Connecting to Comfy API at %s", server_address)
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
|
images = get_images(ws, prompt_workflow)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for node_id in images:
|
||||||
|
for image_data in images[node_id]:
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
results.append(image)
|
||||||
|
|
||||||
|
paths: List[str] = []
|
||||||
|
for j, image in enumerate(results):
|
||||||
|
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
|
||||||
|
with open(image_path, "wb") as f:
|
||||||
|
image_bytes = io.BytesIO()
|
||||||
|
image.save(image_bytes, format="PNG")
|
||||||
|
f.write(image_bytes.getvalue())
|
||||||
|
|
||||||
|
paths.append(image_path)
|
||||||
|
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
paths = generate_images(
|
||||||
|
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
|
||||||
|
)
|
||||||
|
logger.info("Generated %d images: %s", len(paths), paths)
|
|
@ -3,20 +3,26 @@ from collections import deque
|
||||||
from json import dumps, loads
|
from json import dumps, loads
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Dict, Literal
|
from typing import Literal
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from adventure.context import get_actor_agent_for_name, set_actor_agent_for_name
|
from adventure.context import get_actor_agent_for_name, set_actor_agent_for_name
|
||||||
from adventure.models import Actor, Room, World
|
from adventure.models import Actor, Room, World
|
||||||
from adventure.player import RemotePlayer
|
from adventure.player import (
|
||||||
|
RemotePlayer,
|
||||||
|
get_player,
|
||||||
|
has_player,
|
||||||
|
list_players,
|
||||||
|
remove_player,
|
||||||
|
set_player,
|
||||||
|
)
|
||||||
from adventure.state import snapshot_world, world_json
|
from adventure.state import snapshot_world, world_json
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
connected = set()
|
connected = set()
|
||||||
characters: Dict[str, RemotePlayer] = {}
|
|
||||||
recent_events = deque(maxlen=100)
|
recent_events = deque(maxlen=100)
|
||||||
recent_world = None
|
recent_world = None
|
||||||
|
|
||||||
|
@ -40,12 +46,13 @@ async def handler(websocket):
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_turn(character: str, prompt: str) -> bool:
|
def sync_turn(character: str, prompt: str) -> bool:
|
||||||
if id not in characters:
|
player = get_player(id)
|
||||||
return False
|
if player and player.name == character:
|
||||||
|
|
||||||
asyncio.run(next_turn(character, prompt))
|
asyncio.run(next_turn(character, prompt))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await websocket.send(dumps({"type": "id", "id": id}))
|
await websocket.send(dumps({"type": "id", "id": id}))
|
||||||
|
|
||||||
|
@ -67,9 +74,8 @@ async def handler(websocket):
|
||||||
data = loads(message)
|
data = loads(message)
|
||||||
message_type = data.get("type", None)
|
message_type = data.get("type", None)
|
||||||
if message_type == "player":
|
if message_type == "player":
|
||||||
character = characters.get(id)
|
# TODO: should this always remove?
|
||||||
if character:
|
remove_player(id)
|
||||||
del characters[id]
|
|
||||||
|
|
||||||
character_name = data["become"]
|
character_name = data["become"]
|
||||||
actor, llm_agent = get_actor_agent_for_name(character_name)
|
actor, llm_agent = get_actor_agent_for_name(character_name)
|
||||||
|
@ -84,9 +90,7 @@ async def handler(websocket):
|
||||||
)
|
)
|
||||||
llm_agent = llm_agent.fallback_agent
|
llm_agent = llm_agent.fallback_agent
|
||||||
|
|
||||||
if character_name in [
|
if has_player(character_name):
|
||||||
player.name for player in characters.values()
|
|
||||||
]:
|
|
||||||
logger.error(f"Character {character_name} is already in use")
|
logger.error(f"Character {character_name} is already in use")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -94,7 +98,7 @@ async def handler(websocket):
|
||||||
player = RemotePlayer(
|
player = RemotePlayer(
|
||||||
actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent
|
actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent
|
||||||
)
|
)
|
||||||
characters[id] = player
|
set_player(id, player)
|
||||||
logger.info(f"Client {id} is now character {character_name}")
|
logger.info(f"Client {id} is now character {character_name}")
|
||||||
|
|
||||||
# swap out the LLM agent
|
# swap out the LLM agent
|
||||||
|
@ -103,9 +107,12 @@ async def handler(websocket):
|
||||||
# notify all clients that this character is now active
|
# notify all clients that this character is now active
|
||||||
player_event(character_name, id, "join")
|
player_event(character_name, id, "join")
|
||||||
player_list()
|
player_list()
|
||||||
elif message_type == "input" and id in characters:
|
elif message_type == "input":
|
||||||
player = characters[id]
|
player = get_player(id)
|
||||||
logger.info("queueing input for player %s: %s", player.name, data)
|
if player and isinstance(player, RemotePlayer):
|
||||||
|
logger.info(
|
||||||
|
"queueing input for player %s: %s", player.name, data
|
||||||
|
)
|
||||||
player.input_queue.put(data["input"])
|
player.input_queue.put(data["input"])
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -116,9 +123,9 @@ async def handler(websocket):
|
||||||
connected.remove(websocket)
|
connected.remove(websocket)
|
||||||
|
|
||||||
# swap out the character for the original agent when they disconnect
|
# swap out the character for the original agent when they disconnect
|
||||||
if id in characters:
|
player = get_player(id)
|
||||||
player = characters[id]
|
if player and isinstance(player, RemotePlayer):
|
||||||
del characters[id]
|
remove_player(id)
|
||||||
|
|
||||||
logger.info("Disconnecting player for %s", player.name)
|
logger.info("Disconnecting player for %s", player.name)
|
||||||
player_event(player.name, id, "leave")
|
player_event(player.name, id, "leave")
|
||||||
|
@ -217,8 +224,9 @@ def player_event(character: str, id: str, event: Literal["join", "leave"]):
|
||||||
|
|
||||||
|
|
||||||
def player_list():
|
def player_list():
|
||||||
json_broadcast ={
|
players = {value: key for key, value in list_players()}
|
||||||
|
json_broadcast = {
|
||||||
"type": "players",
|
"type": "players",
|
||||||
"players": {player.name: player_id for player_id, player in characters.items()},
|
"players": players,
|
||||||
}
|
}
|
||||||
send_and_append(json_broadcast)
|
send_and_append(json_broadcast)
|
||||||
|
|
Loading…
Reference in New Issue