use events for images, show them in all servers
This commit is contained in:
parent
7010da8ed2
commit
593f3981d1
|
@ -3,6 +3,7 @@ from os import environ
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from re import sub
|
from re import sub
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from discord import Client, Embed, File, Intents
|
from discord import Client, Embed, File, Intents
|
||||||
|
|
||||||
|
@ -18,17 +19,19 @@ from adventure.models.event import (
|
||||||
GenerateEvent,
|
GenerateEvent,
|
||||||
PlayerEvent,
|
PlayerEvent,
|
||||||
PromptEvent,
|
PromptEvent,
|
||||||
|
RenderEvent,
|
||||||
ReplyEvent,
|
ReplyEvent,
|
||||||
ResultEvent,
|
ResultEvent,
|
||||||
StatusEvent,
|
StatusEvent,
|
||||||
)
|
)
|
||||||
from adventure.player import RemotePlayer, get_player, has_player, set_player
|
from adventure.player import RemotePlayer, get_player, has_player, set_player
|
||||||
from adventure.render_comfy import generate_image_tool
|
from adventure.render_comfy import render_event
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
client = None
|
client = None
|
||||||
|
|
||||||
active_tasks = set()
|
active_tasks = set()
|
||||||
|
event_messages: Dict[str, str | GameEvent] = {}
|
||||||
event_queue: Queue[GameEvent] = Queue()
|
event_queue: Queue[GameEvent] = Queue()
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,54 +43,6 @@ def remove_tags(text: str) -> str:
|
||||||
return sub(r"<[^>]*>", "", text)
|
return sub(r"<[^>]*>", "", text)
|
||||||
|
|
||||||
|
|
||||||
def find_embed_field(embed: Embed, name: str) -> str | None:
|
|
||||||
return next((field.value for field in embed.fields if field.name == name), None)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: becomes prompt from event
|
|
||||||
def prompt_from_embed(embed: Embed) -> str | None:
|
|
||||||
room_name = embed.title
|
|
||||||
actor_name = embed.description
|
|
||||||
|
|
||||||
world = get_current_world()
|
|
||||||
if not world:
|
|
||||||
return
|
|
||||||
|
|
||||||
room = next((room for room in world.rooms if room.name == room_name), None)
|
|
||||||
if not room:
|
|
||||||
return
|
|
||||||
|
|
||||||
actor = next((actor for actor in room.actors if actor.name == actor_name), None)
|
|
||||||
if not actor:
|
|
||||||
return
|
|
||||||
|
|
||||||
item_field = find_embed_field(embed, "Item")
|
|
||||||
|
|
||||||
action_field = find_embed_field(embed, "Action")
|
|
||||||
if action_field:
|
|
||||||
if item_field:
|
|
||||||
item = next(
|
|
||||||
(
|
|
||||||
item
|
|
||||||
for item in (room.items + actor.items)
|
|
||||||
if item.name == item_field
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if item:
|
|
||||||
return f"{actor.name} {action_field} the {item.name}. {item.description}. {actor.description}. {room.description}."
|
|
||||||
|
|
||||||
return f"{actor.name} {action_field} the {item_field}. {actor.description}. {room.description}."
|
|
||||||
|
|
||||||
return f"{actor.name} {action_field}. {actor.description}. {room.name}."
|
|
||||||
|
|
||||||
result_field = find_embed_field(embed, "Result")
|
|
||||||
if result_field:
|
|
||||||
return f"{result_field}. {actor.description}. {room.description}."
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class AdventureClient(Client):
|
class AdventureClient(Client):
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
logger.info(f"Logged in as {self.user}")
|
logger.info(f"Logged in as {self.user}")
|
||||||
|
@ -98,23 +53,16 @@ class AdventureClient(Client):
|
||||||
|
|
||||||
logger.info(f"Reaction added: {reaction} by {user}")
|
logger.info(f"Reaction added: {reaction} by {user}")
|
||||||
if reaction.emoji == "📷":
|
if reaction.emoji == "📷":
|
||||||
# message_id = reaction.message.id
|
message_id = reaction.message.id
|
||||||
# TODO: look up event that caused this message, get the room and actors
|
if message_id not in event_messages:
|
||||||
if len(reaction.message.embeds) > 0:
|
logger.warning(f"Message {message_id} not found in event messages")
|
||||||
embed = reaction.message.embeds[0]
|
# TODO: return error message
|
||||||
prompt = prompt_from_embed(embed)
|
return
|
||||||
else:
|
|
||||||
prompt = remove_tags(reaction.message.content)
|
|
||||||
if prompt.startswith("Generating"):
|
|
||||||
# TODO: get the entity from the message
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
event = event_messages[message_id]
|
||||||
|
if isinstance(event, GameEvent):
|
||||||
|
render_event(event)
|
||||||
await reaction.message.add_reaction("📸")
|
await reaction.message.add_reaction("📸")
|
||||||
paths = generate_image_tool(prompt, 2)
|
|
||||||
logger.info(f"Generated images: {paths}")
|
|
||||||
|
|
||||||
files = [File(filename) for filename in paths]
|
|
||||||
await reaction.message.channel.send(files=files, reference=reaction.message)
|
|
||||||
|
|
||||||
async def on_message(self, message):
|
async def on_message(self, message):
|
||||||
if message.author == self.user:
|
if message.author == self.user:
|
||||||
|
@ -277,17 +225,54 @@ async def broadcast_event(message: str | GameEvent):
|
||||||
|
|
||||||
for channel in active_channels:
|
for channel in active_channels:
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
logger.info("broadcasting to channel %s: %s", channel, message)
|
# deprecated, use events instead
|
||||||
await channel.send(content=message)
|
logger.warning(
|
||||||
elif isinstance(message, GameEvent):
|
"broadcasting non-event message to channel %s: %s", channel, message
|
||||||
|
)
|
||||||
|
event_message = await channel.send(content=message)
|
||||||
|
elif isinstance(message, RenderEvent):
|
||||||
|
# special handling to upload images
|
||||||
|
# find the source event
|
||||||
|
source_event_id = message.source.id
|
||||||
|
source_message_id = next(
|
||||||
|
(
|
||||||
|
message_id
|
||||||
|
for message_id, event in event_messages.items()
|
||||||
|
if isinstance(event, GameEvent) and event.id == source_event_id
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not source_message_id:
|
||||||
|
logger.warning("source event not found: %s", source_event_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# open and upload images
|
||||||
|
files = [File(filename) for filename in message.paths]
|
||||||
|
try:
|
||||||
|
source_message = await channel.fetch_message(source_message_id)
|
||||||
|
except Exception as err:
|
||||||
|
logger.warning("source message not found: %s", err)
|
||||||
|
return
|
||||||
|
|
||||||
|
# send the images as a reply to the source message
|
||||||
|
event_message = await source_message.channel.send(
|
||||||
|
files=files, reference=source_message
|
||||||
|
)
|
||||||
|
else:
|
||||||
embed = embed_from_event(message)
|
embed = embed_from_event(message)
|
||||||
|
if not embed:
|
||||||
|
logger.warning("no embed for event: %s", message)
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"broadcasting to channel %s: %s - %s",
|
"broadcasting to channel %s: %s - %s",
|
||||||
channel,
|
channel,
|
||||||
embed.title,
|
embed.title,
|
||||||
embed.description,
|
embed.description,
|
||||||
)
|
)
|
||||||
await channel.send(embed=embed)
|
event_message = await channel.send(embed=embed)
|
||||||
|
|
||||||
|
event_messages[event_message.id] = message
|
||||||
|
|
||||||
|
|
||||||
def embed_from_event(event: GameEvent) -> Embed:
|
def embed_from_event(event: GameEvent) -> Embed:
|
|
@ -84,6 +84,7 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--optional-actions", type=bool, help="Whether to include optional actions"
|
"--optional-actions", type=bool, help="Whether to include optional actions"
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--render", type=bool, help="Whether to render the simulation")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--server", type=str, help="The address on which to run the server"
|
"--server", type=str, help="The address on which to run the server"
|
||||||
)
|
)
|
||||||
|
@ -199,14 +200,20 @@ def main():
|
||||||
|
|
||||||
# launch other threads
|
# launch other threads
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
|
if args.render:
|
||||||
|
from adventure.render_comfy import launch_render
|
||||||
|
|
||||||
|
threads.extend(launch_render())
|
||||||
|
|
||||||
if args.discord:
|
if args.discord:
|
||||||
from adventure.discord_bot import bot_event, launch_bot
|
from adventure.bot_discord import bot_event, launch_bot
|
||||||
|
|
||||||
threads.extend(launch_bot())
|
threads.extend(launch_bot())
|
||||||
callbacks.append(bot_event)
|
callbacks.append(bot_event)
|
||||||
|
|
||||||
if args.server:
|
if args.server:
|
||||||
from adventure.server import launch_server, server_event, server_system
|
from adventure.server_socket import launch_server, server_event, server_system
|
||||||
|
|
||||||
threads.extend(launch_server())
|
threads.extend(launch_server())
|
||||||
callbacks.append(server_event)
|
callbacks.append(server_event)
|
||||||
|
|
|
@ -1,25 +1,33 @@
|
||||||
from json import loads
|
from json import loads
|
||||||
from typing import Any, Callable, Dict, List, Literal
|
from typing import Any, Callable, Dict, List, Literal
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from .base import dataclass
|
from .base import dataclass
|
||||||
from .entity import Actor, Item, Room, WorldEntity
|
from .entity import Actor, Item, Room, WorldEntity
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def uuid() -> str:
|
||||||
|
return uuid4().hex
|
||||||
|
|
||||||
|
|
||||||
class BaseEvent:
|
class BaseEvent:
|
||||||
"""
|
"""
|
||||||
A base event class.
|
A base event class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
type: str
|
type: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerateEvent:
|
class GenerateEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
A new entity has been generated.
|
A new entity has been generated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "generate"
|
type = "generate"
|
||||||
name: str
|
name: str
|
||||||
entity: WorldEntity | None = None
|
entity: WorldEntity | None = None
|
||||||
|
@ -34,11 +42,12 @@ class GenerateEvent:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActionEvent:
|
class ActionEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
An actor has taken an action.
|
An actor has taken an action.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "action"
|
type = "action"
|
||||||
action: str
|
action: str
|
||||||
parameters: Dict[str, bool | float | int | str]
|
parameters: Dict[str, bool | float | int | str]
|
||||||
|
@ -60,11 +69,12 @@ class ActionEvent:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptEvent:
|
class PromptEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
A prompt for an actor to take an action.
|
A prompt for an actor to take an action.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "prompt"
|
type = "prompt"
|
||||||
prompt: str
|
prompt: str
|
||||||
room: Room
|
room: Room
|
||||||
|
@ -72,13 +82,14 @@ class PromptEvent:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReplyEvent:
|
class ReplyEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
An actor has replied with text.
|
An actor has replied with text.
|
||||||
|
|
||||||
This is the non-JSON version of an ActionEvent.
|
This is the non-JSON version of an ActionEvent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "reply"
|
type = "reply"
|
||||||
text: str
|
text: str
|
||||||
room: Room
|
room: Room
|
||||||
|
@ -90,11 +101,12 @@ class ReplyEvent:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResultEvent:
|
class ResultEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
A result of an action.
|
A result of an action.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "result"
|
type = "result"
|
||||||
result: str
|
result: str
|
||||||
room: Room
|
room: Room
|
||||||
|
@ -102,11 +114,12 @@ class ResultEvent:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StatusEvent:
|
class StatusEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
A status broadcast event with text.
|
A status broadcast event with text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "status"
|
type = "status"
|
||||||
text: str
|
text: str
|
||||||
room: Room | None = None
|
room: Room | None = None
|
||||||
|
@ -114,7 +127,7 @@ class StatusEvent:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SnapshotEvent:
|
class SnapshotEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
A snapshot of the world state.
|
A snapshot of the world state.
|
||||||
|
|
||||||
|
@ -122,6 +135,7 @@ class SnapshotEvent:
|
||||||
That is especially important for the memory, which is a dictionary of actor names to lists of messages.
|
That is especially important for the memory, which is a dictionary of actor names to lists of messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "snapshot"
|
type = "snapshot"
|
||||||
world: Dict[str, Any]
|
world: Dict[str, Any]
|
||||||
memory: Dict[str, List[Any]]
|
memory: Dict[str, List[Any]]
|
||||||
|
@ -129,20 +143,45 @@ class SnapshotEvent:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlayerEvent:
|
class PlayerEvent(BaseEvent):
|
||||||
"""
|
"""
|
||||||
A player joining or leaving the game.
|
A player joining or leaving the game.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
type = "player"
|
type = "player"
|
||||||
status: Literal["join", "leave"]
|
status: Literal["join", "leave"]
|
||||||
character: str
|
character: str
|
||||||
client: str
|
client: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlayerListEvent(BaseEvent):
|
||||||
|
"""
|
||||||
|
A list of players in the game and the characters they are playing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
|
type = "players"
|
||||||
|
players: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RenderEvent(BaseEvent):
|
||||||
|
"""
|
||||||
|
Images have been rendered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id = Field(default_factory=uuid)
|
||||||
|
type = "render"
|
||||||
|
paths: List[str]
|
||||||
|
source: "GameEvent"
|
||||||
|
|
||||||
|
|
||||||
# event types
|
# event types
|
||||||
WorldEvent = ActionEvent | PromptEvent | ReplyEvent | ResultEvent | StatusEvent
|
WorldEvent = ActionEvent | PromptEvent | ReplyEvent | ResultEvent | StatusEvent
|
||||||
GameEvent = GenerateEvent | PlayerEvent | WorldEvent
|
PlayerEventType = PlayerEvent | PlayerListEvent
|
||||||
|
GameEvent = GenerateEvent | PlayerEventType | RenderEvent | WorldEvent
|
||||||
|
|
||||||
# callback types
|
# callback types
|
||||||
EventCallback = Callable[[GameEvent], None]
|
EventCallback = Callable[[GameEvent], None]
|
||||||
|
|
|
@ -8,12 +8,24 @@ import urllib.request
|
||||||
import uuid
|
import uuid
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
|
from queue import Queue
|
||||||
from random import choice, randint
|
from random import choice, randint
|
||||||
|
from threading import Thread
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from adventure.context import broadcast
|
||||||
|
from adventure.models.event import (
|
||||||
|
ActionEvent,
|
||||||
|
GameEvent,
|
||||||
|
RenderEvent,
|
||||||
|
ReplyEvent,
|
||||||
|
ResultEvent,
|
||||||
|
StatusEvent,
|
||||||
|
)
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
server_address = environ["COMFY_API"]
|
server_address = environ["COMFY_API"]
|
||||||
|
@ -195,6 +207,7 @@ def generate_images(
|
||||||
|
|
||||||
paths: List[str] = []
|
paths: List[str] = []
|
||||||
for j, image in enumerate(results):
|
for j, image in enumerate(results):
|
||||||
|
# TODO: replace with environment variable
|
||||||
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
|
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
|
||||||
with open(image_path, "wb") as f:
|
with open(image_path, "wb") as f:
|
||||||
image_bytes = io.BytesIO()
|
image_bytes = io.BytesIO()
|
||||||
|
@ -206,6 +219,82 @@ def generate_images(
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_from_event(event: GameEvent) -> str | None:
|
||||||
|
if isinstance(event, ActionEvent):
|
||||||
|
if event.item:
|
||||||
|
return f"{event.actor.name} uses the {event.item.name}. {event.item.description}. {event.actor.description}. {event.room.description}."
|
||||||
|
|
||||||
|
return f"{event.actor.name} {event.action}. {event.actor.description}. {event.room.description}."
|
||||||
|
|
||||||
|
if isinstance(event, ReplyEvent):
|
||||||
|
return event.text
|
||||||
|
|
||||||
|
if isinstance(event, ResultEvent):
|
||||||
|
return f"{event.result}. {event.actor.description}. {event.room.description}."
|
||||||
|
|
||||||
|
if isinstance(event, StatusEvent):
|
||||||
|
if event.room:
|
||||||
|
if event.actor:
|
||||||
|
return f"{event.text}. {event.actor.description}. {event.room.description}."
|
||||||
|
|
||||||
|
return f"{event.text}. {event.room.description}."
|
||||||
|
|
||||||
|
return event.text
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prefix_from_event(event: GameEvent) -> str:
|
||||||
|
if isinstance(event, ActionEvent):
|
||||||
|
return (
|
||||||
|
f"{event.actor.name}-{event.action}-{event.item.name if event.item else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(event, ReplyEvent):
|
||||||
|
return f"{event.actor.name}-reply"
|
||||||
|
|
||||||
|
if isinstance(event, ResultEvent):
|
||||||
|
return f"{event.actor.name}-result"
|
||||||
|
|
||||||
|
if isinstance(event, StatusEvent):
|
||||||
|
return "status"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# requests to generate images for game events
|
||||||
|
render_queue: Queue[GameEvent] = Queue()
|
||||||
|
|
||||||
|
|
||||||
|
def render_loop():
|
||||||
|
while True:
|
||||||
|
event = render_queue.get()
|
||||||
|
prompt = prompt_from_event(event)
|
||||||
|
if prompt:
|
||||||
|
logger.info("rendering prompt for event %s: %s", event, prompt)
|
||||||
|
prefix = prefix_from_event(event)
|
||||||
|
image_paths = generate_images(prompt, 2, prefix=prefix)
|
||||||
|
broadcast(RenderEvent(paths=image_paths, source=event))
|
||||||
|
else:
|
||||||
|
logger.warning("no prompt for event %s", event)
|
||||||
|
|
||||||
|
|
||||||
|
def render_event(event: GameEvent):
|
||||||
|
render_queue.put(event)
|
||||||
|
|
||||||
|
|
||||||
|
render_thread = None
|
||||||
|
|
||||||
|
|
||||||
|
def launch_render():
|
||||||
|
global render_thread
|
||||||
|
|
||||||
|
render_thread = Thread(target=render_loop, daemon=True)
|
||||||
|
render_thread.start()
|
||||||
|
|
||||||
|
return [render_thread]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
paths = generate_images(
|
paths = generate_images(
|
||||||
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
|
"A painting of a beautiful sunset over a calm lake", 3, "landscape"
|
||||||
|
|
|
@ -1,17 +1,26 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from base64 import b64encode
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from io import BytesIO
|
||||||
from json import dumps, loads
|
from json import dumps, loads
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, Literal
|
from typing import Any, Dict, Literal, MutableSequence
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
from PIL import Image
|
||||||
from pydantic import RootModel
|
from pydantic import RootModel
|
||||||
|
|
||||||
from adventure.context import broadcast, get_actor_agent_for_name, set_actor_agent
|
from adventure.context import broadcast, get_actor_agent_for_name, set_actor_agent
|
||||||
from adventure.models.entity import Actor, Item, Room, World
|
from adventure.models.entity import Actor, Item, Room, World
|
||||||
from adventure.models.event import GameEvent, PlayerEvent, PromptEvent
|
from adventure.models.event import (
|
||||||
|
GameEvent,
|
||||||
|
PlayerEvent,
|
||||||
|
PlayerListEvent,
|
||||||
|
PromptEvent,
|
||||||
|
RenderEvent,
|
||||||
|
)
|
||||||
from adventure.player import (
|
from adventure.player import (
|
||||||
RemotePlayer,
|
RemotePlayer,
|
||||||
get_player,
|
get_player,
|
||||||
|
@ -20,12 +29,14 @@ from adventure.player import (
|
||||||
remove_player,
|
remove_player,
|
||||||
set_player,
|
set_player,
|
||||||
)
|
)
|
||||||
|
from adventure.render_comfy import render_event
|
||||||
from adventure.state import snapshot_world, world_json
|
from adventure.state import snapshot_world, world_json
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
connected = set()
|
connected = set()
|
||||||
recent_events = deque(maxlen=100)
|
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
|
||||||
|
recent_json: MutableSequence[str] = deque(maxlen=100)
|
||||||
last_snapshot = None
|
last_snapshot = None
|
||||||
player_names: Dict[str, str] = {}
|
player_names: Dict[str, str] = {}
|
||||||
|
|
||||||
|
@ -61,12 +72,13 @@ async def handler(websocket):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await websocket.send(dumps({"type": "id", "id": id}))
|
await websocket.send(dumps({"type": "id", "client": id}))
|
||||||
|
|
||||||
if last_snapshot:
|
# TODO: only send this if the recent events don't contain a snapshot
|
||||||
|
if last_snapshot and last_snapshot not in recent_json:
|
||||||
await websocket.send(last_snapshot)
|
await websocket.send(last_snapshot)
|
||||||
|
|
||||||
for message in recent_events:
|
for message in recent_json:
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to send recent messages to new client")
|
logger.exception("Failed to send recent messages to new client")
|
||||||
|
@ -141,8 +153,8 @@ async def handler(websocket):
|
||||||
set_actor_agent(actor.name, actor, player)
|
set_actor_agent(actor.name, actor, player)
|
||||||
|
|
||||||
# notify all clients that this character is now active
|
# notify all clients that this character is now active
|
||||||
player_event(character_name, player_name, "join")
|
broadcast_player_event(character_name, player_name, "join")
|
||||||
player_list()
|
broadcast_player_list()
|
||||||
elif message_type == "input":
|
elif message_type == "input":
|
||||||
player = get_player(id)
|
player = get_player(id)
|
||||||
if player and isinstance(player, RemotePlayer):
|
if player and isinstance(player, RemotePlayer):
|
||||||
|
@ -150,6 +162,13 @@ async def handler(websocket):
|
||||||
"queueing input for player %s: %s", player.name, data
|
"queueing input for player %s: %s", player.name, data
|
||||||
)
|
)
|
||||||
player.input_queue.put(data["input"])
|
player.input_queue.put(data["input"])
|
||||||
|
elif message_type == "render":
|
||||||
|
event_id = data["event"]
|
||||||
|
event = next((e for e in recent_events if e.id == event_id), None)
|
||||||
|
if event:
|
||||||
|
render_event(event)
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to find event {event_id}")
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to parse message")
|
logger.exception("Failed to parse message")
|
||||||
|
@ -166,16 +185,16 @@ async def handler(websocket):
|
||||||
remove_player(id)
|
remove_player(id)
|
||||||
|
|
||||||
player_name = get_player_name(id)
|
player_name = get_player_name(id)
|
||||||
logger.info("Disconnecting player %s from %s", player_name, player.name)
|
logger.info("disconnecting player %s from %s", player_name, player.name)
|
||||||
player_event(player.name, player_name, "leave")
|
broadcast_player_event(player.name, player_name, "leave")
|
||||||
player_list()
|
broadcast_player_list()
|
||||||
|
|
||||||
actor, _ = get_actor_agent_for_name(player.name)
|
actor, _ = get_actor_agent_for_name(player.name)
|
||||||
if actor and player.fallback_agent:
|
if actor and player.fallback_agent:
|
||||||
logger.info("Restoring LLM agent for %s", player.name)
|
logger.info("restoring LLM agent for %s", player.name)
|
||||||
set_actor_agent(player.name, actor, player.fallback_agent)
|
set_actor_agent(player.name, actor, player.fallback_agent)
|
||||||
|
|
||||||
logger.info("Client disconnected: %s", id)
|
logger.info("client disconnected: %s", id)
|
||||||
|
|
||||||
|
|
||||||
socket_thread = None
|
socket_thread = None
|
||||||
|
@ -188,9 +207,9 @@ def server_json(obj):
|
||||||
return world_json(obj)
|
return world_json(obj)
|
||||||
|
|
||||||
|
|
||||||
def send_and_append(message):
|
def send_and_append(id: str, message: Dict):
|
||||||
json_message = dumps(message, default=server_json)
|
json_message = dumps(message, default=server_json)
|
||||||
recent_events.append(json_message)
|
recent_json.append(json_message)
|
||||||
websockets.broadcast(connected, json_message)
|
websockets.broadcast(connected, json_message)
|
||||||
return json_message
|
return json_message
|
||||||
|
|
||||||
|
@ -215,28 +234,50 @@ async def server_main():
|
||||||
|
|
||||||
def server_system(world: World, step: int):
|
def server_system(world: World, step: int):
|
||||||
global last_snapshot
|
global last_snapshot
|
||||||
|
id = uuid4().hex # TODO: should a server be allowed to generate event IDs?
|
||||||
json_state = {
|
json_state = {
|
||||||
**snapshot_world(world, step),
|
**snapshot_world(world, step),
|
||||||
|
"id": id,
|
||||||
"type": "snapshot",
|
"type": "snapshot",
|
||||||
}
|
}
|
||||||
last_snapshot = send_and_append(json_state)
|
last_snapshot = send_and_append(id, json_state)
|
||||||
|
|
||||||
|
|
||||||
def server_event(event: GameEvent):
|
def server_event(event: GameEvent):
|
||||||
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
|
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
|
||||||
json_event["type"] = event.type
|
json_event.update(
|
||||||
send_and_append(json_event)
|
{
|
||||||
|
"id": event.id,
|
||||||
|
"type": event.type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(event, RenderEvent):
|
||||||
|
# load and encode the images
|
||||||
|
image_paths = event.paths
|
||||||
|
image_data = {}
|
||||||
|
for path in image_paths:
|
||||||
|
with Image.open(path, "r") as image:
|
||||||
|
buffered = BytesIO()
|
||||||
|
image.save(
|
||||||
|
buffered, format="JPEG", quality=80, optimize=True, progressive=True
|
||||||
|
)
|
||||||
|
image_str = b64encode(buffered.getvalue())
|
||||||
|
image_data[path] = image_str.decode("utf-8")
|
||||||
|
|
||||||
|
json_event["images"] = image_data
|
||||||
|
|
||||||
|
recent_events.append(event)
|
||||||
|
send_and_append(event.id, json_event)
|
||||||
|
|
||||||
|
|
||||||
def player_event(character: str, client: str, status: Literal["join", "leave"]):
|
def broadcast_player_event(
|
||||||
|
character: str, client: str, status: Literal["join", "leave"]
|
||||||
|
):
|
||||||
event = PlayerEvent(status=status, character=character, client=client)
|
event = PlayerEvent(status=status, character=character, client=client)
|
||||||
broadcast(event)
|
broadcast(event)
|
||||||
|
|
||||||
|
|
||||||
def player_list():
|
def broadcast_player_list():
|
||||||
json_broadcast = {
|
event = PlayerListEvent(players=list_players())
|
||||||
"type": "players",
|
broadcast(event)
|
||||||
"players": list_players(),
|
|
||||||
}
|
|
||||||
# TODO: broadcast this
|
|
||||||
send_and_append(json_broadcast)
|
|
|
@ -19,7 +19,7 @@ import useWebSocketModule from 'react-use-websocket';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { HistoryPanel } from './history.js';
|
import { HistoryPanel } from './history.js';
|
||||||
import { Actor, Item, Room } from './models.js';
|
import { Actor, GameEvent, Item, Room } from './models.js';
|
||||||
import { PlayerPanel } from './player.js';
|
import { PlayerPanel } from './player.js';
|
||||||
import { store, StoreState } from './store.js';
|
import { store, StoreState } from './store.js';
|
||||||
import { WorldPanel } from './world.js';
|
import { WorldPanel } from './world.js';
|
||||||
|
@ -93,6 +93,11 @@ export function App(props: AppProps) {
|
||||||
// socket stuff
|
// socket stuff
|
||||||
const { lastMessage, readyState, sendMessage } = useWebSocket(props.socketUrl);
|
const { lastMessage, readyState, sendMessage } = useWebSocket(props.socketUrl);
|
||||||
|
|
||||||
|
// socket senders
|
||||||
|
function renderEvent(event: string) {
|
||||||
|
sendMessage(JSON.stringify({ type: 'render', event }));
|
||||||
|
}
|
||||||
|
|
||||||
function setPlayer(actor: Maybe<Actor>) {
|
function setPlayer(actor: Maybe<Actor>) {
|
||||||
// do not call setCharacter until the server confirms the player change
|
// do not call setCharacter until the server confirms the player change
|
||||||
if (doesExist(actor)) {
|
if (doesExist(actor)) {
|
||||||
|
@ -179,7 +184,7 @@ export function App(props: AppProps) {
|
||||||
<WorldPanel setPlayer={setPlayer} />
|
<WorldPanel setPlayer={setPlayer} />
|
||||||
</Stack>
|
</Stack>
|
||||||
<Stack direction="column" sx={{ minWidth: 600 }} className="scroll-history">
|
<Stack direction="column" sx={{ minWidth: 600 }} className="scroll-history">
|
||||||
<HistoryPanel />
|
<HistoryPanel renderEvent={renderEvent} />
|
||||||
</Stack>
|
</Stack>
|
||||||
</Allotment>
|
</Allotment>
|
||||||
</Stack>
|
</Stack>
|
||||||
|
|
|
@ -1,21 +1,45 @@
|
||||||
import { ListItem, ListItemText, ListItemAvatar, Avatar, Typography } from '@mui/material';
|
import { Avatar, IconButton, ImageList, ImageListItem, ListItem, ListItemAvatar, ListItemText, Typography } from '@mui/material';
|
||||||
import React, { MutableRefObject } from 'react';
|
import React, { Fragment, MutableRefObject } from 'react';
|
||||||
|
|
||||||
|
import { Camera } from '@mui/icons-material';
|
||||||
import { formatters } from './format.js';
|
import { formatters } from './format.js';
|
||||||
|
import { GameEvent } from './models.js';
|
||||||
|
|
||||||
|
export function openImage(image: string) {
|
||||||
|
const byteCharacters = atob(image);
|
||||||
|
const byteNumbers = new Array(byteCharacters.length);
|
||||||
|
for (let i = 0; i < byteCharacters.length; i++) {
|
||||||
|
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
||||||
|
}
|
||||||
|
const byteArray = new Uint8Array(byteNumbers);
|
||||||
|
const file = new Blob([byteArray], { type: 'image/jpeg;base64' });
|
||||||
|
const fileURL = URL.createObjectURL(file);
|
||||||
|
window.open(fileURL, '_blank');
|
||||||
|
}
|
||||||
|
|
||||||
export interface EventItemProps {
|
export interface EventItemProps {
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
event: any;
|
event: any;
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
focusRef?: MutableRefObject<any>;
|
focusRef?: MutableRefObject<any>;
|
||||||
|
|
||||||
|
renderEvent: (event: GameEvent) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ActionEventItem(props: EventItemProps) {
|
export function ActionEventItem(props: EventItemProps) {
|
||||||
const { event } = props;
|
const { event, renderEvent } = props;
|
||||||
const { actor, room, type } = event;
|
const { id, actor, room, type } = event;
|
||||||
const content = formatters[type](event);
|
const content = formatters[type](event);
|
||||||
|
|
||||||
return <ListItem alignItems="flex-start" ref={props.focusRef}>
|
return <ListItem
|
||||||
|
alignItems="flex-start"
|
||||||
|
ref={props.focusRef}
|
||||||
|
secondaryAction={
|
||||||
|
<IconButton edge="end" aria-label="render" onClick={() => renderEvent(id)}>
|
||||||
|
<Camera />
|
||||||
|
</IconButton>
|
||||||
|
}
|
||||||
|
>
|
||||||
<ListItemAvatar>
|
<ListItemAvatar>
|
||||||
<Avatar alt={actor.name} src="/static/images/avatar/1.jpg" />
|
<Avatar alt={actor.name} src="/static/images/avatar/1.jpg" />
|
||||||
</ListItemAvatar>
|
</ListItemAvatar>
|
||||||
|
@ -41,23 +65,26 @@ export function ActionEventItem(props: EventItemProps) {
|
||||||
export function SnapshotEventItem(props: EventItemProps) {
|
export function SnapshotEventItem(props: EventItemProps) {
|
||||||
const { event } = props;
|
const { event } = props;
|
||||||
const { step, world } = event;
|
const { step, world } = event;
|
||||||
const { theme } = world;
|
const { name, theme } = world;
|
||||||
|
|
||||||
return <ListItem alignItems="flex-start" ref={props.focusRef}>
|
return <ListItem alignItems="flex-start" ref={props.focusRef}>
|
||||||
<ListItemAvatar>
|
<ListItemAvatar>
|
||||||
<Avatar alt={step.toString()} src="/static/images/avatar/1.jpg" />
|
<Avatar alt={step.toString()} src="/static/images/avatar/1.jpg" />
|
||||||
</ListItemAvatar>
|
</ListItemAvatar>
|
||||||
<ListItemText
|
<ListItemText
|
||||||
primary={theme}
|
primary={name}
|
||||||
secondary={
|
secondary={
|
||||||
|
<Fragment>
|
||||||
<Typography
|
<Typography
|
||||||
sx={{ display: 'block' }}
|
sx={{ display: 'block' }}
|
||||||
component="span"
|
component="span"
|
||||||
variant="body2"
|
variant="body2"
|
||||||
color="text.primary"
|
color="text.primary"
|
||||||
>
|
>
|
||||||
Step {step}
|
Step: {step}
|
||||||
</Typography>
|
</Typography>
|
||||||
|
World Theme: {theme}
|
||||||
|
</Fragment>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
</ListItem>;
|
</ListItem>;
|
||||||
|
@ -102,7 +129,10 @@ export function PlayerEventItem(props: EventItemProps) {
|
||||||
secondary = `${client} has left the game. ${character} is now controlled by an LLM`;
|
secondary = `${client} has left the game. ${character} is now controlled by an LLM`;
|
||||||
}
|
}
|
||||||
|
|
||||||
return <ListItem alignItems="flex-start" ref={props.focusRef}>
|
return <ListItem
|
||||||
|
alignItems="flex-start"
|
||||||
|
ref={props.focusRef}
|
||||||
|
>
|
||||||
<ListItemAvatar>
|
<ListItemAvatar>
|
||||||
<Avatar alt={character} src="/static/images/avatar/1.jpg" />
|
<Avatar alt={character} src="/static/images/avatar/1.jpg" />
|
||||||
</ListItemAvatar>
|
</ListItemAvatar>
|
||||||
|
@ -122,6 +152,25 @@ export function PlayerEventItem(props: EventItemProps) {
|
||||||
</ListItem>;
|
</ListItem>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function RenderEventItem(props: EventItemProps) {
|
||||||
|
const { event } = props;
|
||||||
|
const { images } = event;
|
||||||
|
|
||||||
|
return <ListItem alignItems="flex-start" ref={props.focusRef}>
|
||||||
|
<ListItemAvatar>
|
||||||
|
<Avatar alt="Render" src="/static/images/avatar/1.jpg" />
|
||||||
|
</ListItemAvatar>
|
||||||
|
<ListItemText
|
||||||
|
primary="Render"
|
||||||
|
secondary={<ImageList cols={3} rowHeight={256}>
|
||||||
|
{Object.entries(images).map(([name, image]) => <ImageListItem key={name}>
|
||||||
|
<img src={`data:image/jpeg;base64,${image}`} onClick={() => openImage(image)} alt="Render" />
|
||||||
|
</ImageListItem>)}
|
||||||
|
</ImageList>}
|
||||||
|
/>
|
||||||
|
</ListItem>;
|
||||||
|
}
|
||||||
|
|
||||||
export function EventItem(props: EventItemProps) {
|
export function EventItem(props: EventItemProps) {
|
||||||
const { event } = props;
|
const { event } = props;
|
||||||
const { type } = event;
|
const { type } = event;
|
||||||
|
@ -129,13 +178,16 @@ export function EventItem(props: EventItemProps) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case 'action':
|
case 'action':
|
||||||
case 'result':
|
case 'result':
|
||||||
return <ActionEventItem event={event} focusRef={props.focusRef} />;
|
return <ActionEventItem {...props} />;
|
||||||
case 'reply':
|
case 'reply':
|
||||||
return <ReplyEventItem event={event} focusRef={props.focusRef} />;
|
case 'status': // TODO: should have a different component
|
||||||
|
return <ReplyEventItem {...props} />;
|
||||||
case 'player':
|
case 'player':
|
||||||
return <PlayerEventItem event={event} focusRef={props.focusRef} />;
|
return <PlayerEventItem {...props} />;
|
||||||
|
case 'render':
|
||||||
|
return <RenderEventItem {...props} />;
|
||||||
case 'snapshot':
|
case 'snapshot':
|
||||||
return <SnapshotEventItem event={event} focusRef={props.focusRef} />;
|
return <SnapshotEventItem {...props} />;
|
||||||
default:
|
default:
|
||||||
return <ListItem ref={props.focusRef}>
|
return <ListItem ref={props.focusRef}>
|
||||||
<ListItemText primary={`Unknown event type: ${type}`} />
|
<ListItemText primary={`Unknown event type: ${type}`} />
|
||||||
|
|
|
@ -3,7 +3,8 @@ import { Divider, List } from '@mui/material';
|
||||||
import React, { useEffect, useRef } from 'react';
|
import React, { useEffect, useRef } from 'react';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
import { EventItem } from './events';
|
import { EventItem } from './events';
|
||||||
import { store, StoreState } from './store';
|
import { GameEvent } from './models';
|
||||||
|
import { StoreState, store } from './store';
|
||||||
|
|
||||||
export function historyStateSelector(s: StoreState) {
|
export function historyStateSelector(s: StoreState) {
|
||||||
return {
|
return {
|
||||||
|
@ -12,9 +13,14 @@ export function historyStateSelector(s: StoreState) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function HistoryPanel() {
|
export interface HistoryPanelProps {
|
||||||
|
renderEvent: (event: GameEvent) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function HistoryPanel(props: HistoryPanelProps) {
|
||||||
const state = useStore(store, historyStateSelector);
|
const state = useStore(store, historyStateSelector);
|
||||||
const { history, scroll } = state;
|
const { history, scroll } = state;
|
||||||
|
const { renderEvent } = props;
|
||||||
|
|
||||||
const scrollRef = useRef<Maybe<Element>>(undefined);
|
const scrollRef = useRef<Maybe<Element>>(undefined);
|
||||||
|
|
||||||
|
@ -28,10 +34,10 @@ export function HistoryPanel() {
|
||||||
|
|
||||||
const items = history.map((item, index) => {
|
const items = history.map((item, index) => {
|
||||||
if (index === history.length - 1) {
|
if (index === history.length - 1) {
|
||||||
return <EventItem key={`item-${index}`} event={item} focusRef={scrollRef} />;
|
return <EventItem key={`item-${index}`} event={item} focusRef={scrollRef} renderEvent={renderEvent} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return <EventItem key={`item-${index}`} event={item} />;
|
return <EventItem key={`item-${index}`} event={item} renderEvent={renderEvent} />;
|
||||||
});
|
});
|
||||||
|
|
||||||
return <List sx={{ width: '100%', bgcolor: 'background.paper' }}>
|
return <List sx={{ width: '100%', bgcolor: 'background.paper' }}>
|
||||||
|
|
Loading…
Reference in New Issue