1
0
Fork 0

use events for images, show them in all servers

This commit is contained in:
Sean Sube 2024-05-12 00:08:53 -05:00
parent 7010da8ed2
commit 593f3981d1
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 357 additions and 133 deletions

View File

@ -3,6 +3,7 @@ from os import environ
from queue import Queue
from re import sub
from threading import Thread
from typing import Dict
from discord import Client, Embed, File, Intents
@ -18,17 +19,19 @@ from adventure.models.event import (
GenerateEvent,
PlayerEvent,
PromptEvent,
RenderEvent,
ReplyEvent,
ResultEvent,
StatusEvent,
)
from adventure.player import RemotePlayer, get_player, has_player, set_player
from adventure.render_comfy import generate_image_tool
from adventure.render_comfy import render_event
logger = getLogger(__name__)
client = None
active_tasks = set()
event_messages: Dict[str, str | GameEvent] = {}
event_queue: Queue[GameEvent] = Queue()
@ -40,54 +43,6 @@ def remove_tags(text: str) -> str:
return sub(r"<[^>]*>", "", text)
def find_embed_field(embed: Embed, name: str) -> str | None:
return next((field.value for field in embed.fields if field.name == name), None)
# TODO: becomes prompt from event
def prompt_from_embed(embed: Embed) -> str | None:
room_name = embed.title
actor_name = embed.description
world = get_current_world()
if not world:
return
room = next((room for room in world.rooms if room.name == room_name), None)
if not room:
return
actor = next((actor for actor in room.actors if actor.name == actor_name), None)
if not actor:
return
item_field = find_embed_field(embed, "Item")
action_field = find_embed_field(embed, "Action")
if action_field:
if item_field:
item = next(
(
item
for item in (room.items + actor.items)
if item.name == item_field
),
None,
)
if item:
return f"{actor.name} {action_field} the {item.name}. {item.description}. {actor.description}. {room.description}."
return f"{actor.name} {action_field} the {item_field}. {actor.description}. {room.description}."
return f"{actor.name} {action_field}. {actor.description}. {room.name}."
result_field = find_embed_field(embed, "Result")
if result_field:
return f"{result_field}. {actor.description}. {room.description}."
return
class AdventureClient(Client):
async def on_ready(self):
logger.info(f"Logged in as {self.user}")
@ -98,23 +53,16 @@ class AdventureClient(Client):
logger.info(f"Reaction added: {reaction} by {user}")
if reaction.emoji == "📷":
# message_id = reaction.message.id
# TODO: look up event that caused this message, get the room and actors
if len(reaction.message.embeds) > 0:
embed = reaction.message.embeds[0]
prompt = prompt_from_embed(embed)
else:
prompt = remove_tags(reaction.message.content)
if prompt.startswith("Generating"):
# TODO: get the entity from the message
pass
message_id = reaction.message.id
if message_id not in event_messages:
logger.warning(f"Message {message_id} not found in event messages")
# TODO: return error message
return
event = event_messages[message_id]
if isinstance(event, GameEvent):
render_event(event)
await reaction.message.add_reaction("📸")
paths = generate_image_tool(prompt, 2)
logger.info(f"Generated images: {paths}")
files = [File(filename) for filename in paths]
await reaction.message.channel.send(files=files, reference=reaction.message)
async def on_message(self, message):
if message.author == self.user:
@ -277,17 +225,54 @@ async def broadcast_event(message: str | GameEvent):
for channel in active_channels:
if isinstance(message, str):
logger.info("broadcasting to channel %s: %s", channel, message)
await channel.send(content=message)
elif isinstance(message, GameEvent):
# deprecated, use events instead
logger.warning(
"broadcasting non-event message to channel %s: %s", channel, message
)
event_message = await channel.send(content=message)
elif isinstance(message, RenderEvent):
# special handling to upload images
# find the source event
source_event_id = message.source.id
source_message_id = next(
(
message_id
for message_id, event in event_messages.items()
if isinstance(event, GameEvent) and event.id == source_event_id
),
None,
)
if not source_message_id:
logger.warning("source event not found: %s", source_event_id)
return
# open and upload images
files = [File(filename) for filename in message.paths]
try:
source_message = await channel.fetch_message(source_message_id)
except Exception as err:
logger.warning("source message not found: %s", err)
return
# send the images as a reply to the source message
event_message = await source_message.channel.send(
files=files, reference=source_message
)
else:
embed = embed_from_event(message)
if not embed:
logger.warning("no embed for event: %s", message)
return
logger.info(
"broadcasting to channel %s: %s - %s",
channel,
embed.title,
embed.description,
)
await channel.send(embed=embed)
event_message = await channel.send(embed=embed)
event_messages[event_message.id] = message
def embed_from_event(event: GameEvent) -> Embed:

