From 593f3981d18a1a47b8f9dabc021a2d027588814a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 May 2024 00:08:53 -0500 Subject: [PATCH] use events for images, show them in all servers --- adventure/{discord_bot.py => bot_discord.py} | 123 ++++++++----------- adventure/main.py | 11 +- adventure/models/event.py | 59 +++++++-- adventure/render_comfy.py | 89 ++++++++++++++ adventure/{server.py => server_socket.py} | 93 ++++++++++---- client/src/app.tsx | 9 +- client/src/events.tsx | 92 +++++++++++--- client/src/history.tsx | 14 ++- 8 files changed, 357 insertions(+), 133 deletions(-) rename adventure/{discord_bot.py => bot_discord.py} (77%) rename adventure/{server.py => server_socket.py} (72%) diff --git a/adventure/discord_bot.py b/adventure/bot_discord.py similarity index 77% rename from adventure/discord_bot.py rename to adventure/bot_discord.py index bbca266..760b6e2 100644 --- a/adventure/discord_bot.py +++ b/adventure/bot_discord.py @@ -3,6 +3,7 @@ from os import environ from queue import Queue from re import sub from threading import Thread +from typing import Dict from discord import Client, Embed, File, Intents @@ -18,17 +19,19 @@ from adventure.models.event import ( GenerateEvent, PlayerEvent, PromptEvent, + RenderEvent, ReplyEvent, ResultEvent, StatusEvent, ) from adventure.player import RemotePlayer, get_player, has_player, set_player -from adventure.render_comfy import generate_image_tool +from adventure.render_comfy import render_event logger = getLogger(__name__) client = None active_tasks = set() +event_messages: Dict[str, str | GameEvent] = {} event_queue: Queue[GameEvent] = Queue() @@ -40,54 +43,6 @@ def remove_tags(text: str) -> str: return sub(r"<[^>]*>", "", text) -def find_embed_field(embed: Embed, name: str) -> str | None: - return next((field.value for field in embed.fields if field.name == name), None) - - -# TODO: becomes prompt from event -def prompt_from_embed(embed: Embed) -> str | None: - room_name = embed.title - actor_name = embed.description - - world = get_current_world() - if not world: - return - - room = next((room for room in world.rooms if room.name == room_name), None) - if not room: - return - - actor = next((actor for actor in room.actors if actor.name == actor_name), None) - if not actor: - return - - item_field = find_embed_field(embed, "Item") - - action_field = find_embed_field(embed, "Action") - if action_field: - if item_field: - item = next( - ( - item - for item in (room.items + actor.items) - if item.name == item_field - ), - None, - ) - if item: - return f"{actor.name} {action_field} the {item.name}. {item.description}. {actor.description}. {room.description}." - - return f"{actor.name} {action_field} the {item_field}. {actor.description}. {room.description}." - - return f"{actor.name} {action_field}. {actor.description}. {room.name}." - - result_field = find_embed_field(embed, "Result") - if result_field: - return f"{result_field}. {actor.description}. {room.description}." - - return - - class AdventureClient(Client): async def on_ready(self): logger.info(f"Logged in as {self.user}") @@ -98,23 +53,16 @@ class AdventureClient(Client): logger.info(f"Reaction added: {reaction} by {user}") if reaction.emoji == "📷": - # message_id = reaction.message.id - # TODO: look up event that caused this message, get the room and actors - if len(reaction.message.embeds) > 0: - embed = reaction.message.embeds[0] - prompt = prompt_from_embed(embed) - else: - prompt = remove_tags(reaction.message.content) - if prompt.startswith("Generating"): - # TODO: get the entity from the message - pass + message_id = reaction.message.id + if message_id not in event_messages: + logger.warning(f"Message {message_id} not found in event messages") + # TODO: return error message + return - await reaction.message.add_reaction("📸") - paths = generate_image_tool(prompt, 2) - logger.info(f"Generated images: {paths}") - - files = [File(filename) for filename in paths] - await reaction.message.channel.send(files=files, reference=reaction.message) + event = event_messages[message_id] + if isinstance(event, GameEvent): + render_event(event) + await reaction.message.add_reaction("📸") async def on_message(self, message): if message.author == self.user: @@ -277,17 +225,54 @@ async def broadcast_event(message: str | GameEvent): for channel in active_channels: if isinstance(message, str): - logger.info("broadcasting to channel %s: %s", channel, message) - await channel.send(content=message) - elif isinstance(message, GameEvent): + # deprecated, use events instead + logger.warning( + "broadcasting non-event message to channel %s: %s", channel, message + ) + event_message = await channel.send(content=message) + elif isinstance(message, RenderEvent): + # special handling to upload images + # find the source event + source_event_id = message.source.id + source_message_id = next( + ( + message_id + for message_id, event in event_messages.items() + if isinstance(event, GameEvent) and event.id == source_event_id + ), + None, + ) + if not source_message_id: + logger.warning("source event not found: %s", source_event_id) + return + + # open and upload images + files = [File(filename) for filename in message.paths] + try: + source_message = await channel.fetch_message(source_message_id) + except Exception as err: + logger.warning("source message not found: %s", err) + return + + # send the images as a reply to the source message + event_message = await source_message.channel.send( + files=files, reference=source_message + ) + else: embed = embed_from_event(message) + if not embed: + logger.warning("no embed for event: %s", message) + return + logger.info( "broadcasting to channel %s: %s - %s", channel, embed.title, embed.description, ) - await channel.send(embed=embed) + event_message = await channel.send(embed=embed) + + event_messages[event_message.id] = message def embed_from_event(event: GameEvent) -> Embed: diff --git a/adventure/main.py b/adventure/main.py index 2b818fe..b4f76dd 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -84,6 +84,7 @@ def parse_args(): parser.add_argument( "--optional-actions", type=bool, help="Whether to include optional actions" ) + parser.add_argument("--render", type=bool, help="Whether to render the simulation") parser.add_argument( "--server", type=str, help="The address on which to run the server" ) @@ -199,14 +200,20 @@ def main(): # launch other threads threads = [] + + if args.render: + from adventure.render_comfy import launch_render + + threads.extend(launch_render()) + if args.discord: - from adventure.discord_bot import bot_event, launch_bot + from adventure.bot_discord import bot_event, launch_bot threads.extend(launch_bot()) callbacks.append(bot_event) if args.server: - from adventure.server import launch_server, server_event, server_system + from adventure.server_socket import launch_server, server_event, server_system threads.extend(launch_server()) callbacks.append(server_event) diff --git a/adventure/models/event.py b/adventure/models/event.py index 957169c..9262e38 100644 --- a/adventure/models/event.py +++ b/adventure/models/event.py @@ -1,25 +1,33 @@ from json import loads from typing import Any, Callable, Dict, List, Literal +from uuid import uuid4 + +from pydantic import Field from .base import dataclass from .entity import Actor, Item, Room, WorldEntity -@dataclass +def uuid() -> str: + return uuid4().hex + + class BaseEvent: """ A base event class. """ + id: str type: str @dataclass -class GenerateEvent: +class GenerateEvent(BaseEvent): """ A new entity has been generated. """ + id = Field(default_factory=uuid) type = "generate" name: str entity: WorldEntity | None = None @@ -34,11 +42,12 @@ class GenerateEvent: @dataclass -class ActionEvent: +class ActionEvent(BaseEvent): """ An actor has taken an action. """ + id = Field(default_factory=uuid) type = "action" action: str parameters: Dict[str, bool | float | int | str] @@ -60,11 +69,12 @@ class ActionEvent: @dataclass -class PromptEvent: +class PromptEvent(BaseEvent): """ A prompt for an actor to take an action. """ + id = Field(default_factory=uuid) type = "prompt" prompt: str room: Room @@ -72,13 +82,14 @@ class PromptEvent: @dataclass -class ReplyEvent: +class ReplyEvent(BaseEvent): """ An actor has replied with text. This is the non-JSON version of an ActionEvent. """ + id = Field(default_factory=uuid) type = "reply" text: str room: Room @@ -90,11 +101,12 @@ class ReplyEvent: @dataclass -class ResultEvent: +class ResultEvent(BaseEvent): """ A result of an action. """ + id = Field(default_factory=uuid) type = "result" result: str room: Room @@ -102,11 +114,12 @@ class ResultEvent: @dataclass -class StatusEvent: +class StatusEvent(BaseEvent): """ A status broadcast event with text. """ + id = Field(default_factory=uuid) type = "status" text: str room: Room | None = None @@ -114,7 +127,7 @@ class StatusEvent: @dataclass -class SnapshotEvent: +class SnapshotEvent(BaseEvent): """ A snapshot of the world state. @@ -122,6 +135,7 @@ class SnapshotEvent: That is especially important for the memory, which is a dictionary of actor names to lists of messages. """ + id = Field(default_factory=uuid) type = "snapshot" world: Dict[str, Any] memory: Dict[str, List[Any]] @@ -129,20 +143,45 @@ class SnapshotEvent: @dataclass -class PlayerEvent: +class PlayerEvent(BaseEvent): """ A player joining or leaving the game. """ + id = Field(default_factory=uuid) type = "player" status: Literal["join", "leave"] character: str client: str +@dataclass +class PlayerListEvent(BaseEvent): + """ + A list of players in the game and the characters they are playing. + """ + + id = Field(default_factory=uuid) + type = "players" + players: Dict[str, str] + + +@dataclass +class RenderEvent(BaseEvent): + """ + Images have been rendered. + """ + + id = Field(default_factory=uuid) + type = "render" + paths: List[str] + source: "GameEvent" + + # event types WorldEvent = ActionEvent | PromptEvent | ReplyEvent | ResultEvent | StatusEvent -GameEvent = GenerateEvent | PlayerEvent | WorldEvent +PlayerEventType = PlayerEvent | PlayerListEvent +GameEvent = GenerateEvent | PlayerEventType | RenderEvent | WorldEvent # callback types EventCallback = Callable[[GameEvent], None] diff --git a/adventure/render_comfy.py b/adventure/render_comfy.py index 610a1eb..8f3821b 100644 --- a/adventure/render_comfy.py +++ b/adventure/render_comfy.py @@ -8,12 +8,24 @@ import urllib.request import uuid from logging import getLogger from os import environ, path +from queue import Queue from random import choice, randint +from threading import Thread from typing import List import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) from PIL import Image +from adventure.context import broadcast +from adventure.models.event import ( + ActionEvent, + GameEvent, + RenderEvent, + ReplyEvent, + ResultEvent, + StatusEvent, +) + logger = getLogger(__name__) server_address = environ["COMFY_API"] @@ -195,6 +207,7 @@ def generate_images( paths: List[str] = [] for j, image in enumerate(results): + # TODO: replace with environment variable image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png") with open(image_path, "wb") as f: image_bytes = io.BytesIO() @@ -206,6 +219,82 @@ def generate_images( return paths +def prompt_from_event(event: GameEvent) -> str | None: + if isinstance(event, ActionEvent): + if event.item: + return f"{event.actor.name} uses the {event.item.name}. {event.item.description}. {event.actor.description}. {event.room.description}." + + return f"{event.actor.name} {event.action}. {event.actor.description}. {event.room.description}." + + if isinstance(event, ReplyEvent): + return event.text + + if isinstance(event, ResultEvent): + return f"{event.result}. {event.actor.description}. {event.room.description}." + + if isinstance(event, StatusEvent): + if event.room: + if event.actor: + return f"{event.text}. {event.actor.description}. {event.room.description}." + + return f"{event.text}. {event.room.description}." + + return event.text + + return None + + +def prefix_from_event(event: GameEvent) -> str: + if isinstance(event, ActionEvent): + return ( + f"{event.actor.name}-{event.action}-{event.item.name if event.item else ''}" + ) + + if isinstance(event, ReplyEvent): + return f"{event.actor.name}-reply" + + if isinstance(event, ResultEvent): + return f"{event.actor.name}-result" + + if isinstance(event, StatusEvent): + return "status" + + return "unknown" + + +# requests to generate images for game events +render_queue: Queue[GameEvent] = Queue() + + +def render_loop(): + while True: + event = render_queue.get() + prompt = prompt_from_event(event) + if prompt: + logger.info("rendering prompt for event %s: %s", event, prompt) + prefix = prefix_from_event(event) + image_paths = generate_images(prompt, 2, prefix=prefix) + broadcast(RenderEvent(paths=image_paths, source=event)) + else: + logger.warning("no prompt for event %s", event) + + +def render_event(event: GameEvent): + render_queue.put(event) + + +render_thread = None + + +def launch_render(): + global render_thread + + render_thread = Thread(target=render_loop, daemon=True) + render_thread.start() + + return [render_thread] + + if __name__ == "__main__": paths = generate_images( "A painting of a beautiful sunset over a calm lake", 3, "landscape" diff --git a/adventure/server.py b/adventure/server_socket.py similarity index 72% rename from adventure/server.py rename to adventure/server_socket.py index 833ac73..23c4aa8 100644 --- a/adventure/server.py +++ b/adventure/server_socket.py @@ -1,17 +1,26 @@ import asyncio +from base64 import b64encode from collections import deque +from io import BytesIO from json import dumps, loads from logging import getLogger from threading import Thread -from typing import Any, Dict, Literal +from typing import Any, Dict, Literal, MutableSequence from uuid import uuid4 import websockets +from PIL import Image from pydantic import RootModel from adventure.context import broadcast, get_actor_agent_for_name, set_actor_agent from adventure.models.entity import Actor, Item, Room, World -from adventure.models.event import GameEvent, PlayerEvent, PromptEvent +from adventure.models.event import ( + GameEvent, + PlayerEvent, + PlayerListEvent, + PromptEvent, + RenderEvent, +) from adventure.player import ( RemotePlayer, get_player, @@ -20,12 +29,14 @@ from adventure.player import ( remove_player, set_player, ) +from adventure.render_comfy import render_event from adventure.state import snapshot_world, world_json logger = getLogger(__name__) connected = set() -recent_events = deque(maxlen=100) +recent_events: MutableSequence[GameEvent] = deque(maxlen=100) +recent_json: MutableSequence[str] = deque(maxlen=100) last_snapshot = None player_names: Dict[str, str] = {} @@ -61,12 +72,13 @@ async def handler(websocket): return False try: - await websocket.send(dumps({"type": "id", "id": id})) + await websocket.send(dumps({"type": "id", "client": id})) - if last_snapshot: + # TODO: only send this if the recent events don't contain a snapshot + if last_snapshot and last_snapshot not in recent_json: await websocket.send(last_snapshot) - for message in recent_events: + for message in recent_json: await websocket.send(message) except Exception: logger.exception("Failed to send recent messages to new client") @@ -141,8 +153,8 @@ async def handler(websocket): set_actor_agent(actor.name, actor, player) # notify all clients that this character is now active - player_event(character_name, player_name, "join") - player_list() + broadcast_player_event(character_name, player_name, "join") + broadcast_player_list() elif message_type == "input": player = get_player(id) if player and isinstance(player, RemotePlayer): @@ -150,6 +162,13 @@ async def handler(websocket): "queueing input for player %s: %s", player.name, data ) player.input_queue.put(data["input"]) + elif message_type == "render": + event_id = data["event"] + event = next((e for e in recent_events if e.id == event_id), None) + if event: + render_event(event) + else: + logger.error(f"Failed to find event {event_id}") except Exception: logger.exception("Failed to parse message") @@ -166,16 +185,16 @@ async def handler(websocket): remove_player(id) player_name = get_player_name(id) - logger.info("Disconnecting player %s from %s", player_name, player.name) - player_event(player.name, player_name, "leave") - player_list() + logger.info("disconnecting player %s from %s", player_name, player.name) + broadcast_player_event(player.name, player_name, "leave") + broadcast_player_list() actor, _ = get_actor_agent_for_name(player.name) if actor and player.fallback_agent: - logger.info("Restoring LLM agent for %s", player.name) + logger.info("restoring LLM agent for %s", player.name) set_actor_agent(player.name, actor, player.fallback_agent) - logger.info("Client disconnected: %s", id) + logger.info("client disconnected: %s", id) socket_thread = None @@ -188,9 +207,9 @@ def server_json(obj): return world_json(obj) -def send_and_append(message): +def send_and_append(id: str, message: Dict): json_message = dumps(message, default=server_json) - recent_events.append(json_message) + recent_json.append(json_message) websockets.broadcast(connected, json_message) return json_message @@ -215,28 +234,50 @@ async def server_main(): def server_system(world: World, step: int): global last_snapshot + id = uuid4().hex # TODO: should a server be allowed to generate event IDs? json_state = { **snapshot_world(world, step), + "id": id, "type": "snapshot", } - last_snapshot = send_and_append(json_state) + last_snapshot = send_and_append(id, json_state) def server_event(event: GameEvent): json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump() - json_event["type"] = event.type - send_and_append(json_event) + json_event.update( + { + "id": event.id, + "type": event.type, + } + ) + + if isinstance(event, RenderEvent): + # load and encode the images + image_paths = event.paths + image_data = {} + for path in image_paths: + with Image.open(path, "r") as image: + buffered = BytesIO() + image.save( + buffered, format="JPEG", quality=80, optimize=True, progressive=True + ) + image_str = b64encode(buffered.getvalue()) + image_data[path] = image_str.decode("utf-8") + + json_event["images"] = image_data + + recent_events.append(event) + send_and_append(event.id, json_event) -def player_event(character: str, client: str, status: Literal["join", "leave"]): +def broadcast_player_event( + character: str, client: str, status: Literal["join", "leave"] +): event = PlayerEvent(status=status, character=character, client=client) broadcast(event) -def player_list(): - json_broadcast = { - "type": "players", - "players": list_players(), - } - # TODO: broadcast this - send_and_append(json_broadcast) +def broadcast_player_list(): + event = PlayerListEvent(players=list_players()) + broadcast(event) diff --git a/client/src/app.tsx b/client/src/app.tsx index e4b96ca..21c27a3 100644 --- a/client/src/app.tsx +++ b/client/src/app.tsx @@ -19,7 +19,7 @@ import useWebSocketModule from 'react-use-websocket'; import { useStore } from 'zustand'; import { HistoryPanel } from './history.js'; -import { Actor, Item, Room } from './models.js'; +import { Actor, GameEvent, Item, Room } from './models.js'; import { PlayerPanel } from './player.js'; import { store, StoreState } from './store.js'; import { WorldPanel } from './world.js'; @@ -93,6 +93,11 @@ export function App(props: AppProps) { // socket stuff const { lastMessage, readyState, sendMessage } = useWebSocket(props.socketUrl); + // socket senders + function renderEvent(event: string) { + sendMessage(JSON.stringify({ type: 'render', event })); + } + function setPlayer(actor: Maybe) { // do not call setCharacter until the server confirms the player change if (doesExist(actor)) { @@ -179,7 +184,7 @@ export function App(props: AppProps) { - + diff --git a/client/src/events.tsx b/client/src/events.tsx index a8295b7..a709adb 100644 --- a/client/src/events.tsx +++ b/client/src/events.tsx @@ -1,21 +1,45 @@ -import { ListItem, ListItemText, ListItemAvatar, Avatar, Typography } from '@mui/material'; -import React, { MutableRefObject } from 'react'; +import { Avatar, IconButton, ImageList, ImageListItem, ListItem, ListItemAvatar, ListItemText, Typography } from '@mui/material'; +import React, { Fragment, MutableRefObject } from 'react'; +import { Camera } from '@mui/icons-material'; import { formatters } from './format.js'; +import { GameEvent } from './models.js'; + +export function openImage(image: string) { + const byteCharacters = atob(image); + const byteNumbers = new Array(byteCharacters.length); + for (let i = 0; i < byteCharacters.length; i++) { + byteNumbers[i] = byteCharacters.charCodeAt(i); + } + const byteArray = new Uint8Array(byteNumbers); + const file = new Blob([byteArray], { type: 'image/jpeg;base64' }); + const fileURL = URL.createObjectURL(file); + window.open(fileURL, '_blank'); +} export interface EventItemProps { // eslint-disable-next-line @typescript-eslint/no-explicit-any event: any; // eslint-disable-next-line @typescript-eslint/no-explicit-any focusRef?: MutableRefObject; + + renderEvent: (event: GameEvent) => void; } export function ActionEventItem(props: EventItemProps) { - const { event } = props; - const { actor, room, type } = event; + const { event, renderEvent } = props; + const { id, actor, room, type } = event; const content = formatters[type](event); - return + return renderEvent(id)}> + + + } + > @@ -41,23 +65,26 @@ export function ActionEventItem(props: EventItemProps) { export function SnapshotEventItem(props: EventItemProps) { const { event } = props; const { step, world } = event; - const { theme } = world; + const { name, theme } = world; return - Step {step} - + + + Step: {step} + + World Theme: {theme} + } /> ; @@ -102,7 +129,10 @@ export function PlayerEventItem(props: EventItemProps) { secondary = `${client} has left the game. ${character} is now controlled by an LLM`; } - return + return @@ -122,6 +152,25 @@ export function PlayerEventItem(props: EventItemProps) { ; } +export function RenderEventItem(props: EventItemProps) { + const { event } = props; + const { images } = event; + + return + + + + + {Object.entries(images).map(([name, image]) => + openImage(image)} alt="Render" /> + )} + } + /> + ; +} + export function EventItem(props: EventItemProps) { const { event } = props; const { type } = event; @@ -129,13 +178,16 @@ export function EventItem(props: EventItemProps) { switch (type) { case 'action': case 'result': - return ; + return ; case 'reply': - return ; + case 'status': // TODO: should have a different component + return ; case 'player': - return ; + return ; + case 'render': + return ; case 'snapshot': - return ; + return ; default: return diff --git a/client/src/history.tsx b/client/src/history.tsx index c73a32e..ea0cb79 100644 --- a/client/src/history.tsx +++ b/client/src/history.tsx @@ -3,7 +3,8 @@ import { Divider, List } from '@mui/material'; import React, { useEffect, useRef } from 'react'; import { useStore } from 'zustand'; import { EventItem } from './events'; -import { store, StoreState } from './store'; +import { GameEvent } from './models'; +import { StoreState, store } from './store'; export function historyStateSelector(s: StoreState) { return { @@ -12,9 +13,14 @@ export function historyStateSelector(s: StoreState) { }; } -export function HistoryPanel() { +export interface HistoryPanelProps { + renderEvent: (event: GameEvent) => void; +} + +export function HistoryPanel(props: HistoryPanelProps) { const state = useStore(store, historyStateSelector); const { history, scroll } = state; + const { renderEvent } = props; const scrollRef = useRef>(undefined); @@ -28,10 +34,10 @@ export function HistoryPanel() { const items = history.map((item, index) => { if (index === history.length - 1) { - return ; + return ; } - return ; + return ; }); return