1
0
Fork 0

make sure config is used consistently, start adding tests
Run Docker Build / build (push) Successful in 13s Details
Run Python Build / build (push) Successful in 23s Details

This commit is contained in:
Sean Sube 2024-06-01 14:46:11 -05:00
parent be37d58ade
commit ef8529ef62
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
16 changed files with 232 additions and 69 deletions

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ venv/
client/node_modules/
client/out/
taleweave/custom_*
.coverage
coverage.*

80
docs/cli.md Normal file
View File

@ -0,0 +1,80 @@
# TaleWeave AI Command Line Options
The following command line arguments are available when launching the TaleWeave AI engine:
- **--actions**
- **Type:** String
- **Description:** Additional actions to include in the simulation. Note: More than one argument is allowed.
- **--add-rooms**
- **Type:** Integer
- **Default:** 0
- **Description:** The number of new rooms to generate before starting the simulation.
- **--config**
- **Type:** String
- **Description:** The file to load additional configuration from.
- **--discord**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the simulation in a Discord bot.
- **--flavor**
- **Type:** String
- **Default:** ""
- **Description:** Additional flavor text for the generated world.
- **--optional-actions**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Include optional actions in the simulation.
- **--player**
- **Type:** String
- **Description:** The name of the character to play as.
- **--prompts**
- **Type:** String
- **Description:** The file to load game prompts from. Note: More than one argument is allowed.
- **--render**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the render thread.
- **--render-generated**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Render entities as they are generated.
- **--rooms**
- **Type:** Integer
- **Description:** The number of rooms to generate.
- **--server**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the websocket server.
- **--state**
- **Type:** String
- **Description:** The file to save the world state to. Defaults to `$world.state.json` if not set.
- **--turns**
- **Type:** Integer or "inf"
- **Default:** 10
- **Description:** The number of simulation turns to run.
- **--systems**
- **Type:** String
- **Description:** Extra systems to run in the simulation. Note: More than one argument is allowed.
- **--theme**
- **Type:** String
- **Default:** "fantasy"
- **Description:** The theme of the generated world.
- **--world**
- **Type:** String
- **Default:** "world"
- **Description:** The file to save the generated world to.
- **--world-template**
- **Type:** String
- **Description:** The template file to load the world prompt from.

View File

@ -2,16 +2,14 @@ from taleweave.context import (
action_context,
get_agent_for_character,
get_current_turn,
get_game_config,
get_prompt,
)
from taleweave.errors import ActionError
from taleweave.models.config import DEFAULT_CONFIG
from taleweave.models.planning import CalendarEvent
from taleweave.utils.planning import get_recent_notes
from taleweave.utils.prompt import format_prompt
character_config = DEFAULT_CONFIG.world.character
def take_note(fact: str):
"""
@ -22,11 +20,13 @@ def take_note(fact: str):
fact: The fact to remember.
"""
config = get_game_config()
with action_context() as (_, action_character):
if fact in action_character.planner.notes:
raise ActionError(get_prompt("action_take_note_error_duplicate"))
if len(action_character.planner.notes) >= character_config.note_limit:
if len(action_character.planner.notes) >= config.world.character.note_limit:
raise ActionError(get_prompt("action_take_note_error_limit"))
action_character.planner.notes.append(fact)
@ -103,6 +103,8 @@ def summarize_notes(limit: int) -> str:
limit: The maximum number of notes to keep.
"""
config = get_game_config()
with action_context() as (_, action_character):
notes = action_character.planner.notes
if len(notes) == 0:
@ -120,11 +122,11 @@ def summarize_notes(limit: int) -> str:
)
new_notes = [note.strip() for note in summary.split("\n") if note.strip()]
if len(new_notes) > character_config.note_limit:
if len(new_notes) > config.world.character.note_limit:
raise ActionError(
format_prompt(
"action_summarize_notes_error_limit",
limit=character_config.note_limit,
limit=config.world.character.note_limit,
)
)
@ -165,7 +167,9 @@ def check_calendar(count: int):
count: The number of upcoming events to read. 5 is usually a good number.
"""
count = min(count, character_config.event_limit)
config = get_game_config()
count = min(count, config.world.character.event_limit)
current_turn = get_current_turn()
with action_context() as (_, action_character):

