1
0
Fork 0

start using config, render entities, render world graph in browser client

This commit is contained in:
Sean Sube 2024-05-12 15:47:18 -05:00
parent e654ac2df9
commit 5f7dd3bb89
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
19 changed files with 703 additions and 260 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
adventure/custom_* adventure/custom_*
adventure/user_config.yaml
worlds/ worlds/
__pycache__/ __pycache__/
.env .env

View File

@ -13,6 +13,7 @@ from adventure.context import (
get_current_world, get_current_world,
set_actor_agent, set_actor_agent,
) )
from adventure.models.config import DiscordBotConfig
from adventure.models.event import ( from adventure.models.event import (
ActionEvent, ActionEvent,
GameEvent, GameEvent,
@ -24,11 +25,18 @@ from adventure.models.event import (
ResultEvent, ResultEvent,
StatusEvent, StatusEvent,
) )
from adventure.player import RemotePlayer, get_player, has_player, set_player from adventure.player import (
RemotePlayer,
get_player,
has_player,
remove_player,
set_player,
)
from adventure.render_comfy import render_event from adventure.render_comfy import render_event
logger = getLogger(__name__) logger = getLogger(__name__)
client = None client = None
bot_config: DiscordBotConfig = DiscordBotConfig(channels=["bots"])
active_tasks = set() active_tasks = set()
event_messages: Dict[str, str | GameEvent] = {} event_messages: Dict[str, str | GameEvent] = {}
@ -45,17 +53,17 @@ def remove_tags(text: str) -> str:
class AdventureClient(Client): class AdventureClient(Client):
async def on_ready(self): async def on_ready(self):
logger.info(f"Logged in as {self.user}") logger.info(f"logged in as {self.user}")
async def on_reaction_add(self, reaction, user): async def on_reaction_add(self, reaction, user):
if user == self.user: if user == self.user:
return return
logger.info(f"Reaction added: {reaction} by {user}") logger.info(f"reaction added: {reaction} by {user}")
if reaction.emoji == "📷": if reaction.emoji == "📷":
message_id = reaction.message.id message_id = reaction.message.id
if message_id not in event_messages: if message_id not in event_messages:
logger.warning(f"Message {message_id} not found in event messages") logger.warning(f"message {message_id} not found in event messages")
# TODO: return error message # TODO: return error message
return return
@ -119,19 +127,25 @@ class AdventureClient(Client):
return broadcast(join_event) return broadcast(join_event)
player = get_player(user_name) player = get_player(user_name)
if player: if isinstance(player, RemotePlayer):
if message.content.startswith("!leave"): if message.content.startswith("!leave"):
# TODO: check if player is playing remove_player(user_name)
# TODO: revert to LLM agent
logger.info(f"{user_name} has left the game!") # revert to LLM agent
actor, _ = get_actor_agent_for_name(player.name)
if actor and player.fallback_agent:
logger.info("restoring LLM agent for %s", player.name)
set_actor_agent(actor.name, actor, player.fallback_agent)
# broadcast leave event
logger.info("disconnecting player %s from %s", user_name, player.name)
leave_event = PlayerEvent("leave", player.name, user_name) leave_event = PlayerEvent("leave", player.name, user_name)
return broadcast(leave_event) return broadcast(leave_event)
else:
if isinstance(player, RemotePlayer):
content = remove_tags(message.content) content = remove_tags(message.content)
player.input_queue.put(content) player.input_queue.put(content)
logger.info( logger.info(
f"Received message from {user_name} for {player.name}: {content}" f"received message from {user_name} for {player.name}: {content}"
) )
return return
@ -141,11 +155,16 @@ class AdventureClient(Client):
return return
def launch_bot(): def launch_bot(config: DiscordBotConfig):
global bot_config
global client global client
bot_config = config
# message contents need to be enabled for multi-server bots
intents = Intents.default() intents = Intents.default()
# intents.message_content = True if bot_config.content_intent:
intents.message_content = True
client = AdventureClient(intents=intents) client = AdventureClient(intents=intents)
@ -164,6 +183,7 @@ def launch_bot():
# logger.debug("no events to prompt") # logger.debug("no events to prompt")
continue continue
# wait for pending messages to send, to keep them in order
if len(active_tasks) > 0: if len(active_tasks) > 0:
logger.debug("waiting for active tasks to complete") logger.debug("waiting for active tasks to complete")
continue continue
@ -178,6 +198,7 @@ def launch_bot():
else: else:
logger.warning("no Discord client available") logger.warning("no Discord client available")
logger.info("launching Discord bot")
bot_thread = Thread(target=bot_main, daemon=True) bot_thread = Thread(target=bot_main, daemon=True)
bot_thread.start() bot_thread.start()
@ -205,7 +226,7 @@ def get_active_channels():
channel channel
for guild in client.guilds for guild in client.guilds
for channel in guild.text_channels for channel in guild.text_channels
if channel.name == "bots" if channel.name in bot_config.channels
] ]
@ -286,6 +307,8 @@ def embed_from_event(event: GameEvent) -> Embed:
return embed_from_status(event) return embed_from_status(event)
elif isinstance(event, PlayerEvent): elif isinstance(event, PlayerEvent):
return embed_from_player(event) return embed_from_player(event)
elif isinstance(event, PromptEvent):
return embed_from_prompt(event)
else: else:
logger.warning("unknown event type: %s", event) logger.warning("unknown event type: %s", event)
@ -334,8 +357,14 @@ def embed_from_player(event: PlayerEvent):
return player_embed return player_embed
def embed_from_prompt(event: PromptEvent):
# TODO: ping the player
prompt_embed = Embed(title=event.room.name, description=event.actor.name)
prompt_embed.add_field(name="Prompt", value=event.prompt)
return prompt_embed
def embed_from_status(event: StatusEvent): def embed_from_status(event: StatusEvent):
# TODO: add room and actor
status_embed = Embed( status_embed = Embed(
title=event.room.name if event.room else "", title=event.room.name if event.room else "",
description=event.actor.name if event.actor else "", description=event.actor.name if event.actor else "",

View File

@ -5,7 +5,7 @@ from typing import List
from packit.agent import Agent from packit.agent import Agent
from packit.loops import loop_retry from packit.loops import loop_retry
from adventure.models.entity import Actor, Item, Room, World from adventure.models.entity import Actor, Item, Room, World, WorldEntity
from adventure.models.event import EventCallback, GenerateEvent from adventure.models.event import EventCallback, GenerateEvent
logger = getLogger(__name__) logger = getLogger(__name__)
@ -171,10 +171,18 @@ def generate_world(
) -> World: ) -> World:
room_count = room_count or randint(3, max_rooms) room_count = room_count or randint(3, max_rooms)
if callable(callback): def callback_wrapper(message: str | None = None, entity: WorldEntity | None = None):
callback( if message:
GenerateEvent.from_name(f"Generating a {theme} with {room_count} rooms") event = GenerateEvent.from_name(message)
) elif entity:
event = GenerateEvent.from_entity(entity)
else:
raise ValueError("Either message or entity must be provided")
if callable(callback):
callback(event)
callback_wrapper(message=f"Generating a {theme} with {room_count} rooms")
existing_actors: List[str] = [] existing_actors: List[str] = []
existing_items: List[str] = [] existing_items: List[str] = []
@ -186,17 +194,13 @@ def generate_world(
room = generate_room( room = generate_room(
agent, theme, existing_rooms=existing_rooms, callback=callback agent, theme, existing_rooms=existing_rooms, callback=callback
) )
callback_wrapper(entity=room)
rooms.append(room) rooms.append(room)
existing_rooms.append(room.name) existing_rooms.append(room.name)
item_count = randint(1, 3) item_count = randint(1, 3)
if callable(callback): callback_wrapper(f"Generating {item_count} items for room: {room.name}")
callback(
GenerateEvent.from_name(
f"Generating {item_count} items for room: {room.name}"
)
)
for j in range(item_count): for j in range(item_count):
item = generate_item( item = generate_item(
@ -206,17 +210,16 @@ def generate_world(
existing_items=existing_items, existing_items=existing_items,
callback=callback, callback=callback,
) )
callback_wrapper(entity=item)
room.items.append(item) room.items.append(item)
existing_items.append(item.name) existing_items.append(item.name)
actor_count = randint(1, 3) actor_count = randint(1, 3)
if callable(callback): callback_wrapper(
callback( message=f"Generating {actor_count} actors for room: {room.name}"
GenerateEvent.from_name( )
f"Generating {actor_count} actors for room: {room.name}"
)
)
for j in range(actor_count): for j in range(actor_count):
actor = generate_actor( actor = generate_actor(
@ -226,18 +229,15 @@ def generate_world(
existing_actors=existing_actors, existing_actors=existing_actors,
callback=callback, callback=callback,
) )
callback_wrapper(entity=actor)
room.actors.append(actor) room.actors.append(actor)
existing_actors.append(actor.name) existing_actors.append(actor.name)
# generate the actor's inventory # generate the actor's inventory
item_count = randint(0, 2) item_count = randint(0, 2)
if callable(callback): callback_wrapper(f"Generating {item_count} items for actor {actor.name}")
callback(
GenerateEvent.from_name(
f"Generating {item_count} items for actor {actor.name}"
)
)
for k in range(item_count): for k in range(item_count):
item = generate_item( item = generate_item(
@ -247,6 +247,8 @@ def generate_world(
existing_items=existing_items, existing_items=existing_items,
callback=callback, callback=callback,
) )
callback_wrapper(entity=item)
actor.items.append(item) actor.items.append(item)
existing_items.append(item.name) existing_items.append(item.name)

View File

@ -144,7 +144,7 @@ def format_logic(attributes: Attributes, rules: LogicTable, self=True) -> str:
logger.debug("label has no relevant description: %s", label) logger.debug("label has no relevant description: %s", label)
if len(labels) > 0: if len(labels) > 0:
logger.info("adding attribute labels: %s", labels) logger.debug("adding attribute labels: %s", labels)
return " ".join(labels) return " ".join(labels)

View File

@ -10,8 +10,9 @@ from yaml import Loader, load
from adventure.context import set_current_step, set_dungeon_master from adventure.context import set_current_step, set_dungeon_master
from adventure.generate import generate_world from adventure.generate import generate_world
from adventure.models.config import Config
from adventure.models.entity import World, WorldState from adventure.models.entity import World, WorldState
from adventure.models.event import EventCallback, GameEvent from adventure.models.event import EventCallback, GameEvent, GenerateEvent
from adventure.models.files import PromptFile, WorldPrompt from adventure.models.files import PromptFile, WorldPrompt
from adventure.plugins import load_plugin from adventure.plugins import load_plugin
from adventure.simulate import simulate_world from adventure.simulate import simulate_world
@ -36,7 +37,7 @@ except Exception as err:
print("error loading logging config: %s" % (err)) print("error loading logging config: %s" % (err))
logger = logger_with_colors(__name__, level="DEBUG") logger = logger_with_colors(__name__) # , level="DEBUG")
load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True) load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True)
@ -64,7 +65,15 @@ def parse_args():
help="Extra actions to include in the simulation", help="Extra actions to include in the simulation",
) )
parser.add_argument( parser.add_argument(
"--discord", type=bool, help="Whether to run the simulation in a Discord bot" "--config",
type=str,
default="config.yml",
help="The file to load the configuration from",
)
parser.add_argument(
"--discord",
action="store_true",
help="Whether to run the simulation in a Discord bot",
) )
parser.add_argument( parser.add_argument(
"--flavor", "--flavor",
@ -73,29 +82,50 @@ def parse_args():
help="Some additional flavor text for the generated world", help="Some additional flavor text for the generated world",
) )
parser.add_argument( parser.add_argument(
"--player", type=str, help="The name of the character to play as" "--max-rooms",
type=int,
help="The maximum number of rooms to generate",
) )
parser.add_argument( parser.add_argument(
"--rooms", type=int, default=5, help="The number of rooms to generate" "--optional-actions",
action="store_true",
help="Whether to include optional actions",
) )
parser.add_argument( parser.add_argument(
"--max-rooms", type=int, help="The maximum number of rooms to generate" "--player",
type=str,
help="The name of the character to play as",
) )
parser.add_argument( parser.add_argument(
"--optional-actions", type=bool, help="Whether to include optional actions" "--render",
action="store_true",
help="Whether to render the simulation",
) )
parser.add_argument("--render", type=bool, help="Whether to render the simulation")
parser.add_argument( parser.add_argument(
"--server", type=str, help="The address on which to run the server" "--render-generated",
action="store_true",
help="Whether to render entities as they are generated",
)
parser.add_argument(
"--rooms",
type=int,
help="The number of rooms to generate",
)
parser.add_argument(
"--server",
type=str,
help="The address on which to run the server",
) )
parser.add_argument( parser.add_argument(
"--state", "--state",
type=str, type=str,
# default="world.state.json",
help="The file to save the world state to. Defaults to $world.state.json, if not set", help="The file to save the world state to. Defaults to $world.state.json, if not set",
) )
parser.add_argument( parser.add_argument(
"--steps", type=int, default=10, help="The number of simulation steps to run" "--steps",
type=int,
default=10,
help="The number of simulation steps to run",
) )
parser.add_argument( parser.add_argument(
"--systems", "--systems",
@ -104,7 +134,10 @@ def parse_args():
help="Extra systems to run in the simulation", help="Extra systems to run in the simulation",
) )
parser.add_argument( parser.add_argument(
"--theme", type=str, default="fantasy", help="The theme of the generated world" "--theme",
type=str,
default="fantasy",
help="The theme of the generated world",
) )
parser.add_argument( parser.add_argument(
"--world", "--world",
@ -113,16 +146,16 @@ def parse_args():
help="The file to save the generated world to", help="The file to save the generated world to",
) )
parser.add_argument( parser.add_argument(
"--world-prompt", "--world-template",
type=str, type=str,
help="The file to load the world prompt from", help="The template file to load the world prompt from",
) )
return parser.parse_args() return parser.parse_args()
def get_world_prompt(args) -> WorldPrompt: def get_world_prompt(args) -> WorldPrompt:
if args.world_prompt: if args.world_template:
prompt_file, prompt_name = args.world_prompt.split(":") prompt_file, prompt_name = args.world_template.split(":")
with open(prompt_file, "r") as f: with open(prompt_file, "r") as f:
prompts = PromptFile(**load_yaml(f)) prompts = PromptFile(**load_yaml(f))
for prompt in prompts.prompts: for prompt in prompts.prompts:
@ -138,7 +171,9 @@ def get_world_prompt(args) -> WorldPrompt:
) )
def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt): def load_or_generate_world(
args, players, callbacks, systems, world_prompt: WorldPrompt
):
world_file = args.world + ".json" world_file = args.world + ".json"
world_state_file = args.state or (args.world + ".state.json") world_state_file = args.state or (args.world + ".state.json")
@ -170,7 +205,7 @@ def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt):
world = None world = None
def broadcast_callback(event: GameEvent): def broadcast_callback(event: GameEvent):
logger.info(event) logger.debug("broadcasting generation event: %s", event)
for callback in callbacks: for callback in callbacks:
callback(event) callback(event)
@ -184,6 +219,10 @@ def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt):
) )
save_world(world, world_file) save_world(world, world_file)
# run the systems once to initialize everything
for system_update, _ in systems:
system_update(world, 0)
create_agents(world, memory=memory, players=players) create_agents(world, memory=memory, players=players)
return (world, world_state_file) return (world, world_state_file)
@ -191,6 +230,9 @@ def load_or_generate_world(args, players, callbacks, world_prompt: WorldPrompt):
def main(): def main():
args = parse_args() args = parse_args()
with open(args.config, "r") as f:
config = Config(**load_yaml(f))
players = [] players = []
if args.player: if args.player:
players.append(args.player) players.append(args.player)
@ -204,12 +246,22 @@ def main():
if args.render: if args.render:
from adventure.render_comfy import launch_render from adventure.render_comfy import launch_render
threads.extend(launch_render()) threads.extend(launch_render(config.render))
if args.render_generated:
from adventure.render_comfy import render_entity
def render_generated(event: GameEvent):
if isinstance(event, GenerateEvent) and event.entity:
logger.info("rendering generated entity: %s", event.entity.name)
render_entity(event.entity)
callbacks.append(render_generated)
if args.discord: if args.discord:
from adventure.bot_discord import bot_event, launch_bot from adventure.bot_discord import bot_event, launch_bot
threads.extend(launch_bot()) threads.extend(launch_bot(config.bot.discord))
callbacks.append(bot_event) callbacks.append(bot_event)
if args.server: if args.server:
@ -263,7 +315,7 @@ def main():
# load or generate the world # load or generate the world
world_prompt = get_world_prompt(args) world_prompt = get_world_prompt(args)
world, world_state_file = load_or_generate_world( world, world_state_file = load_or_generate_world(
args, players, callbacks, world_prompt=world_prompt args, players, callbacks, extra_systems, world_prompt=world_prompt
) )
# make sure the snapshot system runs last # make sure the snapshot system runs last
@ -273,9 +325,9 @@ def main():
extra_systems.append((snapshot_system, None)) extra_systems.append((snapshot_system, None))
# run the systems once to initialize everything # hack: send a snapshot to the websocket server
for system_update, _ in extra_systems: if args.server:
system_update(world, 0) server_system(world, 0)
# create the DM # create the DM
llm = agent_easy_connect() llm = agent_easy_connect()

