1
0
Fork 0
taleweave-ai/adventure/server.py

194 lines
5.5 KiB
Python
Raw Normal View History

2024-05-04 04:18:21 +00:00
import asyncio
2024-05-04 20:35:42 +00:00
from collections import deque
2024-05-05 14:14:54 +00:00
from json import dumps, loads
2024-05-04 04:18:21 +00:00
from logging import getLogger
from threading import Thread
from typing import Dict
from uuid import uuid4
2024-05-04 04:18:21 +00:00
import websockets
from adventure.context import get_actor_agent_for_name, set_actor_agent_for_name
2024-05-04 04:18:21 +00:00
from adventure.models import Actor, Room, World
2024-05-05 14:14:54 +00:00
from adventure.player import RemotePlayer
2024-05-04 04:18:21 +00:00
from adventure.state import snapshot_world, world_json
logger = getLogger(__name__)
connected = set()
2024-05-05 14:14:54 +00:00
characters: Dict[str, RemotePlayer] = {}
recent_events = deque(maxlen=100)
2024-05-04 20:35:42 +00:00
recent_world = None
2024-05-04 04:18:21 +00:00
async def handler(websocket):
id = uuid4().hex
logger.info("Client connected, given id: %s", id)
2024-05-04 04:18:21 +00:00
connected.add(websocket)
2024-05-04 20:35:42 +00:00
2024-05-05 14:14:54 +00:00
async def next_turn(character: str, prompt: str) -> None:
await websocket.send(
dumps(
{
"type": "prompt",
"id": id,
"character": character,
"prompt": prompt,
"actions": [],
}
),
)
2024-05-05 14:14:54 +00:00
def sync_turn(character: str, prompt: str) -> bool:
if websocket not in characters:
return False
asyncio.run(next_turn(character, prompt))
return True
2024-05-04 20:35:42 +00:00
try:
await websocket.send(dumps({"type": "id", "id": id}))
2024-05-04 20:35:42 +00:00
if recent_world:
await websocket.send(recent_world)
for message in recent_events:
await websocket.send(message)
except Exception:
logger.exception("Failed to send recent messages to new client")
2024-05-04 04:18:21 +00:00
while True:
try:
2024-05-05 14:14:54 +00:00
# if this socket is attached to a character and that character's turn is active, wait for input
2024-05-04 04:18:21 +00:00
message = await websocket.recv()
logger.info(f"Received message for {id}: {message}")
2024-05-05 14:14:54 +00:00
try:
data = loads(message)
message_type = data.get("type", None)
if message_type == "player":
2024-05-05 14:14:54 +00:00
character = characters.get(websocket)
if character:
del characters[id]
2024-05-05 14:14:54 +00:00
character_name = data["become"]
actor, llm_agent = get_actor_agent_for_name(character_name)
2024-05-05 14:14:54 +00:00
if not actor:
logger.error(f"Failed to find actor {character_name}")
continue
if character_name in [
player.name for player in characters.values()
]:
2024-05-05 14:14:54 +00:00
logger.error(f"Character {character_name} is already in use")
continue
# player_name = data["player"]
player = RemotePlayer(actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent)
characters[id] = player
2024-05-05 14:14:54 +00:00
logger.info(f"Client {websocket} is now character {character_name}")
# swap out the LLM agent
set_actor_agent_for_name(actor.name, actor, player)
# notify all clients that this character is now active
send_and_append(
{"type": "player", "name": character_name, "id": id}
)
elif message_type == "input" and id in characters:
player = characters[id]
logger.info("queueing input for player %s: %s", player.name, data)
player.input_queue.put(data["input"])
2024-05-05 14:14:54 +00:00
except Exception:
logger.exception("Failed to parse message")
2024-05-04 04:18:21 +00:00
except websockets.ConnectionClosedOK:
break
connected.remove(websocket)
2024-05-05 14:14:54 +00:00
# swap out the character for the original agent when they disconnect
2024-05-05 14:14:54 +00:00
if websocket in characters:
player = characters[id]
del characters[id]
actor, _ = get_actor_agent_for_name(player.name)
if actor:
set_actor_agent_for_name(player.name, actor, player.fallback_agent)
2024-05-05 14:14:54 +00:00
2024-05-04 20:35:42 +00:00
logger.info("Client disconnected")
2024-05-04 04:18:21 +00:00
socket_thread = None
static_thread = None
2024-05-04 20:35:42 +00:00
def server_json(obj):
if isinstance(obj, Actor):
return obj.name
if isinstance(obj, Room):
return obj.name
return world_json(obj)
def send_and_append(message):
json_message = dumps(message, default=server_json)
recent_events.append(json_message)
websockets.broadcast(connected, json_message)
return json_message
2024-05-04 04:18:21 +00:00
def launch_server():
global socket_thread, static_thread
def run_sockets():
asyncio.run(server_main())
socket_thread = Thread(target=run_sockets)
socket_thread.start()
async def server_main():
async with websockets.serve(handler, "", 8001):
logger.info("Server started")
await asyncio.Future() # run forever
def server_system(world: World, step: int):
2024-05-04 20:35:42 +00:00
global recent_world
2024-05-04 04:18:21 +00:00
json_state = {
**snapshot_world(world, step),
"type": "world",
}
2024-05-04 20:35:42 +00:00
recent_world = send_and_append(json_state)
2024-05-04 04:18:21 +00:00
def server_result(room: Room, actor: Actor, action: str):
json_action = {
2024-05-04 20:35:42 +00:00
"actor": actor,
2024-05-04 04:18:21 +00:00
"result": action,
2024-05-04 20:35:42 +00:00
"room": room,
2024-05-04 04:18:21 +00:00
"type": "result",
}
2024-05-04 20:35:42 +00:00
send_and_append(json_action)
2024-05-04 04:18:21 +00:00
2024-05-04 20:35:42 +00:00
def server_action(room: Room, actor: Actor, message: str):
2024-05-04 04:18:21 +00:00
json_input = {
2024-05-04 20:35:42 +00:00
"actor": actor,
2024-05-04 04:18:21 +00:00
"input": message,
2024-05-04 20:35:42 +00:00
"room": room,
"type": "action",
}
send_and_append(json_input)
def server_event(message: str):
json_broadcast = {
"message": message,
"type": "event",
2024-05-04 04:18:21 +00:00
}
2024-05-04 20:35:42 +00:00
send_and_append(json_broadcast)