View File

@ -84,6 +84,7 @@ def parse_args():
parser.add_argument(
"--optional-actions", type=bool, help="Whether to include optional actions"
)
parser.add_argument("--render", type=bool, help="Whether to render the simulation")
parser.add_argument(
"--server", type=str, help="The address on which to run the server"
)
@ -199,14 +200,20 @@ def main():
# launch other threads
threads = []
if args.render:
from adventure.render_comfy import launch_render
threads.extend(launch_render())
if args.discord:
from adventure.discord_bot import bot_event, launch_bot
from adventure.bot_discord import bot_event, launch_bot
threads.extend(launch_bot())
callbacks.append(bot_event)
if args.server:
from adventure.server import launch_server, server_event, server_system
from adventure.server_socket import launch_server, server_event, server_system
threads.extend(launch_server())
callbacks.append(server_event)

View File

@ -1,25 +1,33 @@
from json import loads
from typing import Any, Callable, Dict, List, Literal
from uuid import uuid4
from pydantic import Field
from .base import dataclass
from .entity import Actor, Item, Room, WorldEntity
@dataclass
def uuid() -> str:
return uuid4().hex
class BaseEvent:
"""
A base event class.
"""
id: str
type: str
@dataclass
class GenerateEvent:
class GenerateEvent(BaseEvent):
"""
A new entity has been generated.
"""
id = Field(default_factory=uuid)
type = "generate"
name: str
entity: WorldEntity | None = None
@ -34,11 +42,12 @@ class GenerateEvent:
@dataclass
class ActionEvent:
class ActionEvent(BaseEvent):
"""
An actor has taken an action.
"""
id = Field(default_factory=uuid)
type = "action"
action: str
parameters: Dict[str, bool | float | int | str]
@ -60,11 +69,12 @@ class ActionEvent:
@dataclass
class PromptEvent:
class PromptEvent(BaseEvent):
"""
A prompt for an actor to take an action.
"""
id = Field(default_factory=uuid)
type = "prompt"
prompt: str
room: Room
@ -72,13 +82,14 @@ class PromptEvent:
@dataclass
class ReplyEvent:
class ReplyEvent(BaseEvent):
"""
An actor has replied with text.
This is the non-JSON version of an ActionEvent.
"""
id = Field(default_factory=uuid)
type = "reply"
text: str
room: Room
@ -90,11 +101,12 @@ class ReplyEvent:
@dataclass
class ResultEvent:
class ResultEvent(BaseEvent):
"""
A result of an action.
"""
id = Field(default_factory=uuid)
type = "result"
result: str
room: Room
@ -102,11 +114,12 @@ class ResultEvent:
@dataclass
class StatusEvent:
class StatusEvent(BaseEvent):
"""
A status broadcast event with text.
"""
id = Field(default_factory=uuid)
type = "status"
text: str
room: Room | None = None
@ -114,7 +127,7 @@ class StatusEvent:
@dataclass
class SnapshotEvent:
class SnapshotEvent(BaseEvent):
"""
A snapshot of the world state.
@ -122,6 +135,7 @@ class SnapshotEvent:
That is especially important for the memory, which is a dictionary of actor names to lists of messages.
"""
id = Field(default_factory=uuid)
type = "snapshot"
world: Dict[str, Any]
memory: Dict[str, List[Any]]
@ -129,20 +143,45 @@ class SnapshotEvent:
@dataclass
class PlayerEvent:
class PlayerEvent(BaseEvent):
"""
A player joining or leaving the game.
"""
id = Field(default_factory=uuid)
type = "player"
status: Literal["join", "leave"]
character: str
client: str
@dataclass
class PlayerListEvent(BaseEvent):
"""
A list of players in the game and the characters they are playing.
"""
id = Field(default_factory=uuid)
type = "players"
players: Dict[str, str]
@dataclass
class RenderEvent(BaseEvent):
"""
Images have been rendered.
"""
id = Field(default_factory=uuid)
type = "render"
paths: List[str]
source: "GameEvent"
# event types
WorldEvent = ActionEvent | PromptEvent | ReplyEvent | ResultEvent | StatusEvent
GameEvent = GenerateEvent | PlayerEvent | WorldEvent
PlayerEventType = PlayerEvent | PlayerListEvent
GameEvent = GenerateEvent | PlayerEventType | RenderEvent | WorldEvent
# callback types
EventCallback = Callable[[GameEvent], None]