View File

@ -0,0 +1,41 @@
from typing import Dict, List
from .base import dataclass
@dataclass
class Range:
min: int
max: int
@dataclass
class Size:
width: int
height: int
@dataclass
class DiscordBotConfig:
channels: List[str]
content_intent: bool = False
@dataclass
class BotConfig:
discord: DiscordBotConfig
@dataclass
class RenderConfig:
cfg: Range
checkpoints: List[str]
path: str
sizes: Dict[str, Size]
steps: Range
@dataclass
class Config:
bot: BotConfig
render: RenderConfig

View File

@ -1,5 +1,5 @@
from json import loads from json import loads
from typing import Any, Callable, Dict, List, Literal from typing import Any, Callable, Dict, List, Literal, Union
from uuid import uuid4 from uuid import uuid4
from pydantic import Field from pydantic import Field
@ -175,7 +175,7 @@ class RenderEvent(BaseEvent):
id = Field(default_factory=uuid) id = Field(default_factory=uuid)
type = "render" type = "render"
paths: List[str] paths: List[str]
source: "GameEvent" source: Union["GameEvent", WorldEntity]
# event types # event types

View File

@ -1,22 +1,22 @@
# This is an example that uses the websockets api to know when a prompt execution is done
# Once the prompt execution is done it downloads the images using the /history endpoint
import io import io
import json import json
import urllib.parse import urllib.parse
import urllib.request import urllib.request
import uuid
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from queue import Queue from queue import Queue
from random import choice, randint from random import choice, randint
from threading import Thread from threading import Thread
from typing import List from typing import List
from uuid import uuid4
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
from jinja2 import Environment, FileSystemLoader, select_autoescape
from PIL import Image from PIL import Image
from adventure.context import broadcast from adventure.context import broadcast
from adventure.models.config import Range, RenderConfig, Size
from adventure.models.entity import WorldEntity
from adventure.models.event import ( from adventure.models.event import (
ActionEvent, ActionEvent,
GameEvent, GameEvent,
@ -29,15 +29,39 @@ from adventure.models.event import (
logger = getLogger(__name__) logger = getLogger(__name__)
server_address = environ["COMFY_API"] server_address = environ["COMFY_API"]
client_id = str(uuid.uuid4()) client_id = uuid4().hex
render_config: RenderConfig = RenderConfig(
cfg=Range(min=5, max=8),
checkpoints=[
"diffusion-sdxl-dynavision-0-5-5-7.safetensors",
],
path="/tmp/adventure-images",
sizes={
"landscape": Size(width=1024, height=768),
"portrait": Size(width=768, height=1024),
"square": Size(width=768, height=768),
},
steps=Range(min=30, max=30),
)
# requests to generate images for game events
render_queue: Queue[GameEvent | WorldEntity] = Queue()
render_thread: Thread | None = None
def generate_cfg(): def generate_cfg():
return randint(5, 8) if render_config.cfg.min == render_config.cfg.max:
return render_config.cfg.min
return randint(render_config.cfg.min, render_config.cfg.max)
def generate_steps(): def generate_steps():
return 30 if render_config.steps.min == render_config.steps.max:
return render_config.steps.min
return randint(render_config.steps.min, render_config.steps.max)
def generate_batches( def generate_batches(
@ -93,7 +117,7 @@ def get_images(ws, prompt):
continue # previews are binary data continue # previews are binary data
history = get_history(prompt_id)[prompt_id] history = get_history(prompt_id)[prompt_id]
for o in history["outputs"]: for _ in history["outputs"]:
for node_id in history["outputs"]: for node_id in history["outputs"]:
node_output = history["outputs"][node_id] node_output = history["outputs"][node_id]
if "images" in node_output: if "images" in node_output:
@ -117,86 +141,47 @@ def generate_image_tool(prompt, count, size="landscape"):
return output_paths return output_paths
sizes = {
"landscape": (1024, 768),
"portrait": (768, 1024),
"square": (768, 768),
}
def generate_images( def generate_images(
prompt: str, count: int, size="landscape", prefix="output" prompt: str, count: int, size="landscape", prefix="output"
) -> List[str]: ) -> List[str]:
cfg = generate_cfg() cfg = generate_cfg()
width, height = sizes.get(size, (512, 512)) dims = render_config.sizes[size]
steps = generate_steps() steps = generate_steps()
seed = randint(0, 10000000) seed = randint(0, 10000000)
checkpoint = choice(["diffusion-sdxl-dynavision-0-5-5-7.safetensors"]) checkpoint = choice(render_config.checkpoints)
logger.info( logger.info(
"generating %s images at %s by %s with prompt: %s", count, width, height, prompt "generating %s images at %s by %s with prompt: %s",
count,
dims.width,
dims.height,
prompt,
)
env = Environment(
loader=FileSystemLoader(["adventure/templates"]),
autoescape=select_autoescape(["json"]),
)
template = env.get_template("comfy.json.j2")
result = template.render(
cfg=cfg,
height=dims.height,
width=dims.width,
steps=steps,
seed=seed,
checkpoint=checkpoint,
prompt=prompt.replace("\n", ". "),
negative_prompt="",
count=count,
prefix=prefix,
) )
# parsing here helps ensure the template emits valid JSON # parsing here helps ensure the template emits valid JSON
prompt_workflow = { logger.debug("template workflow: %s", result)
"3": { prompt_workflow = json.loads(result)
"class_type": "KSampler",
"inputs": {
"cfg": cfg,
"denoise": 1,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"seed": seed,
"steps": steps,
},
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {"ckpt_name": checkpoint},
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {"batch_size": count, "height": height, "width": width},
},
"6": {
"class_type": "smZ CLIPTextEncode",
"inputs": {
"text": prompt,
"parser": "compel",
"mean_normalization": True,
"multi_conditioning": True,
"use_old_emphasis_implementation": False,
"with_SDXL": False,
"ascore": 6,
"width": width,
"height": height,
"crop_w": 0,
"crop_h": 0,
"target_width": width,
"target_height": height,
"text_g": "",
"text_l": "",
"smZ_steps": 1,
"clip": ["4", 1],
},
},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "", "clip": ["4", 1]}},
"8": {
"class_type": "VAEDecode",
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
},
"9": {
"class_type": "SaveImage",
"inputs": {"filename_prefix": prefix, "images": ["8", 0]},
},
}
logger.debug("Connecting to Comfy API at %s", server_address) logger.debug("connecting to Comfy API at %s", server_address)
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id), timeout=60)
images = get_images(ws, prompt_workflow) images = get_images(ws, prompt_workflow)
results = [] results = []
@ -207,8 +192,7 @@ def generate_images(
paths: List[str] = [] paths: List[str] = []
for j, image in enumerate(results): for j, image in enumerate(results):
# TODO: replace with environment variable image_path = path.join(render_config.path, f"{prefix}-{j}.png")
image_path = path.join("/home/ssube/adventure-images", f"{prefix}-{j}.png")
with open(image_path, "wb") as f: with open(image_path, "wb") as f:
image_bytes = io.BytesIO() image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG") image.save(image_bytes, format="PNG")
@ -244,51 +228,85 @@ def prompt_from_event(event: GameEvent) -> str | None:
return None return None
def prefix_from_event(event: GameEvent) -> str: def prompt_from_entity(entity: WorldEntity) -> str:
return entity.description
def get_image_prefix(event: GameEvent | WorldEntity) -> str:
if isinstance(event, ActionEvent): if isinstance(event, ActionEvent):
return ( return f"event-action-{event.actor.name}-{event.action}"
f"{event.actor.name}-{event.action}-{event.item.name if event.item else ''}"
)
if isinstance(event, ReplyEvent): if isinstance(event, ReplyEvent):
return f"{event.actor.name}-reply" return f"event-reply-{event.actor.name}"
if isinstance(event, ResultEvent): if isinstance(event, ResultEvent):
return f"{event.actor.name}-result" return f"event-result-{event.actor.name}"
if isinstance(event, StatusEvent): if isinstance(event, StatusEvent):
return "status" return "status"
if isinstance(event, WorldEntity):
return f"entity-{event.__class__.__name__.lower()}-{event.name}"
return "unknown" return "unknown"
# requests to generate images for game events
render_queue: Queue[GameEvent] = Queue()
def render_loop(): def render_loop():
while True: while True:
event = render_queue.get() event = render_queue.get()
prompt = prompt_from_event(event) prefix = get_image_prefix(event)
# check if images already exist
image_index = 0
image_path = path.join(render_config.path, f"{prefix}-{image_index}.png")
existing_images = []
while path.exists(image_path):
existing_images.append(image_path)
image_index += 1
image_path = path.join(render_config.path, f"{prefix}-{image_index}.png")
if existing_images:
logger.info(
"using existing images for event %s: %s", event, existing_images
)
broadcast(RenderEvent(paths=existing_images, source=event))
continue
# generate the prompt
if isinstance(event, WorldEntity):
logger.info("rendering entity %s", event)
prompt = prompt_from_entity(event)
else:
logger.info("rendering event %s", event)
prompt = prompt_from_event(event)
# render or not
if prompt: if prompt:
logger.info("rendering prompt for event %s: %s", event, prompt) logger.info("rendering prompt for event %s: %s", event, prompt)
prefix = prefix_from_event(event)
image_paths = generate_images(prompt, 2, prefix=prefix) image_paths = generate_images(prompt, 2, prefix=prefix)
broadcast(RenderEvent(paths=image_paths, source=event)) broadcast(RenderEvent(paths=image_paths, source=event))
else: else:
logger.warning("no prompt for event %s", event) logger.warning("no prompt for event %s", event)
def render_entity(entity: WorldEntity):
render_queue.put(entity)
def render_event(event: GameEvent): def render_event(event: GameEvent):
render_queue.put(event) render_queue.put(event)
render_thread = None def launch_render(config: RenderConfig):
global render_config
def launch_render():
global render_thread global render_thread
# update the config
logger.info("updating render config: %s", config)
render_config = config
# start the render thread
logger.info("launching render thread")
render_thread = Thread(target=render_loop, daemon=True) render_thread = Thread(target=render_loop, daemon=True)
render_thread.start() render_thread.start()

View File

@ -12,7 +12,12 @@ import websockets
from PIL import Image from PIL import Image
from pydantic import RootModel from pydantic import RootModel
from adventure.context import broadcast, get_actor_agent_for_name, set_actor_agent from adventure.context import (
broadcast,
get_actor_agent_for_name,
get_current_world,
set_actor_agent,
)
from adventure.models.entity import Actor, Item, Room, World from adventure.models.entity import Actor, Item, Room, World
from adventure.models.event import ( from adventure.models.event import (
GameEvent, GameEvent,
@ -29,16 +34,16 @@ from adventure.player import (
remove_player, remove_player,
set_player, set_player,
) )
from adventure.render_comfy import render_event from adventure.render_comfy import render_entity, render_event
from adventure.state import snapshot_world, world_json from adventure.state import snapshot_world, world_json
logger = getLogger(__name__) logger = getLogger(__name__)
connected = set() connected = set()
last_snapshot: str | None = None
player_names: Dict[str, str] = {}
recent_events: MutableSequence[GameEvent] = deque(maxlen=100) recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
recent_json: MutableSequence[str] = deque(maxlen=100) recent_json: MutableSequence[str] = deque(maxlen=100)
last_snapshot = None
player_names: Dict[str, str] = {}
def get_player_name(client_id: str) -> str: def get_player_name(client_id: str) -> str:
@ -47,13 +52,14 @@ def get_player_name(client_id: str) -> str:
async def handler(websocket): async def handler(websocket):
id = uuid4().hex id = uuid4().hex
logger.info("Client connected, given id: %s", id) logger.info("client connected, given id: %s", id)
connected.add(websocket) connected.add(websocket)
async def next_turn(character: str, prompt: str) -> None: async def next_turn(character: str, prompt: str) -> None:
await websocket.send( await websocket.send(
dumps( dumps(
{ {
# TODO: these should be fields in the PromptEvent
"type": "prompt", "type": "prompt",
"client": id, "client": id,
"character": character, "character": character,
@ -64,6 +70,7 @@ async def handler(websocket):
) )
def sync_turn(event: PromptEvent) -> bool: def sync_turn(event: PromptEvent) -> bool:
# TODO: nothing about this is good
player = get_player(id) player = get_player(id)
if player and player.name == event.actor.name: if player and player.name == event.actor.name:
asyncio.run(next_turn(event.actor.name, event.prompt)) asyncio.run(next_turn(event.actor.name, event.prompt))
@ -74,21 +81,21 @@ async def handler(websocket):
try: try:
await websocket.send(dumps({"type": "id", "client": id})) await websocket.send(dumps({"type": "id", "client": id}))
# TODO: only send this if the recent events don't contain a snapshot # only send the snapshot once
if last_snapshot and last_snapshot not in recent_json: if last_snapshot and last_snapshot not in recent_json:
await websocket.send(last_snapshot) await websocket.send(last_snapshot)
for message in recent_json: for message in recent_json:
await websocket.send(message) await websocket.send(message)
except Exception: except Exception:
logger.exception("Failed to send recent messages to new client") logger.exception("failed to send recent messages to new client")
while True: while True:
try: try:
# if this socket is attached to a character and that character's turn is active, wait for input # if this socket is attached to a character and that character's turn is active, wait for input
message = await websocket.recv() message = await websocket.recv()
player_name = get_player_name(id) player_name = get_player_name(id)
logger.info(f"Received message for {player_name}: {message}") logger.info(f"received message for {player_name}: {message}")
try: try:
data = loads(message) data = loads(message)
@ -106,7 +113,7 @@ async def handler(websocket):
) )
if existing_id is not None: if existing_id is not None:
logger.error( logger.error(
f"Name {new_player_name} is already in use by {existing_id}" f"name {new_player_name} is already in use by {existing_id}"
) )
continue continue
@ -119,7 +126,7 @@ async def handler(websocket):
character_name = data["become"] character_name = data["become"]
if has_player(character_name): if has_player(character_name):
logger.error( logger.error(
f"Character {character_name} is already in use" f"character {character_name} is already in use"
) )
continue continue
@ -146,7 +153,7 @@ async def handler(websocket):
) )
set_player(id, player) set_player(id, player)
logger.info( logger.info(
f"Client {player_name} is now character {character_name}" f"client {player_name} is now character {character_name}"
) )
# swap out the LLM agent # swap out the LLM agent
@ -163,15 +170,10 @@ async def handler(websocket):
) )
player.input_queue.put(data["input"]) player.input_queue.put(data["input"])
elif message_type == "render": elif message_type == "render":
event_id = data["event"] render_input(data)
event = next((e for e in recent_events if e.id == event_id), None)
if event:
render_event(event)
else:
logger.error(f"Failed to find event {event_id}")
except Exception: except Exception:
logger.exception("Failed to parse message") logger.exception("failed to parse message")
except websockets.ConnectionClosedOK: except websockets.ConnectionClosedOK:
break break
@ -197,6 +199,56 @@ async def handler(websocket):
logger.info("client disconnected: %s", id) logger.info("client disconnected: %s", id)
def render_input(data):
world = get_current_world()
if not world:
logger.error("no world available")
return
if "event" in data:
event_id = data["event"]
event = next((e for e in recent_events if e.id == event_id), None)
if event:
render_event(event)
else:
logger.error(f"failed to find event {event_id}")
elif "actor" in data:
actor_name = data["actor"]
actor = next(
(a for r in world.rooms for a in r.actors if a.name == actor_name), None
)
if actor:
render_entity(actor)
else:
logger.error(f"failed to find actor {actor_name}")
elif "room" in data:
room_name = data["room"]
room = next((r for r in world.rooms if r.name == room_name), None)
if room:
render_entity(room)
else:
logger.error(f"failed to find room {room_name}")
elif "item" in data:
item_name = data["item"]
item = None
for room in world.rooms:
item = next((i for i in room.items if i.name == item_name), None)
if item:
break
for actor in room.actors:
item = next((i for i in actor.items if i.name == item_name), None)
if item:
break
if item:
render_entity(item)
else:
logger.error(f"failed to find item {item_name}")
else:
logger.error(f"failed to find entity in {data}")
socket_thread = None socket_thread = None
@ -220,6 +272,7 @@ def launch_server():
def run_sockets(): def run_sockets():
asyncio.run(server_main()) asyncio.run(server_main())
logger.info("launching websocket server")
socket_thread = Thread(target=run_sockets, daemon=True) socket_thread = Thread(target=run_sockets, daemon=True)
socket_thread.start() socket_thread.start()
@ -228,7 +281,7 @@ def launch_server():
async def server_main(): async def server_main():
async with websockets.serve(handler, "", 8001): async with websockets.serve(handler, "", 8001):
logger.info("Server started") logger.info("websocket server started")
await asyncio.Future() # run forever await asyncio.Future() # run forever

View File

@ -0,0 +1,59 @@
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": {{ cfg }},
"denoise": 1,
"latent_image": ["5", 0],
"model": ["4", 0],
"negative": ["7", 0],
"positive": ["6", 0],
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"seed": {{ seed }},
"steps": {{ steps }}
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "{{ checkpoint }}"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": {{ count }},
"height": {{ height }},
"width": {{ width }}
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": {{ prompt | tojson }},
"clip": ["4", 1]
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": "",
"clip": ["4", 1]
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["3", 0],
"vae": ["4", 2]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": {{ prefix | tojson }},
"images": ["8", 0]
}
}
}