View File

@ -11,10 +11,11 @@ from taleweave.context import (
broadcast,
get_character_agent_for_name,
get_current_world,
get_game_config,
set_character_agent,
subscribe,
)
from taleweave.models.config import DEFAULT_CONFIG, DiscordBotConfig
from taleweave.models.config import DiscordBotConfig
from taleweave.models.event import (
ActionEvent,
GameEvent,
@ -38,7 +39,6 @@ from taleweave.utils.prompt import format_prompt
logger = getLogger(__name__)
client = None
bot_config: DiscordBotConfig = DEFAULT_CONFIG.bot.discord
active_tasks = set()
event_messages: Dict[int, str | GameEvent] = {}
@ -78,21 +78,24 @@ class AdventureClient(Client):
if message.author == self.user:
return
config = get_game_config()
author = message.author
channel = message.channel
user_name = author.name # include nick
if message.content.startswith(
bot_config.command_prefix + bot_config.name_command
config.bot.discord.command_prefix + config.bot.discord.name_command
):
world = get_current_world()
if world:
world_message = format_prompt(
"discord_world_active", bot_name=bot_config.name_title, world=world
"discord_world_active",
bot_name=config.bot.discord.name_title,
world=world,
)
else:
world_message = format_prompt(
"discord_world_none", bot_name=bot_config.name_title
"discord_world_none", bot_name=config.bot.discord.name_title
)
await message.channel.send(world_message)
@ -100,7 +103,7 @@ class AdventureClient(Client):
if message.content.startswith("!help"):
await message.channel.send(
format_prompt("discord_help", bot_name=bot_config.name_command)
format_prompt("discord_help", bot_name=config.bot.discord.name_command)
)
return
@ -172,14 +175,11 @@ class AdventureClient(Client):
def launch_bot(config: DiscordBotConfig):
global bot_config
global client
bot_config = config
# message contents need to be enabled for multi-server bots
intents = Intents.default()
if bot_config.content_intent:
if config.content_intent:
intents.message_content = True
client = AdventureClient(intents=intents)
@ -246,12 +246,14 @@ def get_active_channels():
if not client:
return []
config = get_game_config()
# return client.private_channels
return [
channel
for guild in client.guilds
for channel in guild.text_channels
if channel.name in bot_config.channels
if channel.name in config.bot.discord.channels
]

View File

@ -18,6 +18,7 @@ from packit.agent import Agent
from pyee.base import EventEmitter
from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG, Config
from taleweave.models.entity import Character, Room, World
from taleweave.models.event import GameEvent, StatusEvent
from taleweave.models.prompt import PromptLibrary
@ -34,6 +35,7 @@ dungeon_master: Agent | None = None
# game context
event_emitter = EventEmitter()
game_config: Config = DEFAULT_CONFIG
game_systems: List[GameSystem] = []
prompt_library: PromptLibrary = PromptLibrary(prompts={})
system_data: Dict[str, Any] = {}
@ -160,6 +162,10 @@ def get_dungeon_master() -> Agent:
return dungeon_master
def get_game_config() -> Config:
return game_config
def get_game_systems() -> List[GameSystem]:
return game_systems
@ -172,6 +178,10 @@ def get_prompt_library() -> PromptLibrary:
return prompt_library
def get_system_config(system: str) -> Any | None:
return game_config.systems.data.get(system)
def get_system_data(system: str) -> Any | None:
return system_data.get(system)
@ -209,6 +219,11 @@ def set_dungeon_master(agent):
dungeon_master = agent
def set_game_config(config: Config):
global game_config
game_config = config
def set_game_systems(systems: Sequence[GameSystem]):
global game_systems
game_systems = list(systems)

View File

@ -7,9 +7,14 @@ from packit.loops import loop_retry
from packit.results import enum_result, int_result
from packit.utils import could_be_json
from taleweave.context import broadcast, get_prompt, set_current_world, set_system_data
from taleweave.context import (
broadcast,
get_game_config,
get_prompt,
set_current_world,
set_system_data,
)
from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG, WorldConfig
from taleweave.models.effect import (
EffectPattern,
FloatEffectPattern,
@ -33,7 +38,10 @@ from taleweave.utils.string import normalize_name
logger = getLogger(__name__)
world_config: WorldConfig = DEFAULT_CONFIG.world
def get_world_config():
config = get_game_config()
return config.world
def duplicate_name_parser(existing_names: List[str]):
@ -112,6 +120,7 @@ def generate_room(
actions = {}
room = Room(name=name, description=desc, items=[], characters=[], actions=actions)
world_config = get_world_config()
item_count = resolve_int_range(world_config.size.room_items) or 0
broadcast_generated(
format_prompt(
@ -276,6 +285,7 @@ def generate_item(
item = Item(name=name, description=desc, actions=actions)
generate_system_attributes(agent, world, item, systems)
world_config = get_world_config()
effect_count = resolve_int_range(world_config.size.item_effects) or 0
broadcast_generated(
message=format_prompt(
@ -343,6 +353,7 @@ def generate_character(
generate_system_attributes(agent, world, character, systems)
# generate the character's inventory
world_config = get_world_config()
item_count = resolve_int_range(world_config.size.character_items) or 0
broadcast_generated(
message=format_prompt(
@ -499,6 +510,7 @@ def link_rooms(
rooms: List[Room] | None = None,
) -> None:
rooms = rooms or world.rooms
world_config = get_world_config()
for room in rooms:
num_portals = resolve_int_range(world_config.size.portals) or 0
@ -550,6 +562,7 @@ def generate_world(
systems: List[GameSystem],
room_count: int | None = None,
) -> World:
world_config = get_world_config()
room_count = room_count or resolve_int_range(world_config.size.rooms) or 0
broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme"))

View File

@ -27,7 +27,7 @@ except Exception as err:
logger = logger_with_colors(__name__) # , level="DEBUG")
load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True)
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
if True:
from taleweave.context import (
@ -35,6 +35,7 @@ if True:
get_system_data,
set_current_turn,
set_dungeon_master,
set_game_config,
set_system_data,
subscribe,
)
@ -312,6 +313,7 @@ def main():
if args.config:
with open(args.config, "r") as f:
config = Config(**load_yaml(f))
set_game_config(config)
else:
config = DEFAULT_CONFIG

View File

@ -1,6 +1,6 @@
from typing import Dict, List
from .base import IntRange, dataclass
from .base import Attributes, IntRange, dataclass
@dataclass
@ -44,6 +44,15 @@ class ServerConfig:
websocket: WebsocketServerConfig
@dataclass
class SystemsConfig:
"""
Configuration for the game systems.
"""
data: Attributes
@dataclass
class WorldCharacterConfig:
conversation_limit: int
@ -80,6 +89,7 @@ class Config:
bot: BotConfig
render: RenderConfig
server: ServerConfig
systems: SystemsConfig
world: WorldConfig
@ -107,6 +117,7 @@ DEFAULT_CONFIG = Config(
steps=30,
),
server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8001)),
systems=SystemsConfig(data={}),
world=WorldConfig(
character=WorldCharacterConfig(
conversation_limit=2,

View File

@ -60,6 +60,7 @@ class PromptEvent(BaseModel):
A prompt for a character to take an action.
"""
actions: Dict[str, Any]
prompt: str
room: Room
character: Character

View File

@ -197,15 +197,20 @@ class RemotePlayer(BasePlayer):
Ask the player for input.
"""
actions = {}
formatted_prompt = prompt.format(**kwargs)
if toolbox:
actions = toolbox.list_definitions()
formatted_prompt += self.format_psuedo_functions(toolbox)
self.memory.append(HumanMessage(content=formatted_prompt))
with action_context() as (current_room, current_character):
prompt_event = PromptEvent(
prompt=formatted_prompt, room=current_room, character=current_character
actions=actions,
prompt=formatted_prompt,
room=current_room,
character=current_character,
)
try:

View File

@ -15,9 +15,9 @@ from fnvhash import fnv1a_32
from jinja2 import Environment, FileSystemLoader, select_autoescape
from PIL import Image
from taleweave.context import broadcast
from taleweave.models.base import uuid
from taleweave.models.config import DEFAULT_CONFIG, RenderConfig
from taleweave.context import broadcast, get_game_config
from taleweave.models.base import IntRange, uuid
from taleweave.models.config import RenderConfig
from taleweave.models.entity import WorldEntity
from taleweave.models.event import (
ActionEvent,
@ -36,7 +36,6 @@ logger = getLogger(__name__)
server_address = environ["COMFY_API"]
client_id = uuid()
render_config: RenderConfig = DEFAULT_CONFIG.render
# requests to generate images for game events
@ -44,12 +43,17 @@ render_queue: Queue[GameEvent | WorldEntity] = Queue()
render_thread: Thread | None = None
def generate_cfg():
return resolve_int_range(render_config.cfg)
def get_render_config():
config = get_game_config()
return config.render
def generate_steps():
return resolve_int_range(render_config.steps)
def generate_cfg(cfg: int | IntRange):
return resolve_int_range(cfg)
def generate_steps(steps: int | IntRange):
return resolve_int_range(steps)
def generate_batches(
@ -148,9 +152,10 @@ def generate_image_tool(prompt, count, size="landscape"):
def generate_images(
prompt: str, count: int, size="landscape", prefix="output"
) -> List[str]:
cfg = generate_cfg()
render_config = get_render_config()
cfg = generate_cfg(render_config.cfg)
dims = render_config.sizes[size]
steps = generate_steps()
steps = generate_steps(render_config.steps)
seed = randint(0, 10000000)
checkpoint = choice(render_config.checkpoints)
logger.info(
@ -248,6 +253,8 @@ def get_image_prefix(event: GameEvent | WorldEntity) -> str:
def render_loop():
render_config = get_render_config()
while True:
event = render_queue.get()
prefix = get_image_prefix(event)
@ -317,13 +324,8 @@ def render_generated(event: GameEvent):
def launch_render(config: RenderConfig):
global render_config
global render_thread
# update the config
logger.info("updating render config: %s", config)
render_config = config
# start the render thread
logger.info("launching render thread")
render_thread = Thread(target=render_loop, daemon=True)

View File

@ -16,10 +16,11 @@ from taleweave.context import (
broadcast,
get_character_agent_for_name,
get_current_world,
get_game_config,
set_character_agent,
subscribe,
)
from taleweave.models.config import DEFAULT_CONFIG, WebsocketServerConfig
from taleweave.models.config import WebsocketServerConfig
from taleweave.models.entity import World, WorldEntity
from taleweave.models.event import (
GameEvent,
@ -47,7 +48,6 @@ last_snapshot: str | None = None
player_names: Dict[str, str] = {}
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
recent_json: MutableSequence[str] = deque(maxlen=100)
server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket
def get_player_name(client_id: str) -> str:
@ -59,16 +59,15 @@ async def handler(websocket):
logger.info("client connected, given id: %s", id)
connected.add(websocket)
async def next_turn(character: str, prompt: str) -> None:
async def next_turn(event: PromptEvent) -> None:
await websocket.send(
dumps(
{
# TODO: these should be fields in the PromptEvent
"type": "prompt",
"client": id,
"character": character,
"prompt": prompt,
"actions": [],
"type": event.type,
"client": id, # TODO: this should be a field in the PromptEvent
"character": event.character,
"prompt": event.prompt,
"actions": event.actions,
}
),
)
@ -77,7 +76,7 @@ async def handler(websocket):
# TODO: nothing about this is good
player = get_player(id)
if player and player.name == event.character.name:
asyncio.run(next_turn(event.character.name, event.prompt))
asyncio.run(next_turn(event))
return True
return False
@ -303,10 +302,6 @@ def send_and_append(id: str, message: Dict):
def launch_server(config: WebsocketServerConfig):
global socket_thread
global server_config
logger.info("configuring websocket server: %s", config)
server_config = config
def run_sockets():
asyncio.run(server_main())
@ -321,7 +316,11 @@ def launch_server(config: WebsocketServerConfig):
async def server_main():
async with websockets.serve(handler, server_config.host, server_config.port):
config = get_game_config()
async with websockets.serve(
handler, config.server.websocket.host, config.server.websocket.port
):
logger.info("websocket server started")
await asyncio.Future() # run forever

View File

@ -36,6 +36,7 @@ from taleweave.context import (
get_character_for_agent,
get_current_turn,
get_current_world,
get_game_config,
get_prompt,
set_current_character,
set_current_room,
@ -44,7 +45,6 @@ from taleweave.context import (
set_game_systems,
)
from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG
from taleweave.models.entity import Character, Room, World
from taleweave.models.event import ActionEvent, ResultEvent
from taleweave.utils.conversation import make_keyword_condition, summarize_room
@ -57,9 +57,6 @@ from taleweave.utils.world import describe_entity, format_attributes
logger = getLogger(__name__)
turn_config = DEFAULT_CONFIG.world.turn
def world_result_parser(value, agent, **kwargs):
current_world = get_current_world()
if not current_world:
@ -121,10 +118,11 @@ def prompt_character_action(
event = ActionEvent.from_json(value, room, character)
else:
# TODO: this path should be removed and throw
logger.warning(
"invalid action, emitting as result event - this is a bug somewhere"
)
event = ResultEvent(value, room, character)
# logger.warning(
# "invalid action, emitting as result event - this is a bug somewhere"
# )
# event = ResultEvent(value, room, character)
raise ValueError("invalid non-JSON action")
broadcast(event)
@ -198,7 +196,8 @@ def prompt_character_planning(
current_turn: int,
max_steps: int | None = None,
) -> str:
max_steps = max_steps or turn_config.planning_steps
config = get_game_config()
max_steps = max_steps or config.world.turn.planning_steps
notes_prompt, events_prompt = get_notes_events(character, current_turn)

View File

@ -8,8 +8,7 @@ from packit.conditions import condition_or, condition_threshold, make_flag_condi
from packit.results import multi_function_or_str_result
from packit.utils import could_be_json
from taleweave.context import broadcast
from taleweave.models.config import DEFAULT_CONFIG
from taleweave.context import broadcast, get_game_config
from taleweave.models.entity import Character, Room
from taleweave.models.event import ReplyEvent
from taleweave.utils.prompt import format_str
@ -19,9 +18,6 @@ from .string import and_list, normalize_name
logger = getLogger(__name__)
character_config = DEFAULT_CONFIG.world.character
def make_keyword_condition(end_message: str, keywords=["end", "stop"]):
set_end, condition_end = make_flag_condition()
@ -99,8 +95,10 @@ def loop_conversation(
Loop through a conversation between a series of agents, using metadata from their characters.
"""
config = get_game_config()
if max_length is None:
max_length = character_config.conversation_limit
max_length = config.world.character.conversation_limit
if len(characters) != len(agents):
raise ValueError("The number of characters and agents must match.")

0
tests/utils/__init__.py Normal file
View File

View File

@ -0,0 +1,30 @@
from unittest import TestCase
from taleweave.models.entity import Room, World
from taleweave.utils.search import find_room
class TestFindRoom(TestCase):
def test_existing_room(self):
world = World(name="Test World", rooms=[], theme="testing", order=[])
room = Room(
name="Test Room",
description="A test room.",
characters=[],
items=[],
portals=[],
)
world.rooms.append(room)
self.assertEqual(find_room(world, "Test Room"), room)
def test_missing_room(self):
world = World(name="Test World", rooms=[], theme="testing", order=[])
room = Room(
name="Test Room",
description="A test room.",
characters=[],
items=[],
portals=[],
)
world.rooms.append(room)
self.assertEqual(find_room(world, "Missing Room"), None)