diff --git a/.gitignore b/.gitignore index 3e7caac..ec021af 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ adventure/custom_* +adventure/user_config.yaml worlds/ __pycache__/ .env diff --git a/adventure/bot_discord.py b/adventure/bot_discord.py index 760b6e2..ff67f50 100644 --- a/adventure/bot_discord.py +++ b/adventure/bot_discord.py @@ -13,6 +13,7 @@ from adventure.context import ( get_current_world, set_actor_agent, ) +from adventure.models.config import DiscordBotConfig from adventure.models.event import ( ActionEvent, GameEvent, @@ -24,11 +25,18 @@ from adventure.models.event import ( ResultEvent, StatusEvent, ) -from adventure.player import RemotePlayer, get_player, has_player, set_player +from adventure.player import ( + RemotePlayer, + get_player, + has_player, + remove_player, + set_player, +) from adventure.render_comfy import render_event logger = getLogger(__name__) client = None +bot_config: DiscordBotConfig = DiscordBotConfig(channels=["bots"]) active_tasks = set() event_messages: Dict[str, str | GameEvent] = {} @@ -45,17 +53,17 @@ def remove_tags(text: str) -> str: class AdventureClient(Client): async def on_ready(self): - logger.info(f"Logged in as {self.user}") + logger.info(f"logged in as {self.user}") async def on_reaction_add(self, reaction, user): if user == self.user: return - logger.info(f"Reaction added: {reaction} by {user}") + logger.info(f"reaction added: {reaction} by {user}") if reaction.emoji == "📷": message_id = reaction.message.id if message_id not in event_messages: - logger.warning(f"Message {message_id} not found in event messages") + logger.warning(f"message {message_id} not found in event messages") # TODO: return error message return @@ -119,19 +127,25 @@ class AdventureClient(Client): return broadcast(join_event) player = get_player(user_name) - if player: + if isinstance(player, RemotePlayer): if message.content.startswith("!leave"): - # TODO: check if player is playing - # TODO: revert to LLM agent - logger.info(f"{user_name} has left the game!") + remove_player(user_name) + + # revert to LLM agent + actor, _ = get_actor_agent_for_name(player.name) + if actor and player.fallback_agent: + logger.info("restoring LLM agent for %s", player.name) + set_actor_agent(actor.name, actor, player.fallback_agent) + + # broadcast leave event + logger.info("disconnecting player %s from %s", user_name, player.name) leave_event = PlayerEvent("leave", player.name, user_name) return broadcast(leave_event) - - if isinstance(player, RemotePlayer): + else: content = remove_tags(message.content) player.input_queue.put(content) logger.info( - f"Received message from {user_name} for {player.name}: {content}" + f"received message from {user_name} for {player.name}: {content}" ) return @@ -141,11 +155,16 @@ class AdventureClient(Client): return -def launch_bot(): +def launch_bot(config: DiscordBotConfig): + global bot_config global client + bot_config = config + + # message contents need to be enabled for multi-server bots intents = Intents.default() - # intents.message_content = True + if bot_config.content_intent: + intents.message_content = True client = AdventureClient(intents=intents) @@ -164,6 +183,7 @@ def launch_bot(): # logger.debug("no events to prompt") continue + # wait for pending messages to send, to keep them in order if len(active_tasks) > 0: logger.debug("waiting for active tasks to complete") continue @@ -178,6 +198,7 @@ def launch_bot(): else: logger.warning("no Discord client available") + logger.info("launching Discord bot") bot_thread = Thread(target=bot_main, daemon=True) bot_thread.start() @@ -205,7 +226,7 @@ def get_active_channels(): channel for guild in client.guilds for channel in guild.text_channels - if channel.name == "bots" + if channel.name in bot_config.channels ] @@ -286,6 +307,8 @@ def embed_from_event(event: GameEvent) -> Embed: return embed_from_status(event) elif isinstance(event, PlayerEvent): return embed_from_player(event) + elif isinstance(event, PromptEvent): + return embed_from_prompt(event) else: logger.warning("unknown event type: %s", event) @@ -334,8 +357,14 @@ def embed_from_player(event: PlayerEvent): return player_embed +def embed_from_prompt(event: PromptEvent): + # TODO: ping the player + prompt_embed = Embed(title=event.room.name, description=event.actor.name) + prompt_embed.add_field(name="Prompt", value=event.prompt) + return prompt_embed + + def embed_from_status(event: StatusEvent): - # TODO: add room and actor status_embed = Embed( title=event.room.name if event.room else "", description=event.actor.name if event.actor else "", diff --git a/adventure/generate.py b/adventure/generate.py index 9017879..d22b23f 100644 --- a/adventure/generate.py +++ b/adventure/generate.py @@ -5,7 +5,7 @@ from typing import List from packit.agent import Agent from packit.loops import loop_retry -from adventure.models.entity import Actor, Item, Room, World +from adventure.models.entity import Actor, Item, Room, World, WorldEntity from adventure.models.event import EventCallback, GenerateEvent logger = getLogger(__name__) @@ -171,10 +171,18 @@ def generate_world( ) -> World: room_count = room_count or randint(3, max_rooms) - if callable(callback): - callback( - GenerateEvent.from_name(f"Generating a {theme} with {room_count} rooms") - ) + def callback_wrapper(message: str | None = None, entity: WorldEntity | None = None): + if message: + event = GenerateEvent.from_name(message) + elif entity: + event = GenerateEvent.from_entity(entity) + else: + raise ValueError("Either message or entity must be provided") + + if callable(callback): + callback(event) + + callback_wrapper(message=f"Generating a {theme} with {room_count} rooms") existing_actors: List[str] = [] existing_items: List[str] = [] @@ -186,17 +194,13 @@ def generate_world( room = generate_room( agent, theme, existing_rooms=existing_rooms, callback=callback ) + callback_wrapper(entity=room) rooms.append(room) existing_rooms.append(room.name) item_count = randint(1, 3) - if callable(callback): - callback( - GenerateEvent.from_name( - f"Generating {item_count} items for room: {room.name}" - ) - ) + callback_wrapper(f"Generating {item_count} items for room: {room.name}") for j in range(item_count): item = generate_item( @@ -206,17 +210,16 @@ def generate_world( existing_items=existing_items, callback=callback, ) + callback_wrapper(entity=item) + room.items.append(item) existing_items.append(item.name) actor_count = randint(1, 3) - if callable(callback): - callback( - GenerateEvent.from_name( - f"Generating {actor_count} actors for room: {room.name}" - ) - ) + callback_wrapper( + message=f"Generating {actor_count} actors for room: {room.name}" + ) for j in range(actor_count): actor = generate_actor( @@ -226,18 +229,15 @@ def generate_world( existing_actors=existing_actors, callback=callback, ) + callback_wrapper(entity=actor) + room.actors.append(actor) existing_actors.append(actor.name) # generate the actor's inventory item_count = randint(0, 2) - if callable(callback): - callback( - GenerateEvent.from_name( - f"Generating {item_count} items for actor {actor.name}" - ) - ) + callback_wrapper(f"Generating {item_count} items for actor {actor.name}") for k in range(item_count): item = generate_item( @@ -247,6 +247,8 @@ def generate_world( existing_items=existing_items, callback=callback, ) + callback_wrapper(entity=item) + actor.items.append(item) existing_items.append(item.name) diff --git a/adventure/logic.py b/adventure/logic.py index 5820762..de39817 100644 --- a/adventure/logic.py +++ b/adventure/logic.py @@ -144,7 +144,7 @@ def format_logic(attributes: Attributes, rules: LogicTable, self=True) -> str: logger.debug("label has no relevant description: %s", label) if len(labels) > 0: - logger.info("adding attribute labels: %s", labels) + logger.debug("adding attribute labels: %s", labels) return " ".join(labels) diff --git a/adventure/main.py b/adventure/main.py index b4f76dd..d036423 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -10,8 +10,9 @@ from yaml import Loader, load from adventure.context import set_current_step, set_dungeon_master from adventure.generate import generate_world +from adventure.models.config import Config from adventure.models.entity import World, WorldState -from adventure.models.event import EventCallback, GameEvent +from adventure.models.event import EventCallback, GameEvent, GenerateEvent from adventure.models.files import PromptFile, WorldPrompt from adventure.plugins import load_plugin from adventure.simulate import simulate_world @@ -36,7 +37,7 @@ except Exception as err: print("error loading logging config: %s" % (err)) -logger = logger_with_colors(__name__, level="DEBUG") +logger = logger_with_colors(__name__) # , level="DEBUG") load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True) @@ -64,7 +65,15 @@ def parse_args(): help="Extra actions to include in the simulation", ) parser.add_argument( - "--discord", type=bool, help="Whether to run the simulation in a Discord bot" + "--config", + type=str, + default="config.yml", + help="The file to load the configuration from", + ) + parser.add_argument( + "--discord", + action="store_true", + help="Whether to run the simulation in a Discord bot", ) parser.add_argument( "--flavor", @@ -73,29 +82,50 @@ def parse_args(): help="Some additional flavor text for the generated world", ) parser.add_argument( - "--player", type=str, help="The name of the character to play as" + "--max-rooms", + type=int, + help="The maximum number of rooms to generate", ) parser.add_argument( - "--rooms", type=int, default=5, help="The number of rooms to generate" + "--optional-actions", + action="store_true", + help="Whether to include optional actions", ) parser.add_argument( - "--max-rooms", type=int, help="The maximum number of rooms to generate" + "--player", + type=str, + help="The name of the character to play as", ) parser.add_argument( - "--optional-actions", type=bool, help="Whether to include optional actions" + "--render", + action="store_true", + help="Whether to render the simulation", ) - parser.add_argument("--render", type=bool, help="Whether to render the simulation") parser.add_argument( - "--server", type=str, help="The address on which to run the server" + "--render-generated", + action="store_true", + help="Whether to render entities as they are generated", + ) + parser.add_argument( + "--rooms", + type=int, + help="The number of rooms to generate", + ) + parser.add_argument( + "--server", + type=str, + help="The address on which to run the server", ) parser.add_argument( "--state", type=str, - # default="world.state.json", help="The file to save the world state to. Defaults to $world.state.json, if not set", ) parser.add_argument( - "--steps", type=int, default=10, help="The number of simulation steps to run" + "--steps", + type=int, + default=10, + help="The number of simulation steps to run", ) parser.add_argument( "--systems", @@ -104,7 +134,10 @@ def parse_args(): help="Extra systems to run in the simulation", ) parser.add_argument( - "--theme", type=str, default="fantasy", help="The theme of the generated world" + "--theme", + type=str, + default="fantasy", + help="The theme of the generated world", ) parser.add_argument( "--world", @@ -113,16 +146,16 @@ def parse_args(): help="The file to save the generated world to", ) parser.add_argument( - "--world-prompt", + "--world-template", type=str, - help="The file to load the world prompt from", + help="The template file to load the world prompt from", ) return parser.parse_args() def get_world_prompt(args) -> WorldPrompt: - if args.world_prompt: - prompt_file, prompt_name = args.world_prompt.split(":") + if args.world_template: + prompt_file, prompt_name = args.world_template.split(":") with open(prompt_file, "r") as f: prompts = PromptFile(**load_yaml(f)) for prompt in prompts.prompts: @@ -138,7 +171,9 @@ def get_world_prompt(args) -> WorldPrompt: ) -def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt): +def load_or_generate_world( + args, players, callbacks, systems, world_prompt: WorldPrompt +): world_file = args.world + ".json" world_state_file = args.state or (args.world + ".state.json") @@ -170,7 +205,7 @@ def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt): world = None def broadcast_callback(event: GameEvent): - logger.info(event) + logger.debug("broadcasting generation event: %s", event) for callback in callbacks: callback(event) @@ -184,6 +219,10 @@ def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt): ) save_world(world, world_file) + # run the systems once to initialize everything + for system_update, _ in systems: + system_update(world, 0) + create_agents(world, memory=memory, players=players) return (world, world_state_file) @@ -191,6 +230,9 @@ def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt): def main(): args = parse_args() + with open(args.config, "r") as f: + config = Config(**load_yaml(f)) + players = [] if args.player: players.append(args.player) @@ -204,12 +246,22 @@ def main(): if args.render: from adventure.render_comfy import launch_render - threads.extend(launch_render()) + threads.extend(launch_render(config.render)) + + if args.render_generated: + from adventure.render_comfy import render_entity + + def render_generated(event: GameEvent): + if isinstance(event, GenerateEvent) and event.entity: + logger.info("rendering generated entity: %s", event.entity.name) + render_entity(event.entity) + + callbacks.append(render_generated) if args.discord: from adventure.bot_discord import bot_event, launch_bot - threads.extend(launch_bot()) + threads.extend(launch_bot(config.bot.discord)) callbacks.append(bot_event) if args.server: @@ -263,7 +315,7 @@ def main(): # load or generate the world world_prompt = get_world_prompt(args) world, world_state_file = load_or_generate_world( - args, players, callbacks, world_prompt=world_prompt + args, players, callbacks, extra_systems, world_prompt=world_prompt ) # make sure the snapshot system runs last @@ -273,9 +325,9 @@ def main(): extra_systems.append((snapshot_system, None)) - # run the systems once to initialize everything - for system_update, _ in extra_systems: - system_update(world, 0) + # hack: send a snapshot to the websocket server + if args.server: + server_system(world, 0) # create the DM llm = agent_easy_connect() diff --git a/adventure/models/config.py b/adventure/models/config.py new file mode 100644 index 0000000..566a66f --- /dev/null +++ b/adventure/models/config.py @@ -0,0 +1,41 @@ +from typing import Dict, List + +from .base import dataclass + + +@dataclass +class Range: + min: int + max: int + + +@dataclass +class Size: + width: int + height: int + + +@dataclass +class DiscordBotConfig: + channels: List[str] + content_intent: bool = False + + +@dataclass +class BotConfig: + discord: DiscordBotConfig + + +@dataclass +class RenderConfig: + cfg: Range + checkpoints: List[str] + path: str + sizes: Dict[str, Size] + steps: Range + + +@dataclass +class Config: + bot: BotConfig + render: RenderConfig diff --git a/adventure/models/event.py b/adventure/models/event.py index 9262e38..07f607d 100644 --- a/adventure/models/event.py +++ b/adventure/models/event.py @@ -1,5 +1,5 @@ from json import loads -from typing import Any, Callable, Dict, List, Literal +from typing import Any, Callable, Dict, List, Literal, Union from uuid import uuid4 from pydantic import Field @@ -175,7 +175,7 @@ class RenderEvent(BaseEvent): id = Field(default_factory=uuid) type = "render" paths: List[str] - source: "GameEvent" + source: Union["GameEvent", WorldEntity] # event types diff --git a/adventure/render_comfy.py b/adventure/render_comfy.py index 8f3821b..450f082 100644 --- a/adventure/render_comfy.py +++ b/adventure/render_comfy.py @@ -1,22 +1,22 @@ -# This is an example that uses the websockets api to know when a prompt execution is done -# Once the prompt execution is done it downloads the images using the /history endpoint - import io import json import urllib.parse import urllib.request -import uuid from logging import getLogger from os import environ, path from queue import Queue from random import choice, randint from threading import Thread from typing import List +from uuid import uuid4 import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +from jinja2 import Environment, FileSystemLoader, select_autoescape from PIL import Image from adventure.context import broadcast +from adventure.models.config import Range, RenderConfig, Size +from adventure.models.entity import WorldEntity from adventure.models.event import ( ActionEvent, GameEvent, @@ -29,15 +29,39 @@ from adventure.models.event import ( logger = getLogger(__name__) server_address = environ["COMFY_API"] -client_id = str(uuid.uuid4()) +client_id = uuid4().hex +render_config: RenderConfig = RenderConfig( + cfg=Range(min=5, max=8), + checkpoints=[ + "diffusion-sdxl-dynavision-0-5-5-7.safetensors", + ], + path="/tmp/adventure-images", + sizes={ + "landscape": Size(width=1024, height=768), + "portrait": Size(width=768, height=1024), + "square": Size(width=768, height=768), + }, + steps=Range(min=30, max=30), +) + + +# requests to generate images for game events +render_queue: Queue[GameEvent | WorldEntity] = Queue() +render_thread: Thread | None = None def generate_cfg(): - return randint(5, 8) + if render_config.cfg.min == render_config.cfg.max: + return render_config.cfg.min + + return randint(render_config.cfg.min, render_config.cfg.max) def generate_steps(): - return 30 + if render_config.steps.min == render_config.steps.max: + return render_config.steps.min + + return randint(render_config.steps.min, render_config.steps.max) def generate_batches( @@ -93,7 +117,7 @@ def get_images(ws, prompt): continue # previews are binary data history = get_history(prompt_id)[prompt_id] - for o in history["outputs"]: + for _ in history["outputs"]: for node_id in history["outputs"]: node_output = history["outputs"][node_id] if "images" in node_output: @@ -117,86 +141,47 @@ def generate_image_tool(prompt, count, size="landscape"): return output_paths -sizes = { - "landscape": (1024, 768), - "portrait": (768, 1024), - "square": (768, 768), -} - - def generate_images( prompt: str, count: int, size="landscape", prefix="output" ) -> List[str]: cfg = generate_cfg() - width, height = sizes.get(size, (512, 512)) + dims = render_config.sizes[size] steps = generate_steps() seed = randint(0, 10000000) - checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"]) + checkpoint = choice(render_config.checkpoints) logger.info( - "generating %s images at %s by %s with prompt: %s", count, width, height, prompt + "generating %s images at %s by %s with prompt: %s", + count, + dims.width, + dims.height, + prompt, + ) + + env = Environment( + loader=FileSystemLoader(["adventure/templates"]), + autoescape=select_autoescape(["json"]), + ) + template = env.get_template("comfy.json.j2") + result = template.render( + cfg=cfg, + height=dims.height, + width=dims.width, + steps=steps, + seed=seed, + checkpoint=checkpoint, + prompt=prompt.replace("\n", ". "), + negative_prompt="", + count=count, + prefix=prefix, ) # parsing here helps ensure the template emits valid JSON - prompt_workflow = { - "3": { - "class_type": "KSampler", - "inputs": { - "cfg": cfg, - "denoise": 1, - "latent_image": ["5", 0], - "model": ["4", 0], - "negative": ["7", 0], - "positive": ["6", 0], - "sampler_name": "euler_ancestral", - "scheduler": "normal", - "seed": seed, - "steps": steps, - }, - }, - "4": { - "class_type": "CheckpointLoaderSimple", - "inputs": {"ckpt_name": checkpoint}, - }, - "5": { - "class_type": "EmptyLatentImage", - "inputs": {"batch_size": count, "height": height, "width": width}, - }, - "6": { - "class_type": "smZ CLIPTextEncode", - "inputs": { - "text": prompt, - "parser": "compel", - "mean_normalization": True, - "multi_conditioning": True, - "use_old_emphasis_implementation": False, - "with_SDXL": False, - "ascore": 6, - "width": width, - "height": height, - "crop_w": 0, - "crop_h": 0, - "target_width": width, - "target_height": height, - "text_g": "", - "text_l": "", - "smZ_steps": 1, - "clip": ["4", 1], - }, - }, - "7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}}, - "8": { - "class_type": "VAEDecode", - "inputs": {"samples": ["3", 0], "vae": ["4", 2]}, - }, - "9": { - "class_type": "SaveImage", - "inputs": {"filename_prefix": prefix, "images": ["8", 0]}, - }, - } + logger.debug("template workflow: %s", result) + prompt_workflow = json.loads(result) - logger.debug("Connecting to Comfy API at %s", server_address) + logger.debug("connecting to Comfy API at %s", server_address) ws = websocket.WebSocket() - ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) + ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id), timeout=60) images = get_images(ws, prompt_workflow) results = [] @@ -207,8 +192,7 @@ def generate_images( paths: List[str] = [] for j, image in enumerate(results): - # TODO: replace with environment variable - image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png") + image_path = path.join(render_config.path, f"{prefix}-{j}.png") with open(image_path, "wb") as f: image_bytes = io.BytesIO() image.save(image_bytes, format="PNG") @@ -244,51 +228,85 @@ def prompt_from_event(event: GameEvent) -> str | None: return None -def prefix_from_event(event: GameEvent) -> str: +def prompt_from_entity(entity: WorldEntity) -> str: + return entity.description + + +def get_image_prefix(event: GameEvent | WorldEntity) -> str: if isinstance(event, ActionEvent): - return ( - f"{event.actor.name}-{event.action}-{event.item.name if event.item else ''}" - ) + return f"event-action-{event.actor.name}-{event.action}" if isinstance(event, ReplyEvent): - return f"{event.actor.name}-reply" + return f"event-reply-{event.actor.name}" if isinstance(event, ResultEvent): - return f"{event.actor.name}-result" + return f"event-result-{event.actor.name}" if isinstance(event, StatusEvent): return "status" + if isinstance(event, WorldEntity): + return f"entity-{event.__class__.__name__.lower()}-{event.name}" + return "unknown" -# requests to generate images for game events -render_queue: Queue[GameEvent] = Queue() - - def render_loop(): while True: event = render_queue.get() - prompt = prompt_from_event(event) + prefix = get_image_prefix(event) + + # check if images already exist + image_index = 0 + image_path = path.join(render_config.path, f"{prefix}-{image_index}.png") + existing_images = [] + while path.exists(image_path): + existing_images.append(image_path) + image_index += 1 + image_path = path.join(render_config.path, f"{prefix}-{image_index}.png") + + if existing_images: + logger.info( + "using existing images for event %s: %s", event, existing_images + ) + broadcast(RenderEvent(paths=existing_images, source=event)) + continue + + # generate the prompt + if isinstance(event, WorldEntity): + logger.info("rendering entity %s", event) + prompt = prompt_from_entity(event) + else: + logger.info("rendering event %s", event) + prompt = prompt_from_event(event) + + # render or not if prompt: logger.info("rendering prompt for event %s: %s", event, prompt) - prefix = prefix_from_event(event) image_paths = generate_images(prompt, 2, prefix=prefix) broadcast(RenderEvent(paths=image_paths, source=event)) else: logger.warning("no prompt for event %s", event) +def render_entity(entity: WorldEntity): + render_queue.put(entity) + + def render_event(event: GameEvent): render_queue.put(event) -render_thread = None - - -def launch_render(): +def launch_render(config: RenderConfig): + global render_config global render_thread + # update the config + logger.info("updating render config: %s", config) + render_config = config + + # start the render thread + logger.info("launching render thread") render_thread = Thread(target=render_loop, daemon=True) render_thread.start() diff --git a/adventure/server_socket.py b/adventure/server_socket.py index 23c4aa8..e2176c8 100644 --- a/adventure/server_socket.py +++ b/adventure/server_socket.py @@ -12,7 +12,12 @@ import websockets from PIL import Image from pydantic import RootModel -from adventure.context import broadcast, get_actor_agent_for_name, set_actor_agent +from adventure.context import ( + broadcast, + get_actor_agent_for_name, + get_current_world, + set_actor_agent, +) from adventure.models.entity import Actor, Item, Room, World from adventure.models.event import ( GameEvent, @@ -29,16 +34,16 @@ from adventure.player import ( remove_player, set_player, ) -from adventure.render_comfy import render_event +from adventure.render_comfy import render_entity, render_event from adventure.state import snapshot_world, world_json logger = getLogger(__name__) connected = set() +last_snapshot: str | None = None +player_names: Dict[str, str] = {} recent_events: MutableSequence[GameEvent] = deque(maxlen=100) recent_json: MutableSequence[str] = deque(maxlen=100) -last_snapshot = None -player_names: Dict[str, str] = {} def get_player_name(client_id: str) -> str: @@ -47,13 +52,14 @@ def get_player_name(client_id: str) -> str: async def handler(websocket): id = uuid4().hex - logger.info("Client connected, given id: %s", id) + logger.info("client connected, given id: %s", id) connected.add(websocket) async def next_turn(character: str, prompt: str) -> None: await websocket.send( dumps( { + # TODO: these should be fields in the PromptEvent "type": "prompt", "client": id, "character": character, @@ -64,6 +70,7 @@ async def handler(websocket): ) def sync_turn(event: PromptEvent) -> bool: + # TODO: nothing about this is good player = get_player(id) if player and player.name == event.actor.name: asyncio.run(next_turn(event.actor.name, event.prompt)) @@ -74,21 +81,21 @@ async def handler(websocket): try: await websocket.send(dumps({"type": "id", "client": id})) - # TODO: only send this if the recent events don't contain a snapshot + # only send the snapshot once if last_snapshot and last_snapshot not in recent_json: await websocket.send(last_snapshot) for message in recent_json: await websocket.send(message) except Exception: - logger.exception("Failed to send recent messages to new client") + logger.exception("failed to send recent messages to new client") while True: try: # if this socket is attached to a character and that character's turn is active, wait for input message = await websocket.recv() player_name = get_player_name(id) - logger.info(f"Received message for {player_name}: {message}") + logger.info(f"received message for {player_name}: {message}") try: data = loads(message) @@ -106,7 +113,7 @@ async def handler(websocket): ) if existing_id is not None: logger.error( - f"Name {new_player_name} is already in use by {existing_id}" + f"name {new_player_name} is already in use by {existing_id}" ) continue @@ -119,7 +126,7 @@ async def handler(websocket): character_name = data["become"] if has_player(character_name): logger.error( - f"Character {character_name} is already in use" + f"character {character_name} is already in use" ) continue @@ -146,7 +153,7 @@ async def handler(websocket): ) set_player(id, player) logger.info( - f"Client {player_name} is now character {character_name}" + f"client {player_name} is now character {character_name}" ) # swap out the LLM agent @@ -163,15 +170,10 @@ async def handler(websocket): ) player.input_queue.put(data["input"]) elif message_type == "render": - event_id = data["event"] - event = next((e for e in recent_events if e.id == event_id), None) - if event: - render_event(event) - else: - logger.error(f"Failed to find event {event_id}") + render_input(data) except Exception: - logger.exception("Failed to parse message") + logger.exception("failed to parse message") except websockets.ConnectionClosedOK: break @@ -197,6 +199,56 @@ async def handler(websocket): logger.info("client disconnected: %s", id) +def render_input(data): + world = get_current_world() + if not world: + logger.error("no world available") + return + + if "event" in data: + event_id = data["event"] + event = next((e for e in recent_events if e.id == event_id), None) + if event: + render_event(event) + else: + logger.error(f"failed to find event {event_id}") + elif "actor" in data: + actor_name = data["actor"] + actor = next( + (a for r in world.rooms for a in r.actors if a.name == actor_name), None + ) + if actor: + render_entity(actor) + else: + logger.error(f"failed to find actor {actor_name}") + elif "room" in data: + room_name = data["room"] + room = next((r for r in world.rooms if r.name == room_name), None) + if room: + render_entity(room) + else: + logger.error(f"failed to find room {room_name}") + elif "item" in data: + item_name = data["item"] + item = None + for room in world.rooms: + item = next((i for i in room.items if i.name == item_name), None) + if item: + break + + for actor in room.actors: + item = next((i for i in actor.items if i.name == item_name), None) + if item: + break + + if item: + render_entity(item) + else: + logger.error(f"failed to find item {item_name}") + else: + logger.error(f"failed to find entity in {data}") + + socket_thread = None @@ -220,6 +272,7 @@ def launch_server(): def run_sockets(): asyncio.run(server_main()) + logger.info("launching websocket server") socket_thread = Thread(target=run_sockets, daemon=True) socket_thread.start() @@ -228,7 +281,7 @@ def launch_server(): async def server_main(): async with websockets.serve(handler, "", 8001): - logger.info("Server started") + logger.info("websocket server started") await asyncio.Future() # run forever diff --git a/adventure/templates/comfy.json.j2 b/adventure/templates/comfy.json.j2 new file mode 100644 index 0000000..31da32c --- /dev/null +++ b/adventure/templates/comfy.json.j2 @@ -0,0 +1,59 @@ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": {{ cfg }}, + "denoise": 1, + "latent_image": ["5", 0], + "model": ["4", 0], + "negative": ["7", 0], + "positive": ["6", 0], + "sampler_name": "euler_ancestral", + "scheduler": "normal", + "seed": {{ seed }}, + "steps": {{ steps }} + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "{{ checkpoint }}" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": {{ count }}, + "height": {{ height }}, + "width": {{ width }} + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "text": {{ prompt | tojson }}, + "clip": ["4", 1] + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "text": "", + "clip": ["4", 1] + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": ["3", 0], + "vae": ["4", 2] + } + }, + "9": { + "class_type": "SaveImage", + "inputs": { + "filename_prefix": {{ prefix | tojson }}, + "images": ["8", 0] + } + } +} \ No newline at end of file diff --git a/client/package.json b/client/package.json index 4600985..a26a3de 100644 --- a/client/package.json +++ b/client/package.json @@ -16,6 +16,7 @@ "@mui/x-tree-view": "^7.3.1", "@types/lodash": "^4.14.192", "@types/node": "^20.11.0", + "@viz-js/viz": "^3.5.0", "allotment": "^1.20.0", "browser-bunyan": "^1.8.0", "i18next": "^22.4.14", diff --git a/client/src/app.tsx b/client/src/app.tsx index 21c27a3..4151506 100644 --- a/client/src/app.tsx +++ b/client/src/app.tsx @@ -1,31 +1,26 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { Maybe, doesExist } from '@apextoaster/js-utils'; import { - Button, Container, CssBaseline, - Dialog, - DialogActions, - DialogContent, - DialogTitle, Stack, ThemeProvider, - Typography, createTheme, } from '@mui/material'; import { Allotment } from 'allotment'; -import React, { Fragment, useEffect } from 'react'; +import React, { useEffect } from 'react'; import useWebSocketModule from 'react-use-websocket'; import { useStore } from 'zustand'; import { HistoryPanel } from './history.js'; -import { Actor, GameEvent, Item, Room } from './models.js'; +import { Actor } from './models.js'; import { PlayerPanel } from './player.js'; -import { store, StoreState } from './store.js'; -import { WorldPanel } from './world.js'; import { Statusbar } from './status.js'; +import { StoreState, store } from './store.js'; +import { WorldPanel } from './world.js'; import 'allotment/dist/style.css'; +import { DetailDialog } from './details.js'; import './main.css'; const useWebSocket = (useWebSocketModule as any).default; @@ -34,51 +29,6 @@ export interface AppProps { socketUrl: string; } -export interface EntityDetailsProps { - entity: Maybe; - close: () => void; -} - -export function EntityDetails(props: EntityDetailsProps) { - const { entity, close } = props; - - // eslint-disable-next-line no-restricted-syntax - if (!doesExist(entity)) { - return ; - } - - return - {entity.name} - - - {entity.description} - - - - - - ; -} - -export function detailStateSelector(s: StoreState) { - return { - detailEntity: s.detailEntity, - clearDetailEntity: s.clearDetailEntity, - }; -} - -export function DetailDialog() { - const state = useStore(store, detailStateSelector); - const { detailEntity, clearDetailEntity } = state; - - return - - ; -} - export function appStateSelector(s: StoreState) { return { themeMode: s.themeMode, @@ -94,6 +44,10 @@ export function App(props: AppProps) { const { lastMessage, readyState, sendMessage } = useWebSocket(props.socketUrl); // socket senders + function renderEntity(type: string, entity: string) { + sendMessage(JSON.stringify({ type: 'render', [type]: entity })); + } + function renderEvent(event: string) { sendMessage(JSON.stringify({ type: 'render', event })); } @@ -138,14 +92,9 @@ export function App(props: AppProps) { return; case 'prompt': // prompts are broadcast to all players - if (event.client === clientId) { - // only notify the active player - setActiveTurn(true); - break; - } else { - setActiveTurn(false); - return; - } + // only notify the active player + setActiveTurn(event.client === clientId); + break; case 'player': if (event.status === 'join' && doesExist(world) && event.client === clientId) { const { character: characterName } = event; @@ -173,7 +122,7 @@ export function App(props: AppProps) { return - + @@ -184,7 +133,7 @@ export function App(props: AppProps) { - + diff --git a/client/src/details.tsx b/client/src/details.tsx new file mode 100644 index 0000000..cf3d618 --- /dev/null +++ b/client/src/details.tsx @@ -0,0 +1,111 @@ +import { Maybe, doesExist } from '@apextoaster/js-utils'; +import { Button, Dialog, DialogActions, DialogContent, DialogTitle, Typography } from '@mui/material'; +import { instance as graphviz } from '@viz-js/viz'; +import React, { Fragment, useEffect } from 'react'; +import { useStore } from 'zustand'; +import { Actor, Item, Room, World } from './models'; +import { StoreState, store } from './store'; + +export interface EntityDetailsProps { + entity: Maybe; + onClose: () => void; + onRender: (type: string, entity: string) => void; +} + +export function EntityDetails(props: EntityDetailsProps) { + const { entity, onClose, onRender } = props; + + // eslint-disable-next-line no-restricted-syntax + if (!doesExist(entity)) { + return ; + } + + return + {entity.name} + + + {entity.description} + + + + + + + ; +} + +export interface WorldDetailsProps { + world: World; +} + +export function WorldDetails(props: WorldDetailsProps) { + const { world } = props; + + useEffect(() => { + graphviz().then((viz) => { + const dot = worldGraph(world); + const svg = viz.renderSVGElement(dot); + const graph = document.getElementById('graph'); + if (doesExist(graph)) { + graph.replaceChildren(svg); + } + }).catch((err) => { + // eslint-disable-next-line no-console + console.error(err); + }); + }, [world]); + + return + {world.name} + + + Theme: {world.theme} + +
+ + ; +} + +export function detailStateSelector(s: StoreState) { + return { + detailEntity: s.detailEntity, + clearDetailEntity: s.clearDetailEntity, + }; +} + +export interface DetailDialogProps { + renderEntity: (type: string, entity: string) => void; +} + +export function DetailDialog(props: DetailDialogProps) { + const state = useStore(store, detailStateSelector); + const { detailEntity, clearDetailEntity } = state; + + let details; + if (isWorld(detailEntity)) { + details = ; + } else { + details = ; + } + + return {details}; +} + +export function isWorld(entity: Maybe): entity is World { + return doesExist(entity) && doesExist(entity.theme); +} + +export function worldGraph(world: World): string { + return `digraph { + ${world.rooms.map((room) => roomGraph(room).join('; ')).join('\n')} + }`; +} + +export function roomGraph(room: Room): Array { + return Object.entries(room.portals).map(([direction, destination]) => + `"${room.name}" -> "${destination}" [label="${direction}"]` + ); +} diff --git a/client/src/events.tsx b/client/src/events.tsx index a709adb..7199463 100644 --- a/client/src/events.tsx +++ b/client/src/events.tsx @@ -1,9 +1,12 @@ import { Avatar, IconButton, ImageList, ImageListItem, ListItem, ListItemAvatar, ListItemText, Typography } from '@mui/material'; import React, { Fragment, MutableRefObject } from 'react'; +import { Maybe, doesExist } from '@apextoaster/js-utils'; import { Camera } from '@mui/icons-material'; +import { useStore } from 'zustand'; import { formatters } from './format.js'; -import { GameEvent } from './models.js'; +import { Actor, GameEvent } from './models.js'; +import { StoreState, store } from './store.js'; export function openImage(image: string) { const byteCharacters = atob(image); @@ -23,7 +26,22 @@ export interface EventItemProps { // eslint-disable-next-line @typescript-eslint/no-explicit-any focusRef?: MutableRefObject; - renderEvent: (event: GameEvent) => void; + renderEntity: (type: string, entity: string) => void; + renderEvent: (event: string) => void; +} + +export function characterSelector(state: StoreState) { + return { + character: state.character, + }; +} + +export function sameCharacter(a: Maybe, b: Maybe): boolean { + if (doesExist(a) && doesExist(b)) { + return a.name === b.name; + } + + return false; } export function ActionEventItem(props: EventItemProps) { @@ -31,6 +49,14 @@ export function ActionEventItem(props: EventItemProps) { const { id, actor, room, type } = event; const content = formatters[type](event); + const state = useStore(store, characterSelector); + const { character } = state; + + const playerAction = sameCharacter(actor, character); + const typographyProps = { + color: playerAction ? 'success.text' : 'primary.text', + }; + return + + {Object.entries(images).map(([name, image]) => + openImage(image as string)}> + Render + + )} + + ; +} + +export function PromptEventItem(props: EventItemProps) { + const { event } = props; + const { character, prompt } = event; + + const state = useStore(store, characterSelector); + const { character: playerCharacter } = state; + + const playerPrompt = sameCharacter(playerCharacter, character); + const typographyProps = { + color: playerPrompt ? 'success.text' : 'primary.text', + }; + + return + + + - {Object.entries(images).map(([name, image]) => - openImage(image)} alt="Render" /> - )} - } + primary="Prompt" + primaryTypographyProps={typographyProps} + secondaryTypographyProps={typographyProps} + secondary={ + + Prompt for {character}: {prompt} + + } + /> + ; +} + +export function GenerateEventItem(props: EventItemProps) { + const { event, renderEntity } = props; + const { entity, name } = event; + + return renderEntity(entity.name)}> + + + } + > + + + + + {name} + + } /> ; } @@ -188,6 +281,10 @@ export function EventItem(props: EventItemProps) { return ; case 'snapshot': return ; + case 'prompt': + return ; + case 'generate': + return ; default: return diff --git a/client/src/history.tsx b/client/src/history.tsx index ea0cb79..f4e4cb4 100644 --- a/client/src/history.tsx +++ b/client/src/history.tsx @@ -3,7 +3,6 @@ import { Divider, List } from '@mui/material'; import React, { useEffect, useRef } from 'react'; import { useStore } from 'zustand'; import { EventItem } from './events'; -import { GameEvent } from './models'; import { StoreState, store } from './store'; export function historyStateSelector(s: StoreState) { @@ -14,13 +13,13 @@ export function historyStateSelector(s: StoreState) { } export interface HistoryPanelProps { - renderEvent: (event: GameEvent) => void; + renderEntity: (type: string, entity: string) => void; + renderEvent: (event: string) => void; } export function HistoryPanel(props: HistoryPanelProps) { const state = useStore(store, historyStateSelector); const { history, scroll } = state; - const { renderEvent } = props; const scrollRef = useRef>(undefined); @@ -34,10 +33,10 @@ export function HistoryPanel(props: HistoryPanelProps) { const items = history.map((item, index) => { if (index === history.length - 1) { - return ; + return ; } - return ; + return ; }); return diff --git a/client/src/store.ts b/client/src/store.ts index c35edb8..4f52375 100644 --- a/client/src/store.ts +++ b/client/src/store.ts @@ -10,7 +10,7 @@ export interface ClientState { autoScroll: boolean; clientId: string; clientName: string; - detailEntity: Maybe; + detailEntity: Maybe; eventHistory: Array; readyState: ReadyState; themeMode: PaletteMode; @@ -19,7 +19,7 @@ export interface ClientState { setAutoScroll: (autoScroll: boolean) => void; setClientId: (clientId: string) => void; setClientName: (name: string) => void; - setDetailEntity: (entity: Maybe) => void; + setDetailEntity: (entity: Maybe) => void; setReadyState: (state: ReadyState) => void; setThemeMode: (mode: PaletteMode) => void; diff --git a/client/src/world.tsx b/client/src/world.tsx index 02b2331..4fb799a 100644 --- a/client/src/world.tsx +++ b/client/src/world.tsx @@ -6,7 +6,7 @@ import React from 'react'; import { useStore } from 'zustand'; import { StoreState, store } from './store'; -import { Actor, Item, Room, World } from './models'; +import { Actor, Item, Room } from './models'; export type SetDetails = (entity: Maybe) => void; export type SetPlayer = (actor: Maybe) => void; @@ -33,6 +33,7 @@ export function itemStateSelector(s: StoreState) { export function worldStateSelector(s: StoreState) { return { world: s.world, + setDetailEntity: s.setDetailEntity, }; } @@ -91,7 +92,7 @@ export function RoomItem(props: { room: Room } & BaseEntityItemProps) { export function WorldPanel(props: BaseEntityItemProps) { const { setPlayer } = props; const state = useStore(store, worldStateSelector); - const { world } = state; + const { world, setDetailEntity } = state; // eslint-disable-next-line no-restricted-syntax if (!doesExist(world)) { @@ -111,6 +112,7 @@ export function WorldPanel(props: BaseEntityItemProps) { Theme: {world.theme} + setDetailEntity(world)} /> {world.rooms.map((room) => )} diff --git a/client/yarn.lock b/client/yarn.lock index 588a59f..b0f9020 100644 --- a/client/yarn.lock +++ b/client/yarn.lock @@ -798,6 +798,11 @@ resolved "https://registry.yarnpkg.com/@ungap/structured-clone/-/structured-clone-1.2.0.tgz#756641adb587851b5ccb3e095daf27ae581c8406" integrity sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ== +"@viz-js/viz@^3.5.0": + version "3.5.0" + resolved "https://registry.yarnpkg.com/@viz-js/viz/-/viz-3.5.0.tgz#9fd09729cd2cdcbc51b0ea293a1954e6839797b2" + integrity sha512-66iFqMC2m0lZhvmHXFyJY12Jn8v9hswFMR3nsumN1dfhNoVrAHsa/7xpB3BojIVyj8IeEc8ciLjxZVdUnhcOxw== + "@xobotyi/scrollbar-width@^1.9.5": version "1.9.5" resolved "https://registry.yarnpkg.com/@xobotyi/scrollbar-width/-/scrollbar-width-1.9.5.tgz#80224a6919272f405b87913ca13b92929bdf3c4d" diff --git a/config.yml b/config.yml new file mode 100644 index 0000000..00c6634 --- /dev/null +++ b/config.yml @@ -0,0 +1,24 @@ +bot: + discord: + channels: [bots] +render: + cfg: + min: 5 + max: 8 + checkpoints: [ + "diffusion-sdxl-dynavision-0-5-5-7.safetensors", + ] + path: /tmp/adventure-images + sizes: + landscape: + width: 1280 + height: 960 + portrait: + width: 960 + height: 1280 + square: + width: 1024 + height: 1024 + steps: + min: 30 + max: 50 \ No newline at end of file