1
0
Fork 0
taleweave-ai/adventure/server/websocket.py

337 lines
11 KiB
Python
Raw Normal View History

2024-05-04 04:18:21 +00:00
import asyncio
from base64 import b64encode
2024-05-04 20:35:42 +00:00
from collections import deque
from io import BytesIO
2024-05-05 14:14:54 +00:00
from json import dumps, loads
2024-05-04 04:18:21 +00:00
from logging import getLogger
from threading import Thread
from typing import Any, Dict, Literal, MutableSequence
from uuid import uuid4
2024-05-04 04:18:21 +00:00
import websockets
from PIL import Image
from pydantic import RootModel
2024-05-04 04:18:21 +00:00
from adventure.context import (
broadcast,
get_actor_agent_for_name,
get_current_world,
set_actor_agent,
2024-05-18 21:58:11 +00:00
subscribe,
)
2024-05-18 21:58:11 +00:00
from adventure.models.config import DEFAULT_CONFIG, WebsocketServerConfig
from adventure.models.entity import Actor, Item, Room, World
from adventure.models.event import (
GameEvent,
PlayerEvent,
PlayerListEvent,
PromptEvent,
RenderEvent,
)
2024-05-08 01:42:10 +00:00
from adventure.player import (
RemotePlayer,
get_player,
has_player,
list_players,
remove_player,
set_player,
)
from adventure.render.comfy import render_entity, render_event
2024-05-04 04:18:21 +00:00
from adventure.state import snapshot_world, world_json
from adventure.utils.search import find_actor, find_item, find_room
2024-05-04 04:18:21 +00:00
logger = getLogger(__name__)
connected = set()
last_snapshot: str | None = None
player_names: Dict[str, str] = {}
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
recent_json: MutableSequence[str] = deque(maxlen=100)
2024-05-18 21:58:11 +00:00
server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket
def get_player_name(client_id: str) -> str:
return player_names.get(client_id, client_id)
2024-05-04 04:18:21 +00:00
async def handler(websocket):
id = uuid4().hex
logger.info("client connected, given id: %s", id)
2024-05-04 04:18:21 +00:00
connected.add(websocket)
2024-05-04 20:35:42 +00:00
2024-05-05 14:14:54 +00:00
async def next_turn(character: str, prompt: str) -> None:
await websocket.send(
dumps(
{
# TODO: these should be fields in the PromptEvent
"type": "prompt",
"client": id,
"character": character,
"prompt": prompt,
"actions": [],
}
),
)
2024-05-05 14:14:54 +00:00
2024-05-09 02:11:16 +00:00
def sync_turn(event: PromptEvent) -> bool:
# TODO: nothing about this is good
2024-05-08 01:42:10 +00:00
player = get_player(id)
2024-05-09 02:11:16 +00:00
if player and player.name == event.actor.name:
asyncio.run(next_turn(event.actor.name, event.prompt))
2024-05-08 01:42:10 +00:00
return True
2024-05-05 14:14:54 +00:00
2024-05-08 01:42:10 +00:00
return False
2024-05-05 14:14:54 +00:00
2024-05-04 20:35:42 +00:00
try:
await websocket.send(dumps({"type": "id", "client": id}))
# only send the snapshot once
if last_snapshot and last_snapshot not in recent_json:
2024-05-09 02:11:16 +00:00
await websocket.send(last_snapshot)
2024-05-04 20:35:42 +00:00
for message in recent_json:
2024-05-04 20:35:42 +00:00
await websocket.send(message)
except Exception:
logger.exception("failed to send recent messages to new client")
2024-05-04 20:35:42 +00:00
2024-05-04 04:18:21 +00:00
while True:
try:
2024-05-05 14:14:54 +00:00
# if this socket is attached to a character and that character's turn is active, wait for input
2024-05-04 04:18:21 +00:00
message = await websocket.recv()
player_name = get_player_name(id)
logger.info(f"received message for {player_name}: {message}")
2024-05-05 14:14:54 +00:00
try:
data = loads(message)
message_type = data.get("type", None)
if message_type == "player":
if "name" in data:
new_player_name = data["name"]
existing_id = next(
(
k
for k, v in player_names.items()
if v == new_player_name
),
None,
2024-05-05 22:46:24 +00:00
)
if existing_id is not None:
logger.error(
f"name {new_player_name} is already in use by {existing_id}"
)
continue
2024-05-05 22:46:24 +00:00
logger.info(
f"changing player name for {id} to {new_player_name}"
)
player_names[id] = new_player_name
elif "become" in data:
character_name = data["become"]
if has_player(character_name):
logger.error(
f"character {character_name} is already in use"
)
continue
# TODO: should this always remove?
remove_player(id)
actor, llm_agent = get_actor_agent_for_name(character_name)
if not actor:
logger.error(f"Failed to find actor {character_name}")
continue
# prevent any recursive fallback bugs
if isinstance(llm_agent, RemotePlayer):
logger.warning(
"patching recursive fallback for %s", character_name
)
llm_agent = llm_agent.fallback_agent
player = RemotePlayer(
actor.name,
actor.backstory,
sync_turn,
fallback_agent=llm_agent,
)
set_player(id, player)
logger.info(
f"client {player_name} is now character {character_name}"
)
# swap out the LLM agent
set_actor_agent(actor.name, actor, player)
# notify all clients that this character is now active
broadcast_player_event(character_name, player_name, "join")
broadcast_player_list()
2024-05-08 01:42:10 +00:00
elif message_type == "input":
player = get_player(id)
if player and isinstance(player, RemotePlayer):
logger.info(
"queueing input for player %s: %s", player.name, data
)
player.input_queue.put(data["input"])
elif message_type == "render":
render_input(data)
2024-05-05 14:14:54 +00:00
except Exception:
logger.exception("failed to parse message")
2024-05-04 04:18:21 +00:00
except websockets.ConnectionClosedOK:
break
connected.remove(websocket)
if id in player_names:
del player_names[id]
2024-05-05 14:14:54 +00:00
# swap out the character for the original agent when they disconnect
2024-05-08 01:42:10 +00:00
player = get_player(id)
if player and isinstance(player, RemotePlayer):
remove_player(id)
player_name = get_player_name(id)
logger.info("disconnecting player %s from %s", player_name, player.name)
broadcast_player_event(player.name, player_name, "leave")
broadcast_player_list()
actor, _ = get_actor_agent_for_name(player.name)
if actor and player.fallback_agent:
logger.info("restoring LLM agent for %s", player.name)
2024-05-09 02:11:16 +00:00
set_actor_agent(player.name, actor, player.fallback_agent)
2024-05-05 14:14:54 +00:00
logger.info("client disconnected: %s", id)
2024-05-04 04:18:21 +00:00
def render_input(data):
world = get_current_world()
if not world:
logger.error("no world available")
return
if "event" in data:
event_id = data["event"]
event = next((e for e in recent_events if e.id == event_id), None)
if event:
render_event(event)
else:
logger.error(f"failed to find event {event_id}")
elif "actor" in data:
actor_name = data["actor"]
actor = find_actor(world, actor_name)
if actor:
render_entity(actor)
else:
logger.error(f"failed to find actor {actor_name}")
elif "room" in data:
room_name = data["room"]
room = find_room(world, room_name)
if room:
render_entity(room)
else:
logger.error(f"failed to find room {room_name}")
elif "item" in data:
item_name = data["item"]
item = find_item(
world, item_name, include_actor_inventory=True, include_item_inventory=True
)
if item:
render_entity(item)
else:
logger.error(f"failed to find item {item_name}")
else:
logger.error(f"failed to find entity in {data}")
2024-05-04 04:18:21 +00:00
socket_thread = None
2024-05-04 20:35:42 +00:00
def server_json(obj):
if isinstance(obj, (Actor, Item, Room)):
2024-05-04 20:35:42 +00:00
return obj.name
return world_json(obj)
def send_and_append(id: str, message: Dict):
2024-05-04 20:35:42 +00:00
json_message = dumps(message, default=server_json)
recent_json.append(json_message)
2024-05-04 20:35:42 +00:00
websockets.broadcast(connected, json_message)
return json_message
2024-05-18 21:58:11 +00:00
def launch_server(config: WebsocketServerConfig):
global socket_thread
2024-05-18 21:58:11 +00:00
global server_config
logger.info("configuring websocket server: %s", config)
server_config = config
2024-05-04 04:18:21 +00:00
def run_sockets():
asyncio.run(server_main())
logger.info("launching websocket server")
2024-05-09 02:11:16 +00:00
socket_thread = Thread(target=run_sockets, daemon=True)
2024-05-04 04:18:21 +00:00
socket_thread.start()
2024-05-18 21:58:11 +00:00
subscribe(GameEvent, server_event)
2024-05-09 02:11:16 +00:00
return [socket_thread]
2024-05-04 04:18:21 +00:00
async def server_main():
async with websockets.serve(handler, server_config.host, server_config.port):
logger.info("websocket server started")
2024-05-04 04:18:21 +00:00
await asyncio.Future() # run forever
def server_system(world: World, step: int):
2024-05-09 02:11:16 +00:00
global last_snapshot
id = uuid4().hex # TODO: should a server be allowed to generate event IDs?
2024-05-04 04:18:21 +00:00
json_state = {
**snapshot_world(world, step),
"id": id,
"type": "snapshot",
2024-05-04 04:18:21 +00:00
}
last_snapshot = send_and_append(id, json_state)
2024-05-04 04:18:21 +00:00
def server_event(event: GameEvent):
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
json_event.update(
{
"id": event.id,
"type": event.type,
}
)
if isinstance(event, RenderEvent):
# load and encode the images
image_paths = event.paths
image_data = {}
for path in image_paths:
with Image.open(path, "r") as image:
buffered = BytesIO()
image.save(
buffered, format="JPEG", quality=80, optimize=True, progressive=True
)
image_str = b64encode(buffered.getvalue())
image_data[path] = image_str.decode("utf-8")
json_event["images"] = image_data
recent_events.append(event)
send_and_append(event.id, json_event)
def broadcast_player_event(
character: str, client: str, status: Literal["join", "leave"]
):
event = PlayerEvent(status=status, character=character, client=client)
broadcast(event)
def broadcast_player_list():
event = PlayerListEvent(players=list_players())
broadcast(event)