2024-05-04 04:18:21 +00:00
|
|
|
import asyncio
|
2024-05-12 05:08:53 +00:00
|
|
|
from base64 import b64encode
|
2024-05-04 20:35:42 +00:00
|
|
|
from collections import deque
|
2024-05-12 05:08:53 +00:00
|
|
|
from io import BytesIO
|
2024-05-05 14:14:54 +00:00
|
|
|
from json import dumps, loads
|
2024-05-04 04:18:21 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from threading import Thread
|
2024-05-12 05:08:53 +00:00
|
|
|
from typing import Any, Dict, Literal, MutableSequence
|
2024-05-05 18:54:39 +00:00
|
|
|
from uuid import uuid4
|
2024-05-04 04:18:21 +00:00
|
|
|
|
|
|
|
import websockets
|
2024-05-12 05:08:53 +00:00
|
|
|
from PIL import Image
|
2024-05-10 04:45:10 +00:00
|
|
|
from pydantic import RootModel
|
2024-05-04 04:18:21 +00:00
|
|
|
|
2024-05-27 13:10:24 +00:00
|
|
|
from taleweave.context import (
|
2024-05-12 20:47:18 +00:00
|
|
|
broadcast,
|
2024-05-27 01:32:03 +00:00
|
|
|
get_character_agent_for_name,
|
2024-05-12 20:47:18 +00:00
|
|
|
get_current_world,
|
2024-05-27 01:32:03 +00:00
|
|
|
set_character_agent,
|
2024-05-18 21:58:11 +00:00
|
|
|
subscribe,
|
2024-05-12 20:47:18 +00:00
|
|
|
)
|
2024-05-27 13:10:24 +00:00
|
|
|
from taleweave.models.config import DEFAULT_CONFIG, WebsocketServerConfig
|
|
|
|
from taleweave.models.entity import Character, Item, Room, World
|
|
|
|
from taleweave.models.event import (
|
2024-05-12 05:08:53 +00:00
|
|
|
GameEvent,
|
|
|
|
PlayerEvent,
|
|
|
|
PlayerListEvent,
|
|
|
|
PromptEvent,
|
|
|
|
RenderEvent,
|
|
|
|
)
|
2024-05-27 13:10:24 +00:00
|
|
|
from taleweave.player import (
|
2024-05-08 01:42:10 +00:00
|
|
|
RemotePlayer,
|
|
|
|
get_player,
|
|
|
|
has_player,
|
|
|
|
list_players,
|
|
|
|
remove_player,
|
|
|
|
set_player,
|
|
|
|
)
|
2024-05-27 13:10:24 +00:00
|
|
|
from taleweave.render.comfy import render_entity, render_event
|
|
|
|
from taleweave.state import snapshot_world, world_json
|
|
|
|
from taleweave.utils.search import find_character, find_item, find_portal, find_room
|
2024-05-04 04:18:21 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
connected = set()
|
2024-05-12 20:47:18 +00:00
|
|
|
last_snapshot: str | None = None
|
|
|
|
player_names: Dict[str, str] = {}
|
2024-05-12 05:08:53 +00:00
|
|
|
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
|
|
|
|
recent_json: MutableSequence[str] = deque(maxlen=100)
|
2024-05-18 21:58:11 +00:00
|
|
|
server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket
|
2024-05-11 10:17:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_player_name(client_id: str) -> str:
|
|
|
|
return player_names.get(client_id, client_id)
|
2024-05-04 04:18:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def handler(websocket):
|
2024-05-05 18:54:39 +00:00
|
|
|
id = uuid4().hex
|
2024-05-12 20:47:18 +00:00
|
|
|
logger.info("client connected, given id: %s", id)
|
2024-05-04 04:18:21 +00:00
|
|
|
connected.add(websocket)
|
2024-05-04 20:35:42 +00:00
|
|
|
|
2024-05-05 14:14:54 +00:00
|
|
|
async def next_turn(character: str, prompt: str) -> None:
|
2024-05-05 18:54:39 +00:00
|
|
|
await websocket.send(
|
|
|
|
dumps(
|
|
|
|
{
|
2024-05-12 20:47:18 +00:00
|
|
|
# TODO: these should be fields in the PromptEvent
|
2024-05-05 18:54:39 +00:00
|
|
|
"type": "prompt",
|
2024-05-10 04:45:10 +00:00
|
|
|
"client": id,
|
2024-05-05 18:54:39 +00:00
|
|
|
"character": character,
|
|
|
|
"prompt": prompt,
|
|
|
|
"actions": [],
|
|
|
|
}
|
|
|
|
),
|
|
|
|
)
|
2024-05-05 14:14:54 +00:00
|
|
|
|
2024-05-09 02:11:16 +00:00
|
|
|
def sync_turn(event: PromptEvent) -> bool:
|
2024-05-12 20:47:18 +00:00
|
|
|
# TODO: nothing about this is good
|
2024-05-08 01:42:10 +00:00
|
|
|
player = get_player(id)
|
2024-05-27 01:32:03 +00:00
|
|
|
if player and player.name == event.character.name:
|
|
|
|
asyncio.run(next_turn(event.character.name, event.prompt))
|
2024-05-08 01:42:10 +00:00
|
|
|
return True
|
2024-05-05 14:14:54 +00:00
|
|
|
|
2024-05-08 01:42:10 +00:00
|
|
|
return False
|
2024-05-05 14:14:54 +00:00
|
|
|
|
2024-05-04 20:35:42 +00:00
|
|
|
try:
|
2024-05-12 05:08:53 +00:00
|
|
|
await websocket.send(dumps({"type": "id", "client": id}))
|
2024-05-05 18:54:39 +00:00
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
# only send the snapshot once
|
2024-05-12 05:08:53 +00:00
|
|
|
if last_snapshot and last_snapshot not in recent_json:
|
2024-05-09 02:11:16 +00:00
|
|
|
await websocket.send(last_snapshot)
|
2024-05-04 20:35:42 +00:00
|
|
|
|
2024-05-12 05:08:53 +00:00
|
|
|
for message in recent_json:
|
2024-05-04 20:35:42 +00:00
|
|
|
await websocket.send(message)
|
|
|
|
except Exception:
|
2024-05-12 20:47:18 +00:00
|
|
|
logger.exception("failed to send recent messages to new client")
|
2024-05-04 20:35:42 +00:00
|
|
|
|
2024-05-04 04:18:21 +00:00
|
|
|
while True:
|
|
|
|
try:
|
2024-05-05 14:14:54 +00:00
|
|
|
# if this socket is attached to a character and that character's turn is active, wait for input
|
2024-05-04 04:18:21 +00:00
|
|
|
message = await websocket.recv()
|
2024-05-11 10:17:03 +00:00
|
|
|
player_name = get_player_name(id)
|
2024-05-12 20:47:18 +00:00
|
|
|
logger.info(f"received message for {player_name}: {message}")
|
2024-05-05 14:14:54 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
data = loads(message)
|
2024-05-05 18:54:39 +00:00
|
|
|
message_type = data.get("type", None)
|
|
|
|
if message_type == "player":
|
2024-05-11 10:17:03 +00:00
|
|
|
if "name" in data:
|
|
|
|
new_player_name = data["name"]
|
|
|
|
existing_id = next(
|
|
|
|
(
|
|
|
|
k
|
|
|
|
for k, v in player_names.items()
|
|
|
|
if v == new_player_name
|
|
|
|
),
|
|
|
|
None,
|
2024-05-05 22:46:24 +00:00
|
|
|
)
|
2024-05-11 10:17:03 +00:00
|
|
|
if existing_id is not None:
|
|
|
|
logger.error(
|
2024-05-12 20:47:18 +00:00
|
|
|
f"name {new_player_name} is already in use by {existing_id}"
|
2024-05-11 10:17:03 +00:00
|
|
|
)
|
|
|
|
continue
|
2024-05-05 22:46:24 +00:00
|
|
|
|
2024-05-11 10:17:03 +00:00
|
|
|
logger.info(
|
|
|
|
f"changing player name for {id} to {new_player_name}"
|
|
|
|
)
|
|
|
|
player_names[id] = new_player_name
|
|
|
|
|
|
|
|
elif "become" in data:
|
|
|
|
character_name = data["become"]
|
|
|
|
if has_player(character_name):
|
|
|
|
logger.error(
|
2024-05-12 20:47:18 +00:00
|
|
|
f"character {character_name} is already in use"
|
2024-05-11 10:17:03 +00:00
|
|
|
)
|
|
|
|
continue
|
|
|
|
|
|
|
|
# TODO: should this always remove?
|
|
|
|
remove_player(id)
|
|
|
|
|
2024-05-27 01:32:03 +00:00
|
|
|
character, llm_agent = get_character_agent_for_name(
|
|
|
|
character_name
|
|
|
|
)
|
|
|
|
if not character:
|
|
|
|
logger.error(f"Failed to find character {character_name}")
|
2024-05-11 10:17:03 +00:00
|
|
|
continue
|
|
|
|
|
|
|
|
# prevent any recursive fallback bugs
|
|
|
|
if isinstance(llm_agent, RemotePlayer):
|
|
|
|
logger.warning(
|
|
|
|
"patching recursive fallback for %s", character_name
|
|
|
|
)
|
|
|
|
llm_agent = llm_agent.fallback_agent
|
|
|
|
|
|
|
|
player = RemotePlayer(
|
2024-05-27 01:32:03 +00:00
|
|
|
character.name,
|
|
|
|
character.backstory,
|
2024-05-11 10:17:03 +00:00
|
|
|
sync_turn,
|
|
|
|
fallback_agent=llm_agent,
|
|
|
|
)
|
|
|
|
set_player(id, player)
|
|
|
|
logger.info(
|
2024-05-12 20:47:18 +00:00
|
|
|
f"client {player_name} is now character {character_name}"
|
2024-05-11 10:17:03 +00:00
|
|
|
)
|
2024-05-05 18:54:39 +00:00
|
|
|
|
2024-05-11 10:17:03 +00:00
|
|
|
# swap out the LLM agent
|
2024-05-27 01:32:03 +00:00
|
|
|
set_character_agent(character.name, character, player)
|
2024-05-05 18:54:39 +00:00
|
|
|
|
2024-05-11 10:17:03 +00:00
|
|
|
# notify all clients that this character is now active
|
2024-05-12 05:08:53 +00:00
|
|
|
broadcast_player_event(character_name, player_name, "join")
|
|
|
|
broadcast_player_list()
|
2024-05-08 01:42:10 +00:00
|
|
|
elif message_type == "input":
|
|
|
|
player = get_player(id)
|
|
|
|
if player and isinstance(player, RemotePlayer):
|
|
|
|
logger.info(
|
|
|
|
"queueing input for player %s: %s", player.name, data
|
|
|
|
)
|
|
|
|
player.input_queue.put(data["input"])
|
2024-05-12 05:08:53 +00:00
|
|
|
elif message_type == "render":
|
2024-05-12 20:47:18 +00:00
|
|
|
render_input(data)
|
2024-05-05 14:14:54 +00:00
|
|
|
|
|
|
|
except Exception:
|
2024-05-12 20:47:18 +00:00
|
|
|
logger.exception("failed to parse message")
|
2024-05-04 04:18:21 +00:00
|
|
|
except websockets.ConnectionClosedOK:
|
|
|
|
break
|
|
|
|
|
|
|
|
connected.remove(websocket)
|
2024-05-11 22:38:07 +00:00
|
|
|
if id in player_names:
|
|
|
|
del player_names[id]
|
2024-05-05 14:14:54 +00:00
|
|
|
|
2024-05-05 18:54:39 +00:00
|
|
|
# swap out the character for the original agent when they disconnect
|
2024-05-08 01:42:10 +00:00
|
|
|
player = get_player(id)
|
|
|
|
if player and isinstance(player, RemotePlayer):
|
|
|
|
remove_player(id)
|
2024-05-05 18:54:39 +00:00
|
|
|
|
2024-05-11 10:17:03 +00:00
|
|
|
player_name = get_player_name(id)
|
2024-05-12 05:08:53 +00:00
|
|
|
logger.info("disconnecting player %s from %s", player_name, player.name)
|
|
|
|
broadcast_player_event(player.name, player_name, "leave")
|
|
|
|
broadcast_player_list()
|
2024-05-06 01:17:00 +00:00
|
|
|
|
2024-05-27 01:32:03 +00:00
|
|
|
character, _ = get_character_agent_for_name(player.name)
|
|
|
|
if character and player.fallback_agent:
|
2024-05-12 05:08:53 +00:00
|
|
|
logger.info("restoring LLM agent for %s", player.name)
|
2024-05-27 01:32:03 +00:00
|
|
|
set_character_agent(player.name, character, player.fallback_agent)
|
2024-05-05 14:14:54 +00:00
|
|
|
|
2024-05-12 05:08:53 +00:00
|
|
|
logger.info("client disconnected: %s", id)
|
2024-05-04 04:18:21 +00:00
|
|
|
|
|
|
|
|
2024-05-19 18:09:52 +00:00
|
|
|
def find_recent_event(event_id: str) -> GameEvent | None:
|
|
|
|
return next((e for e in recent_events if e.id == event_id), None)
|
|
|
|
|
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
def render_input(data):
|
|
|
|
world = get_current_world()
|
|
|
|
if not world:
|
|
|
|
logger.error("no world available")
|
|
|
|
return
|
|
|
|
|
|
|
|
if "event" in data:
|
|
|
|
event_id = data["event"]
|
2024-05-19 18:09:52 +00:00
|
|
|
event = find_recent_event(event_id)
|
2024-05-12 20:47:18 +00:00
|
|
|
if event:
|
|
|
|
render_event(event)
|
|
|
|
else:
|
|
|
|
logger.error(f"failed to find event {event_id}")
|
2024-05-27 01:32:03 +00:00
|
|
|
elif "character" in data:
|
|
|
|
character_name = data["character"]
|
|
|
|
character = find_character(world, character_name)
|
|
|
|
if character:
|
|
|
|
render_entity(character)
|
2024-05-12 20:47:18 +00:00
|
|
|
else:
|
2024-05-27 01:32:03 +00:00
|
|
|
logger.error(f"failed to find character {character_name}")
|
2024-05-12 20:47:18 +00:00
|
|
|
elif "item" in data:
|
|
|
|
item_name = data["item"]
|
2024-05-14 01:08:19 +00:00
|
|
|
item = find_item(
|
2024-05-27 01:32:03 +00:00
|
|
|
world,
|
|
|
|
item_name,
|
|
|
|
include_character_inventory=True,
|
|
|
|
include_item_inventory=True,
|
2024-05-14 01:08:19 +00:00
|
|
|
)
|
2024-05-12 20:47:18 +00:00
|
|
|
if item:
|
|
|
|
render_entity(item)
|
|
|
|
else:
|
|
|
|
logger.error(f"failed to find item {item_name}")
|
2024-05-19 19:26:51 +00:00
|
|
|
elif "portal" in data:
|
|
|
|
portal_name = data["portal"]
|
|
|
|
portal = find_portal(world, portal_name)
|
|
|
|
if portal:
|
|
|
|
render_entity(portal)
|
|
|
|
else:
|
|
|
|
logger.error(f"failed to find portal {portal_name}")
|
|
|
|
elif "room" in data:
|
|
|
|
room_name = data["room"]
|
|
|
|
room = find_room(world, room_name)
|
|
|
|
if room:
|
|
|
|
render_entity(room)
|
|
|
|
else:
|
|
|
|
logger.error(f"failed to find room {room_name}")
|
2024-05-12 20:47:18 +00:00
|
|
|
else:
|
|
|
|
logger.error(f"failed to find entity in {data}")
|
|
|
|
|
|
|
|
|
2024-05-04 04:18:21 +00:00
|
|
|
socket_thread = None
|
|
|
|
|
|
|
|
|
2024-05-04 20:35:42 +00:00
|
|
|
def server_json(obj):
|
2024-05-27 01:32:03 +00:00
|
|
|
if isinstance(obj, (Character, Item, Room)):
|
2024-05-04 20:35:42 +00:00
|
|
|
return obj.name
|
|
|
|
|
|
|
|
return world_json(obj)
|
|
|
|
|
|
|
|
|
2024-05-12 05:08:53 +00:00
|
|
|
def send_and_append(id: str, message: Dict):
|
2024-05-04 20:35:42 +00:00
|
|
|
json_message = dumps(message, default=server_json)
|
2024-05-12 05:08:53 +00:00
|
|
|
recent_json.append(json_message)
|
2024-05-04 20:35:42 +00:00
|
|
|
websockets.broadcast(connected, json_message)
|
|
|
|
return json_message
|
|
|
|
|
|
|
|
|
2024-05-18 21:58:11 +00:00
|
|
|
def launch_server(config: WebsocketServerConfig):
|
2024-05-11 22:38:07 +00:00
|
|
|
global socket_thread
|
2024-05-18 21:58:11 +00:00
|
|
|
global server_config
|
|
|
|
|
|
|
|
logger.info("configuring websocket server: %s", config)
|
|
|
|
server_config = config
|
2024-05-04 04:18:21 +00:00
|
|
|
|
|
|
|
def run_sockets():
|
|
|
|
asyncio.run(server_main())
|
|
|
|
|
2024-05-12 20:47:18 +00:00
|
|
|
logger.info("launching websocket server")
|
2024-05-09 02:11:16 +00:00
|
|
|
socket_thread = Thread(target=run_sockets, daemon=True)
|
2024-05-04 04:18:21 +00:00
|
|
|
socket_thread.start()
|
|
|
|
|
2024-05-18 21:58:11 +00:00
|
|
|
subscribe(GameEvent, server_event)
|
|
|
|
|
2024-05-09 02:11:16 +00:00
|
|
|
return [socket_thread]
|
|
|
|
|
2024-05-04 04:18:21 +00:00
|
|
|
|
|
|
|
async def server_main():
|
2024-05-19 00:48:18 +00:00
|
|
|
async with websockets.serve(handler, server_config.host, server_config.port):
|
2024-05-12 20:47:18 +00:00
|
|
|
logger.info("websocket server started")
|
2024-05-04 04:18:21 +00:00
|
|
|
await asyncio.Future() # run forever
|
|
|
|
|
|
|
|
|
2024-05-27 12:54:36 +00:00
|
|
|
def server_system(world: World, turn: int, data: Any | None = None):
|
2024-05-09 02:11:16 +00:00
|
|
|
global last_snapshot
|
2024-05-12 05:08:53 +00:00
|
|
|
id = uuid4().hex # TODO: should a server be allowed to generate event IDs?
|
2024-05-04 04:18:21 +00:00
|
|
|
json_state = {
|
2024-05-27 12:54:36 +00:00
|
|
|
**snapshot_world(world, turn),
|
2024-05-12 05:08:53 +00:00
|
|
|
"id": id,
|
2024-05-10 04:45:10 +00:00
|
|
|
"type": "snapshot",
|
2024-05-04 04:18:21 +00:00
|
|
|
}
|
2024-05-12 05:08:53 +00:00
|
|
|
last_snapshot = send_and_append(id, json_state)
|
2024-05-04 04:18:21 +00:00
|
|
|
|
|
|
|
|
2024-05-10 04:45:10 +00:00
|
|
|
def server_event(event: GameEvent):
|
2024-05-11 22:38:07 +00:00
|
|
|
json_event: Dict[str, Any] = RootModel[event.__class__](event).model_dump()
|
2024-05-12 05:08:53 +00:00
|
|
|
json_event.update(
|
|
|
|
{
|
|
|
|
"id": event.id,
|
|
|
|
"type": event.type,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
if isinstance(event, RenderEvent):
|
|
|
|
# load and encode the images
|
|
|
|
image_paths = event.paths
|
|
|
|
image_data = {}
|
|
|
|
for path in image_paths:
|
|
|
|
with Image.open(path, "r") as image:
|
|
|
|
buffered = BytesIO()
|
|
|
|
image.save(
|
|
|
|
buffered, format="JPEG", quality=80, optimize=True, progressive=True
|
|
|
|
)
|
|
|
|
image_str = b64encode(buffered.getvalue())
|
|
|
|
image_data[path] = image_str.decode("utf-8")
|
|
|
|
|
|
|
|
json_event["images"] = image_data
|
|
|
|
|
|
|
|
recent_events.append(event)
|
|
|
|
send_and_append(event.id, json_event)
|
|
|
|
|
|
|
|
|
|
|
|
def broadcast_player_event(
|
|
|
|
character: str, client: str, status: Literal["join", "leave"]
|
|
|
|
):
|
2024-05-10 04:45:10 +00:00
|
|
|
event = PlayerEvent(status=status, character=character, client=client)
|
|
|
|
broadcast(event)
|
2024-05-06 01:17:00 +00:00
|
|
|
|
|
|
|
|
2024-05-12 05:08:53 +00:00
|
|
|
def broadcast_player_list():
|
|
|
|
event = PlayerListEvent(players=list_players())
|
|
|
|
broadcast(event)
|