View File

@ -8,12 +8,24 @@ import urllib.request
import uuid
from logging import getLogger
from os import environ, path
from queue import Queue
from random import choice, randint
from threading import Thread
from typing import List
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
from PIL import Image
from adventure.context import broadcast
from adventure.models.event import (
ActionEvent,
GameEvent,
RenderEvent,
ReplyEvent,
ResultEvent,
StatusEvent,
)
logger = getLogger(__name__)
server_address = environ["COMFY_API"]
@ -195,6 +207,7 @@ def generate_images(
paths: List[str] = []
for j, image in enumerate(results):
# TODO: replace with environment variable
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
with open(image_path, "wb") as f:
image_bytes = io.BytesIO()
@ -206,6 +219,82 @@ def generate_images(
return paths
def prompt_from_event(event: GameEvent) -> str | None:
if isinstance(event, ActionEvent):
if event.item:
return f"{event.actor.name} uses the {event.item.name}. {event.item.description}. {event.actor.description}. {event.room.description}."
return f"{event.actor.name} {event.action}. {event.actor.description}. {event.room.description}."
if isinstance(event, ReplyEvent):
return event.text
if isinstance(event, ResultEvent):
return f"{event.result}. {event.actor.description}. {event.room.description}."
if isinstance(event, StatusEvent):
if event.room:
if event.actor:
return f"{event.text}. {event.actor.description}. {event.room.description}."
return f"{event.text}. {event.room.description}."
return event.text
return None
def prefix_from_event(event: GameEvent) -> str:
if isinstance(event, ActionEvent):
return (
f"{event.actor.name}-{event.action}-{event.item.name if event.item else ''}"
)
if isinstance(event, ReplyEvent):
return f"{event.actor.name}-reply"
if isinstance(event, ResultEvent):
return f"{event.actor.name}-result"
if isinstance(event, StatusEvent):
return "status"
return "unknown"
# requests to generate images for game events
render_queue: Queue[GameEvent] = Queue()
def render_loop():
while True:
event = render_queue.get()
prompt = prompt_from_event(event)
if prompt:
logger.info("rendering prompt for event %s: %s", event, prompt)
prefix = prefix_from_event(event)
image_paths = generate_images(prompt, 2, prefix=prefix)
broadcast(RenderEvent(paths=image_paths, source=event))
else:
logger.warning("no prompt for event %s", event)
def render_event(event: GameEvent):
render_queue.put(event)
render_thread = None
def launch_render():
global render_thread
render_thread = Thread(target=render_loop, daemon=True)
render_thread.start()
return [render_thread]
if __name__ == "__main__":
paths = generate_images(
"A painting of a beautiful sunset over a calm lake", 3, "landscape"

View File

@ -1,17 +1,26 @@
import asyncio
from base64 import b64encode
from collections import deque
from io import BytesIO
from json import dumps, loads
from logging import getLogger
from threading import Thread
from typing import Any, Dict, Literal
from typing import Any, Dict, Literal, MutableSequence
from uuid import uuid4
import websockets
from PIL import Image
from pydantic import RootModel
from adventure.context import broadcast, get_actor_agent_for_name, set_actor_agent
from adventure.models.entity import Actor, Item, Room, World
from adventure.models.event import GameEvent, PlayerEvent, PromptEvent
from adventure.models.event import (
GameEvent,
PlayerEvent,
PlayerListEvent,
PromptEvent,
RenderEvent,
)
from adventure.player import (
RemotePlayer,
get_player,
@ -20,12 +29,14 @@ from adventure.player import (
remove_player,
set_player,
)
from adventure.render_comfy import render_event
from adventure.state import snapshot_world, world_json
logger = getLogger(__name__)
connected = set()
recent_events = deque(maxlen=100)
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
recent_json: MutableSequence[str] = deque(maxlen=100)
last_snapshot = None
player_names: Dict[str, str] = {}
@ -61,12 +72,13 @@ async def handler(websocket):
return False
try:
await websocket.send(dumps({"type": "id", "id": id}))
await websocket.send(dumps({"type": "id", "client": id}))
if last_snapshot:
# TODO: only send this if the recent events don't contain a snapshot
if last_snapshot and last_snapshot not in recent_json:
await websocket.send(last_snapshot)
for message in recent_events:
for message in recent_json:
await websocket.send(message)
except Exception:
logger.exception("Failed to send recent messages to new client")
@ -141,8 +153,8 @@ async def handler(websocket):
set_actor_agent(actor.name, actor, player)
# notify all clients that this character is now active
player_event(character_name, player_name, "join")
player_list()
broadcast_player_event(character_name, player_name, "join")
broadcast_player_list()
elif message_type == "input":
player = get_player(id)
if player and isinstance(player, RemotePlayer):
@ -150,6 +162,13 @@ async def handler(websocket):
"queueing input for player %s: %s", player.name, data
)
player.input_queue.put(data["input"])
elif message_type == "render":
event_id = data["event"]
event = next((e for e in recent_events if e.id == event_id), None)
if event:
render_event(event)
else:
logger.error(f"Failed to find event {event_id}")
except Exception:
logger.exception("Failed to parse message")
@ -166,16 +185,16 @@ async def handler(websocket):
remove_player(id)
player_name = get_player_name(id)
logger.info("Disconnecting player %s from %s", player_name, player.name)
player_event(player.name, player_name, "leave")
player_list()
logger.info("disconnecting player %s from %s", player_name, player.name)
broadcast_player_event(player.name, player_name, "leave")
broadcast_player_list()
actor, _ = get_actor_agent_for_name(player.name)
if actor and player.fallback_agent:
logger.info("Restoring LLM agent for %s", player.name)
logger.info("restoring LLM agent for %s", player.name)
set_actor_agent(player.name, actor, player.fallback_agent)
logger.info("Client disconnected: %s", id)
logger.info("client disconnected: %s", id)
socket_thread = None
@ -188,9 +207,9 @@ def server_json(obj):
return world_json(obj)
def send_and_append(message):
def send_and_append(id: str, message: Dict):
json_message = dumps(message, default=server_json)
recent_events.append(json_message)
recent_json.append(json_message)
websockets.broadcast(connected, json_message)
return json_message
@ -215,28 +234,50 @@ async def server_main():
def server_system(world: World, step: int):
global last_snapshot
id = uuid4().hex # TODO: should a server be allowed to generate event IDs?
json_state = {
**snapshot_world(world, step),
"id": id,
"type": "snapshot",
}
last_snapshot = send_and_append(json_state)
last_snapshot = send_and_append(id, json_state)
def server_event(event: GameEvent):
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
json_event["type"] = event.type
send_and_append(json_event)
json_event.update(
{
"id": event.id,
"type": event.type,
}
)
if isinstance(event, RenderEvent):
# load and encode the images
image_paths = event.paths
image_data = {}
for path in image_paths:
with Image.open(path, "r") as image:
buffered = BytesIO()
image.save(
buffered, format="JPEG", quality=80, optimize=True, progressive=True
)
image_str = b64encode(buffered.getvalue())
image_data[path] = image_str.decode("utf-8")
json_event["images"] = image_data
recent_events.append(event)
send_and_append(event.id, json_event)
def player_event(character: str, client: str, status: Literal["join", "leave"]):
def broadcast_player_event(
character: str, client: str, status: Literal["join", "leave"]
):
event = PlayerEvent(status=status, character=character, client=client)
broadcast(event)
def player_list():
json_broadcast = {
"type": "players",
"players": list_players(),
}
# TODO: broadcast this
send_and_append(json_broadcast)
def broadcast_player_list():
event = PlayerListEvent(players=list_players())
broadcast(event)

View File

@ -19,7 +19,7 @@ import useWebSocketModule from 'react-use-websocket';
import { useStore } from 'zustand';
import { HistoryPanel } from './history.js';
import { Actor, Item, Room } from './models.js';
import { Actor, GameEvent, Item, Room } from './models.js';
import { PlayerPanel } from './player.js';
import { store, StoreState } from './store.js';
import { WorldPanel } from './world.js';
@ -93,6 +93,11 @@ export function App(props: AppProps) {
// socket stuff
const { lastMessage, readyState, sendMessage } = useWebSocket(props.socketUrl);
// socket senders
function renderEvent(event: string) {
sendMessage(JSON.stringify({ type: 'render', event }));
}
function setPlayer(actor: Maybe<Actor>) {
// do not call setCharacter until the server confirms the player change
if (doesExist(actor)) {
@ -179,7 +184,7 @@ export function App(props: AppProps) {
<WorldPanel setPlayer={setPlayer} />
</Stack>
<Stack direction="column" sx={{ minWidth: 600 }} className="scroll-history">
<HistoryPanel />
<HistoryPanel renderEvent={renderEvent} />
</Stack>
</Allotment>
</Stack>

View File

@ -1,21 +1,45 @@
import { ListItem, ListItemText, ListItemAvatar, Avatar, Typography } from '@mui/material';
import React, { MutableRefObject } from 'react';
import { Avatar, IconButton, ImageList, ImageListItem, ListItem, ListItemAvatar, ListItemText, Typography } from '@mui/material';
import React, { Fragment, MutableRefObject } from 'react';
import { Camera } from '@mui/icons-material';
import { formatters } from './format.js';
import { GameEvent } from './models.js';
export function openImage(image: string) {
const byteCharacters = atob(image);
const byteNumbers = new Array(byteCharacters.length);
for (let i = 0; i < byteCharacters.length; i++) {
byteNumbers[i] = byteCharacters.charCodeAt(i);
}
const byteArray = new Uint8Array(byteNumbers);
const file = new Blob([byteArray], { type: 'image/jpeg;base64' });
const fileURL = URL.createObjectURL(file);
window.open(fileURL, '_blank');
}
export interface EventItemProps {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
event: any;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
focusRef?: MutableRefObject<any>;
renderEvent: (event: GameEvent) => void;
}
export function ActionEventItem(props: EventItemProps) {
const { event } = props;
const { actor, room, type } = event;
const { event, renderEvent } = props;
const { id, actor, room, type } = event;
const content = formatters[type](event);
return <ListItem alignItems="flex-start" ref={props.focusRef}>
return <ListItem
alignItems="flex-start"
ref={props.focusRef}
secondaryAction={
<IconButton edge="end" aria-label="render" onClick={() => renderEvent(id)}>
<Camera />
</IconButton>
}
>
<ListItemAvatar>
<Avatar alt={actor.name} src="/static/images/avatar/1.jpg" />
</ListItemAvatar>
@ -41,23 +65,26 @@ export function ActionEventItem(props: EventItemProps) {
export function SnapshotEventItem(props: EventItemProps) {
const { event } = props;
const { step, world } = event;
const { theme } = world;
const { name, theme } = world;
return <ListItem alignItems="flex-start" ref={props.focusRef}>
<ListItemAvatar>
<Avatar alt={step.toString()} src="/static/images/avatar/1.jpg" />
</ListItemAvatar>
<ListItemText
primary={theme}
primary={name}
secondary={
<Fragment>
<Typography
sx={{ display: 'block' }}
component="span"
variant="body2"
color="text.primary"
>
Step {step}
Step: {step}
</Typography>
World Theme: {theme}
</Fragment>
}
/>
</ListItem>;
@ -102,7 +129,10 @@ export function PlayerEventItem(props: EventItemProps) {
secondary = `${client} has left the game. ${character} is now controlled by an LLM`;
}
return <ListItem alignItems="flex-start" ref={props.focusRef}>
return <ListItem
alignItems="flex-start"
ref={props.focusRef}
>
<ListItemAvatar>
<Avatar alt={character} src="/static/images/avatar/1.jpg" />
</ListItemAvatar>
@ -122,6 +152,25 @@ export function PlayerEventItem(props: EventItemProps) {
</ListItem>;
}
export function RenderEventItem(props: EventItemProps) {
const { event } = props;
const { images } = event;
return <ListItem alignItems="flex-start" ref={props.focusRef}>
<ListItemAvatar>
<Avatar alt="Render" src="/static/images/avatar/1.jpg" />
</ListItemAvatar>
<ListItemText
primary="Render"
secondary={<ImageList cols={3} rowHeight={256}>
{Object.entries(images).map(([name, image]) => <ImageListItem key={name}>
<img src={`data:image/jpeg;base64,${image}`} onClick={() => openImage(image)} alt="Render" />
</ImageListItem>)}
</ImageList>}
/>
</ListItem>;
}
export function EventItem(props: EventItemProps) {
const { event } = props;
const { type } = event;
@ -129,13 +178,16 @@ export function EventItem(props: EventItemProps) {
switch (type) {
case 'action':
case 'result':
return <ActionEventItem event={event} focusRef={props.focusRef} />;
return <ActionEventItem {...props} />;
case 'reply':
return <ReplyEventItem event={event} focusRef={props.focusRef} />;
case 'status': // TODO: should have a different component
return <ReplyEventItem {...props} />;
case 'player':
return <PlayerEventItem event={event} focusRef={props.focusRef} />;
return <PlayerEventItem {...props} />;
case 'render':
return <RenderEventItem {...props} />;
case 'snapshot':
return <SnapshotEventItem event={event} focusRef={props.focusRef} />;
return <SnapshotEventItem {...props} />;
default:
return <ListItem ref={props.focusRef}>
<ListItemText primary={`Unknown event type: ${type}`} />

View File

@ -3,7 +3,8 @@ import { Divider, List } from '@mui/material';
import React, { useEffect, useRef } from 'react';
import { useStore } from 'zustand';
import { EventItem } from './events';
import { store, StoreState } from './store';
import { GameEvent } from './models';
import { StoreState, store } from './store';
export function historyStateSelector(s: StoreState) {
return {
@ -12,9 +13,14 @@ export function historyStateSelector(s: StoreState) {
};
}
export function HistoryPanel() {
export interface HistoryPanelProps {
renderEvent: (event: GameEvent) => void;
}
export function HistoryPanel(props: HistoryPanelProps) {
const state = useStore(store, historyStateSelector);
const { history, scroll } = state;
const { renderEvent } = props;
const scrollRef = useRef<Maybe<Element>>(undefined);
@ -28,10 +34,10 @@ export function HistoryPanel() {
const items = history.map((item, index) => {
if (index === history.length - 1) {
return <EventItem key={`item-${index}`} event={item} focusRef={scrollRef} />;
return <EventItem key={`item-${index}`} event={item} focusRef={scrollRef} renderEvent={renderEvent} />;
}
return <EventItem key={`item-${index}`} event={item} />;
return <EventItem key={`item-${index}`} event={item} renderEvent={renderEvent} />;
});
return <List sx={{ width: '100%', bgcolor: 'background.paper' }}>