View File

@ -16,6 +16,7 @@
"@mui/x-tree-view": "^7.3.1", "@mui/x-tree-view": "^7.3.1",
"@types/lodash": "^4.14.192", "@types/lodash": "^4.14.192",
"@types/node": "^20.11.0", "@types/node": "^20.11.0",
"@viz-js/viz": "^3.5.0",
"allotment": "^1.20.0", "allotment": "^1.20.0",
"browser-bunyan": "^1.8.0", "browser-bunyan": "^1.8.0",
"i18next": "^22.4.14", "i18next": "^22.4.14",

View File

@ -1,31 +1,26 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { Maybe, doesExist } from '@apextoaster/js-utils'; import { Maybe, doesExist } from '@apextoaster/js-utils';
import { import {
Button,
Container, Container,
CssBaseline, CssBaseline,
Dialog,
DialogActions,
DialogContent,
DialogTitle,
Stack, Stack,
ThemeProvider, ThemeProvider,
Typography,
createTheme, createTheme,
} from '@mui/material'; } from '@mui/material';
import { Allotment } from 'allotment'; import { Allotment } from 'allotment';
import React, { Fragment, useEffect } from 'react'; import React, { useEffect } from 'react';
import useWebSocketModule from 'react-use-websocket'; import useWebSocketModule from 'react-use-websocket';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { HistoryPanel } from './history.js'; import { HistoryPanel } from './history.js';
import { Actor, GameEvent, Item, Room } from './models.js'; import { Actor } from './models.js';
import { PlayerPanel } from './player.js'; import { PlayerPanel } from './player.js';
import { store, StoreState } from './store.js';
import { WorldPanel } from './world.js';
import { Statusbar } from './status.js'; import { Statusbar } from './status.js';
import { StoreState, store } from './store.js';
import { WorldPanel } from './world.js';
import 'allotment/dist/style.css'; import 'allotment/dist/style.css';
import { DetailDialog } from './details.js';
import './main.css'; import './main.css';
const useWebSocket = (useWebSocketModule as any).default; const useWebSocket = (useWebSocketModule as any).default;
@ -34,51 +29,6 @@ export interface AppProps {
socketUrl: string; socketUrl: string;
} }
export interface EntityDetailsProps {
entity: Maybe<Item | Actor | Room>;
close: () => void;
}
export function EntityDetails(props: EntityDetailsProps) {
const { entity, close } = props;
// eslint-disable-next-line no-restricted-syntax
if (!doesExist(entity)) {
return <Fragment />;
}
return <Fragment>
<DialogTitle>{entity.name}</DialogTitle>
<DialogContent dividers>
<Typography>
{entity.description}
</Typography>
</DialogContent>
<DialogActions>
<Button onClick={close}>Close</Button>
</DialogActions>
</Fragment>;
}
export function detailStateSelector(s: StoreState) {
return {
detailEntity: s.detailEntity,
clearDetailEntity: s.clearDetailEntity,
};
}
export function DetailDialog() {
const state = useStore(store, detailStateSelector);
const { detailEntity, clearDetailEntity } = state;
return <Dialog
open={doesExist(detailEntity)}
onClose={clearDetailEntity}
>
<EntityDetails entity={detailEntity} close={clearDetailEntity} />
</Dialog>;
}
export function appStateSelector(s: StoreState) { export function appStateSelector(s: StoreState) {
return { return {
themeMode: s.themeMode, themeMode: s.themeMode,
@ -94,6 +44,10 @@ export function App(props: AppProps) {
const { lastMessage, readyState, sendMessage } = useWebSocket(props.socketUrl); const { lastMessage, readyState, sendMessage } = useWebSocket(props.socketUrl);
// socket senders // socket senders
function renderEntity(type: string, entity: string) {
sendMessage(JSON.stringify({ type: 'render', [type]: entity }));
}
function renderEvent(event: string) { function renderEvent(event: string) {
sendMessage(JSON.stringify({ type: 'render', event })); sendMessage(JSON.stringify({ type: 'render', event }));
} }
@ -138,14 +92,9 @@ export function App(props: AppProps) {
return; return;
case 'prompt': case 'prompt':
// prompts are broadcast to all players // prompts are broadcast to all players
if (event.client === clientId) { // only notify the active player
// only notify the active player setActiveTurn(event.client === clientId);
setActiveTurn(true); break;
break;
} else {
setActiveTurn(false);
return;
}
case 'player': case 'player':
if (event.status === 'join' && doesExist(world) && event.client === clientId) { if (event.status === 'join' && doesExist(world) && event.client === clientId) {
const { character: characterName } = event; const { character: characterName } = event;
@ -173,7 +122,7 @@ export function App(props: AppProps) {
return <ThemeProvider theme={theme}> return <ThemeProvider theme={theme}>
<CssBaseline /> <CssBaseline />
<DetailDialog /> <DetailDialog renderEntity={renderEntity} />
<Container maxWidth='xl'> <Container maxWidth='xl'>
<Stack direction="column"> <Stack direction="column">
<Statusbar setName={setName} /> <Statusbar setName={setName} />
@ -184,7 +133,7 @@ export function App(props: AppProps) {
<WorldPanel setPlayer={setPlayer} /> <WorldPanel setPlayer={setPlayer} />
</Stack> </Stack>
<Stack direction="column" sx={{ minWidth: 600 }} className="scroll-history"> <Stack direction="column" sx={{ minWidth: 600 }} className="scroll-history">
<HistoryPanel renderEvent={renderEvent} /> <HistoryPanel renderEntity={renderEntity} renderEvent={renderEvent} />
</Stack> </Stack>
</Allotment> </Allotment>
</Stack> </Stack>

111
client/src/details.tsx Normal file
View File

@ -0,0 +1,111 @@
import { Maybe, doesExist } from '@apextoaster/js-utils';
import { Button, Dialog, DialogActions, DialogContent, DialogTitle, Typography } from '@mui/material';
import { instance as graphviz } from '@viz-js/viz';
import React, { Fragment, useEffect } from 'react';
import { useStore } from 'zustand';
import { Actor, Item, Room, World } from './models';
import { StoreState, store } from './store';
export interface EntityDetailsProps {
entity: Maybe<Item | Actor | Room>;
onClose: () => void;
onRender: (type: string, entity: string) => void;
}
export function EntityDetails(props: EntityDetailsProps) {
const { entity, onClose, onRender } = props;
// eslint-disable-next-line no-restricted-syntax
if (!doesExist(entity)) {
return <Fragment />;
}
return <Fragment>
<DialogTitle>{entity.name}</DialogTitle>
<DialogContent dividers>
<Typography>
{entity.description}
</Typography>
</DialogContent>
<DialogActions>
<Button onClick={() => onRender('actor', entity.name)}>Render</Button>
<Button onClick={onClose}>Close</Button>
</DialogActions>
</Fragment>;
}
export interface WorldDetailsProps {
world: World;
}
export function WorldDetails(props: WorldDetailsProps) {
const { world } = props;
useEffect(() => {
graphviz().then((viz) => {
const dot = worldGraph(world);
const svg = viz.renderSVGElement(dot);
const graph = document.getElementById('graph');
if (doesExist(graph)) {
graph.replaceChildren(svg);
}
}).catch((err) => {
// eslint-disable-next-line no-console
console.error(err);
});
}, [world]);
return <Fragment>
<DialogTitle>{world.name}</DialogTitle>
<DialogContent dividers>
<Typography variant='body2'>
Theme: {world.theme}
</Typography>
<div id="graph" />
</DialogContent>
</Fragment>;
}
export function detailStateSelector(s: StoreState) {
return {
detailEntity: s.detailEntity,
clearDetailEntity: s.clearDetailEntity,
};
}
export interface DetailDialogProps {
renderEntity: (type: string, entity: string) => void;
}
export function DetailDialog(props: DetailDialogProps) {
const state = useStore(store, detailStateSelector);
const { detailEntity, clearDetailEntity } = state;
let details;
if (isWorld(detailEntity)) {
details = <WorldDetails world={detailEntity} />;
} else {
details = <EntityDetails entity={detailEntity} onClose={clearDetailEntity} onRender={props.renderEntity} />;
}
return <Dialog
open={doesExist(detailEntity)}
onClose={clearDetailEntity}
>{details}</Dialog>;
}
export function isWorld(entity: Maybe<Item | Actor | Room | World>): entity is World {
return doesExist(entity) && doesExist(entity.theme);
}
export function worldGraph(world: World): string {
return `digraph {
${world.rooms.map((room) => roomGraph(room).join('; ')).join('\n')}
}`;
}
export function roomGraph(room: Room): Array<string> {
return Object.entries(room.portals).map(([direction, destination]) =>
`"${room.name}" -> "${destination}" [label="${direction}"]`
);
}

View File

@ -1,9 +1,12 @@
import { Avatar, IconButton, ImageList, ImageListItem, ListItem, ListItemAvatar, ListItemText, Typography } from '@mui/material'; import { Avatar, IconButton, ImageList, ImageListItem, ListItem, ListItemAvatar, ListItemText, Typography } from '@mui/material';
import React, { Fragment, MutableRefObject } from 'react'; import React, { Fragment, MutableRefObject } from 'react';
import { Maybe, doesExist } from '@apextoaster/js-utils';
import { Camera } from '@mui/icons-material'; import { Camera } from '@mui/icons-material';
import { useStore } from 'zustand';
import { formatters } from './format.js'; import { formatters } from './format.js';
import { GameEvent } from './models.js'; import { Actor, GameEvent } from './models.js';
import { StoreState, store } from './store.js';
export function openImage(image: string) { export function openImage(image: string) {
const byteCharacters = atob(image); const byteCharacters = atob(image);
@ -23,7 +26,22 @@ export interface EventItemProps {
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
focusRef?: MutableRefObject<any>; focusRef?: MutableRefObject<any>;
renderEvent: (event: GameEvent) => void; renderEntity: (type: string, entity: string) => void;
renderEvent: (event: string) => void;
}
export function characterSelector(state: StoreState) {
return {
character: state.character,
};
}
export function sameCharacter(a: Maybe<Actor>, b: Maybe<Actor>): boolean {
if (doesExist(a) && doesExist(b)) {
return a.name === b.name;
}
return false;
} }
export function ActionEventItem(props: EventItemProps) { export function ActionEventItem(props: EventItemProps) {
@ -31,6 +49,14 @@ export function ActionEventItem(props: EventItemProps) {
const { id, actor, room, type } = event; const { id, actor, room, type } = event;
const content = formatters[type](event); const content = formatters[type](event);
const state = useStore(store, characterSelector);
const { character } = state;
const playerAction = sameCharacter(actor, character);
const typographyProps = {
color: playerAction ? 'success.text' : 'primary.text',
};
return <ListItem return <ListItem
alignItems="flex-start" alignItems="flex-start"
ref={props.focusRef} ref={props.focusRef}
@ -45,6 +71,8 @@ export function ActionEventItem(props: EventItemProps) {
</ListItemAvatar> </ListItemAvatar>
<ListItemText <ListItemText
primary={room.name} primary={room.name}
primaryTypographyProps={typographyProps}
secondaryTypographyProps={typographyProps}
secondary={ secondary={
<React.Fragment> <React.Fragment>
<Typography <Typography
@ -160,13 +188,78 @@ export function RenderEventItem(props: EventItemProps) {
<ListItemAvatar> <ListItemAvatar>
<Avatar alt="Render" src="/static/images/avatar/1.jpg" /> <Avatar alt="Render" src="/static/images/avatar/1.jpg" />
</ListItemAvatar> </ListItemAvatar>
<ImageList cols={3} rowHeight={256}>
{Object.entries(images).map(([name, image]) => <ImageListItem key={name}>
<a href='#' onClick={() => openImage(image as string)}>
<img src={`data:image/jpeg;base64,${image}`} alt="Render" style={{ maxHeight: 256, maxWidth: 256 }} />
</a>
</ImageListItem>)}
</ImageList>
</ListItem>;
}
export function PromptEventItem(props: EventItemProps) {
const { event } = props;
const { character, prompt } = event;
const state = useStore(store, characterSelector);
const { character: playerCharacter } = state;
const playerPrompt = sameCharacter(playerCharacter, character);
const typographyProps = {
color: playerPrompt ? 'success.text' : 'primary.text',
};
return <ListItem alignItems="flex-start" ref={props.focusRef}>
<ListItemAvatar>
<Avatar alt="Prompt" src="/static/images/avatar/1.jpg" />
</ListItemAvatar>
<ListItemText <ListItemText
primary="Render" primary="Prompt"
secondary={<ImageList cols={3} rowHeight={256}> primaryTypographyProps={typographyProps}
{Object.entries(images).map(([name, image]) => <ImageListItem key={name}> secondaryTypographyProps={typographyProps}
<img src={`data:image/jpeg;base64,${image}`} onClick={() => openImage(image)} alt="Render" /> secondary={
</ImageListItem>)} <Typography
</ImageList>} sx={{ display: 'block' }}
component="span"
variant="body2"
color="text.primary"
>
Prompt for {character}: {prompt}
</Typography>
}
/>
</ListItem>;
}
export function GenerateEventItem(props: EventItemProps) {
const { event, renderEntity } = props;
const { entity, name } = event;
return <ListItem
alignItems="flex-start"
ref={props.focusRef}
secondaryAction={
<IconButton edge="end" aria-label="render" onClick={() => renderEntity(entity.name)}>
<Camera />
</IconButton>
}
>
<ListItemAvatar>
<Avatar alt="Generate" src="/static/images/avatar/1.jpg" />
</ListItemAvatar>
<ListItemText
primary="Generate"
secondary={
<Typography
sx={{ display: 'block' }}
component="span"
variant="body2"
color="text.primary"
>
{name}
</Typography>
}
/> />
</ListItem>; </ListItem>;
} }
@ -188,6 +281,10 @@ export function EventItem(props: EventItemProps) {
return <RenderEventItem {...props} />; return <RenderEventItem {...props} />;
case 'snapshot': case 'snapshot':
return <SnapshotEventItem {...props} />; return <SnapshotEventItem {...props} />;
case 'prompt':
return <PromptEventItem {...props} />;
case 'generate':
return <GenerateEventItem {...props} />;
default: default:
return <ListItem ref={props.focusRef}> return <ListItem ref={props.focusRef}>
<ListItemText primary={`Unknown event type: ${type}`} /> <ListItemText primary={`Unknown event type: ${type}`} />

View File

@ -3,7 +3,6 @@ import { Divider, List } from '@mui/material';
import React, { useEffect, useRef } from 'react'; import React, { useEffect, useRef } from 'react';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { EventItem } from './events'; import { EventItem } from './events';
import { GameEvent } from './models';
import { StoreState, store } from './store'; import { StoreState, store } from './store';
export function historyStateSelector(s: StoreState) { export function historyStateSelector(s: StoreState) {
@ -14,13 +13,13 @@ export function historyStateSelector(s: StoreState) {
} }
export interface HistoryPanelProps { export interface HistoryPanelProps {
renderEvent: (event: GameEvent) => void; renderEntity: (type: string, entity: string) => void;
renderEvent: (event: string) => void;
} }
export function HistoryPanel(props: HistoryPanelProps) { export function HistoryPanel(props: HistoryPanelProps) {
const state = useStore(store, historyStateSelector); const state = useStore(store, historyStateSelector);
const { history, scroll } = state; const { history, scroll } = state;
const { renderEvent } = props;
const scrollRef = useRef<Maybe<Element>>(undefined); const scrollRef = useRef<Maybe<Element>>(undefined);
@ -34,10 +33,10 @@ export function HistoryPanel(props: HistoryPanelProps) {
const items = history.map((item, index) => { const items = history.map((item, index) => {
if (index === history.length - 1) { if (index === history.length - 1) {
return <EventItem key={`item-${index}`} event={item} focusRef={scrollRef} renderEvent={renderEvent} />; return <EventItem {...props} key={`item-${index}`} event={item} focusRef={scrollRef} />;
} }
return <EventItem key={`item-${index}`} event={item} renderEvent={renderEvent} />; return <EventItem {...props} key={`item-${index}`} event={item} />;
}); });
return <List sx={{ width: '100%', bgcolor: 'background.paper' }}> return <List sx={{ width: '100%', bgcolor: 'background.paper' }}>

View File

@ -10,7 +10,7 @@ export interface ClientState {
autoScroll: boolean; autoScroll: boolean;
clientId: string; clientId: string;
clientName: string; clientName: string;
detailEntity: Maybe<Item | Actor | Room>; detailEntity: Maybe<Item | Actor | Room | World>;
eventHistory: Array<GameEvent>; eventHistory: Array<GameEvent>;
readyState: ReadyState; readyState: ReadyState;
themeMode: PaletteMode; themeMode: PaletteMode;
@ -19,7 +19,7 @@ export interface ClientState {
setAutoScroll: (autoScroll: boolean) => void; setAutoScroll: (autoScroll: boolean) => void;
setClientId: (clientId: string) => void; setClientId: (clientId: string) => void;
setClientName: (name: string) => void; setClientName: (name: string) => void;
setDetailEntity: (entity: Maybe<Item | Actor | Room>) => void; setDetailEntity: (entity: Maybe<Item | Actor | Room | World>) => void;
setReadyState: (state: ReadyState) => void; setReadyState: (state: ReadyState) => void;
setThemeMode: (mode: PaletteMode) => void; setThemeMode: (mode: PaletteMode) => void;

View File

@ -6,7 +6,7 @@ import React from 'react';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { StoreState, store } from './store'; import { StoreState, store } from './store';
import { Actor, Item, Room, World } from './models'; import { Actor, Item, Room } from './models';
export type SetDetails = (entity: Maybe<Item | Actor | Room>) => void; export type SetDetails = (entity: Maybe<Item | Actor | Room>) => void;
export type SetPlayer = (actor: Maybe<Actor>) => void; export type SetPlayer = (actor: Maybe<Actor>) => void;
@ -33,6 +33,7 @@ export function itemStateSelector(s: StoreState) {
export function worldStateSelector(s: StoreState) { export function worldStateSelector(s: StoreState) {
return { return {
world: s.world, world: s.world,
setDetailEntity: s.setDetailEntity,
}; };
} }
@ -91,7 +92,7 @@ export function RoomItem(props: { room: Room } & BaseEntityItemProps) {
export function WorldPanel(props: BaseEntityItemProps) { export function WorldPanel(props: BaseEntityItemProps) {
const { setPlayer } = props; const { setPlayer } = props;
const state = useStore(store, worldStateSelector); const state = useStore(store, worldStateSelector);
const { world } = state; const { world, setDetailEntity } = state;
// eslint-disable-next-line no-restricted-syntax // eslint-disable-next-line no-restricted-syntax
if (!doesExist(world)) { if (!doesExist(world)) {
@ -111,6 +112,7 @@ export function WorldPanel(props: BaseEntityItemProps) {
Theme: {world.theme} Theme: {world.theme}
</Typography> </Typography>
<SimpleTreeView> <SimpleTreeView>
<TreeItem itemId="world-graph" label="Graph" onClick={() => setDetailEntity(world)} />
{world.rooms.map((room) => <RoomItem key={room.name} room={room} setPlayer={setPlayer} />)} {world.rooms.map((room) => <RoomItem key={room.name} room={room} setPlayer={setPlayer} />)}
</SimpleTreeView> </SimpleTreeView>
</CardContent> </CardContent>

View File

@ -798,6 +798,11 @@
resolved "https://registry.yarnpkg.com/@ungap/structured-clone/-/structured-clone-1.2.0.tgz#756641adb587851b5ccb3e095daf27ae581c8406" resolved "https://registry.yarnpkg.com/@ungap/structured-clone/-/structured-clone-1.2.0.tgz#756641adb587851b5ccb3e095daf27ae581c8406"
integrity sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ== integrity sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==
"@viz-js/viz@^3.5.0":
version "3.5.0"
resolved "https://registry.yarnpkg.com/@viz-js/viz/-/viz-3.5.0.tgz#9fd09729cd2cdcbc51b0ea293a1954e6839797b2"
integrity sha512-66iFqMC2m0lZhvmHXFyJY12Jn8v9hswFMR3nsumN1dfhNoVrAHsa/7xpB3BojIVyj8IeEc8ciLjxZVdUnhcOxw==
"@xobotyi/scrollbar-width@^1.9.5": "@xobotyi/scrollbar-width@^1.9.5":
version "1.9.5" version "1.9.5"
resolved "https://registry.yarnpkg.com/@xobotyi/scrollbar-width/-/scrollbar-width-1.9.5.tgz#80224a6919272f405b87913ca13b92929bdf3c4d" resolved "https://registry.yarnpkg.com/@xobotyi/scrollbar-width/-/scrollbar-width-1.9.5.tgz#80224a6919272f405b87913ca13b92929bdf3c4d"

24
config.yml Normal file
View File

@ -0,0 +1,24 @@
bot:
discord:
channels: [bots]
render:
cfg:
min: 5
max: 8
checkpoints: [
"diffusion-sdxl-dynavision-0-5-5-7.safetensors",
]
path: /tmp/adventure-images
sizes:
landscape:
width: 1280
height: 960
portrait:
width: 960
height: 1280
square:
width: 1024
height: 1024
steps:
min: 30
max: 50