1
0
Fork 0

make sure config is used consistently, start adding tests
Run Docker Build / build (push) Successful in 13s Details
Run Python Build / build (push) Successful in 23s Details

This commit is contained in:
Sean Sube 2024-06-01 14:46:11 -05:00
parent be37d58ade
commit ef8529ef62
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
16 changed files with 232 additions and 69 deletions

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ venv/
client/node_modules/ client/node_modules/
client/out/ client/out/
taleweave/custom_* taleweave/custom_*
.coverage
coverage.*

80
docs/cli.md Normal file
View File

@ -0,0 +1,80 @@
# TaleWeave AI Command Line Options
The following command line arguments are available when launching the TaleWeave AI engine:
- **--actions**
- **Type:** String
- **Description:** Additional actions to include in the simulation. Note: More than one argument is allowed.
- **--add-rooms**
- **Type:** Integer
- **Default:** 0
- **Description:** The number of new rooms to generate before starting the simulation.
- **--config**
- **Type:** String
- **Description:** The file to load additional configuration from.
- **--discord**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the simulation in a Discord bot.
- **--flavor**
- **Type:** String
- **Default:** ""
- **Description:** Additional flavor text for the generated world.
- **--optional-actions**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Include optional actions in the simulation.
- **--player**
- **Type:** String
- **Description:** The name of the character to play as.
- **--prompts**
- **Type:** String
- **Description:** The file to load game prompts from. Note: More than one argument is allowed.
- **--render**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the render thread.
- **--render-generated**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Render entities as they are generated.
- **--rooms**
- **Type:** Integer
- **Description:** The number of rooms to generate.
- **--server**
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
- **Description:** Run the websocket server.
- **--state**
- **Type:** String
- **Description:** The file to save the world state to. Defaults to `$world.state.json` if not set.
- **--turns**
- **Type:** Integer or "inf"
- **Default:** 10
- **Description:** The number of simulation turns to run.
- **--systems**
- **Type:** String
- **Description:** Extra systems to run in the simulation. Note: More than one argument is allowed.
- **--theme**
- **Type:** String
- **Default:** "fantasy"
- **Description:** The theme of the generated world.
- **--world**
- **Type:** String
- **Default:** "world"
- **Description:** The file to save the generated world to.
- **--world-template**
- **Type:** String
- **Description:** The template file to load the world prompt from.

View File

@ -2,16 +2,14 @@ from taleweave.context import (
action_context, action_context,
get_agent_for_character, get_agent_for_character,
get_current_turn, get_current_turn,
get_game_config,
get_prompt, get_prompt,
) )
from taleweave.errors import ActionError from taleweave.errors import ActionError
from taleweave.models.config import DEFAULT_CONFIG
from taleweave.models.planning import CalendarEvent from taleweave.models.planning import CalendarEvent
from taleweave.utils.planning import get_recent_notes from taleweave.utils.planning import get_recent_notes
from taleweave.utils.prompt import format_prompt from taleweave.utils.prompt import format_prompt
character_config = DEFAULT_CONFIG.world.character
def take_note(fact: str): def take_note(fact: str):
""" """
@ -22,11 +20,13 @@ def take_note(fact: str):
fact: The fact to remember. fact: The fact to remember.
""" """
config = get_game_config()
with action_context() as (_, action_character): with action_context() as (_, action_character):
if fact in action_character.planner.notes: if fact in action_character.planner.notes:
raise ActionError(get_prompt("action_take_note_error_duplicate")) raise ActionError(get_prompt("action_take_note_error_duplicate"))
if len(action_character.planner.notes) >= character_config.note_limit: if len(action_character.planner.notes) >= config.world.character.note_limit:
raise ActionError(get_prompt("action_take_note_error_limit")) raise ActionError(get_prompt("action_take_note_error_limit"))
action_character.planner.notes.append(fact) action_character.planner.notes.append(fact)
@ -103,6 +103,8 @@ def summarize_notes(limit: int) -> str:
limit: The maximum number of notes to keep. limit: The maximum number of notes to keep.
""" """
config = get_game_config()
with action_context() as (_, action_character): with action_context() as (_, action_character):
notes = action_character.planner.notes notes = action_character.planner.notes
if len(notes) == 0: if len(notes) == 0:
@ -120,11 +122,11 @@ def summarize_notes(limit: int) -> str:
) )
new_notes = [note.strip() for note in summary.split("\n") if note.strip()] new_notes = [note.strip() for note in summary.split("\n") if note.strip()]
if len(new_notes) > character_config.note_limit: if len(new_notes) > config.world.character.note_limit:
raise ActionError( raise ActionError(
format_prompt( format_prompt(
"action_summarize_notes_error_limit", "action_summarize_notes_error_limit",
limit=character_config.note_limit, limit=config.world.character.note_limit,
) )
) )
@ -165,7 +167,9 @@ def check_calendar(count: int):
count: The number of upcoming events to read. 5 is usually a good number. count: The number of upcoming events to read. 5 is usually a good number.
""" """
count = min(count, character_config.event_limit) config = get_game_config()
count = min(count, config.world.character.event_limit)
current_turn = get_current_turn() current_turn = get_current_turn()
with action_context() as (_, action_character): with action_context() as (_, action_character):

View File

@ -11,10 +11,11 @@ from taleweave.context import (
broadcast, broadcast,
get_character_agent_for_name, get_character_agent_for_name,
get_current_world, get_current_world,
get_game_config,
set_character_agent, set_character_agent,
subscribe, subscribe,
) )
from taleweave.models.config import DEFAULT_CONFIG, DiscordBotConfig from taleweave.models.config import DiscordBotConfig
from taleweave.models.event import ( from taleweave.models.event import (
ActionEvent, ActionEvent,
GameEvent, GameEvent,
@ -38,7 +39,6 @@ from taleweave.utils.prompt import format_prompt
logger = getLogger(__name__) logger = getLogger(__name__)
client = None client = None
bot_config: DiscordBotConfig = DEFAULT_CONFIG.bot.discord
active_tasks = set() active_tasks = set()
event_messages: Dict[int, str | GameEvent] = {} event_messages: Dict[int, str | GameEvent] = {}
@ -78,21 +78,24 @@ class AdventureClient(Client):
if message.author == self.user: if message.author == self.user:
return return
config = get_game_config()
author = message.author author = message.author
channel = message.channel channel = message.channel
user_name = author.name # include nick user_name = author.name # include nick
if message.content.startswith( if message.content.startswith(
bot_config.command_prefix + bot_config.name_command config.bot.discord.command_prefix + config.bot.discord.name_command
): ):
world = get_current_world() world = get_current_world()
if world: if world:
world_message = format_prompt( world_message = format_prompt(
"discord_world_active", bot_name=bot_config.name_title, world=world "discord_world_active",
bot_name=config.bot.discord.name_title,
world=world,
) )
else: else:
world_message = format_prompt( world_message = format_prompt(
"discord_world_none", bot_name=bot_config.name_title "discord_world_none", bot_name=config.bot.discord.name_title
) )
await message.channel.send(world_message) await message.channel.send(world_message)
@ -100,7 +103,7 @@ class AdventureClient(Client):
if message.content.startswith("!help"): if message.content.startswith("!help"):
await message.channel.send( await message.channel.send(
format_prompt("discord_help", bot_name=bot_config.name_command) format_prompt("discord_help", bot_name=config.bot.discord.name_command)
) )
return return
@ -172,14 +175,11 @@ class AdventureClient(Client):
def launch_bot(config: DiscordBotConfig): def launch_bot(config: DiscordBotConfig):
global bot_config
global client global client
bot_config = config
# message contents need to be enabled for multi-server bots # message contents need to be enabled for multi-server bots
intents = Intents.default() intents = Intents.default()
if bot_config.content_intent: if config.content_intent:
intents.message_content = True intents.message_content = True
client = AdventureClient(intents=intents) client = AdventureClient(intents=intents)
@ -246,12 +246,14 @@ def get_active_channels():
if not client: if not client:
return [] return []
config = get_game_config()
# return client.private_channels # return client.private_channels
return [ return [
channel channel
for guild in client.guilds for guild in client.guilds
for channel in guild.text_channels for channel in guild.text_channels
if channel.name in bot_config.channels if channel.name in config.bot.discord.channels
] ]

View File

@ -18,6 +18,7 @@ from packit.agent import Agent
from pyee.base import EventEmitter from pyee.base import EventEmitter
from taleweave.game_system import GameSystem from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG, Config
from taleweave.models.entity import Character, Room, World from taleweave.models.entity import Character, Room, World
from taleweave.models.event import GameEvent, StatusEvent from taleweave.models.event import GameEvent, StatusEvent
from taleweave.models.prompt import PromptLibrary from taleweave.models.prompt import PromptLibrary
@ -34,6 +35,7 @@ dungeon_master: Agent | None = None
# game context # game context
event_emitter = EventEmitter() event_emitter = EventEmitter()
game_config: Config = DEFAULT_CONFIG
game_systems: List[GameSystem] = [] game_systems: List[GameSystem] = []
prompt_library: PromptLibrary = PromptLibrary(prompts={}) prompt_library: PromptLibrary = PromptLibrary(prompts={})
system_data: Dict[str, Any] = {} system_data: Dict[str, Any] = {}
@ -160,6 +162,10 @@ def get_dungeon_master() -> Agent:
return dungeon_master return dungeon_master
def get_game_config() -> Config:
return game_config
def get_game_systems() -> List[GameSystem]: def get_game_systems() -> List[GameSystem]:
return game_systems return game_systems
@ -172,6 +178,10 @@ def get_prompt_library() -> PromptLibrary:
return prompt_library return prompt_library
def get_system_config(system: str) -> Any | None:
return game_config.systems.data.get(system)
def get_system_data(system: str) -> Any | None: def get_system_data(system: str) -> Any | None:
return system_data.get(system) return system_data.get(system)
@ -209,6 +219,11 @@ def set_dungeon_master(agent):
dungeon_master = agent dungeon_master = agent
def set_game_config(config: Config):
global game_config
game_config = config
def set_game_systems(systems: Sequence[GameSystem]): def set_game_systems(systems: Sequence[GameSystem]):
global game_systems global game_systems
game_systems = list(systems) game_systems = list(systems)

View File

@ -7,9 +7,14 @@ from packit.loops import loop_retry
from packit.results import enum_result, int_result from packit.results import enum_result, int_result
from packit.utils import could_be_json from packit.utils import could_be_json
from taleweave.context import broadcast, get_prompt, set_current_world, set_system_data from taleweave.context import (
broadcast,
get_game_config,
get_prompt,
set_current_world,
set_system_data,
)
from taleweave.game_system import GameSystem from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG, WorldConfig
from taleweave.models.effect import ( from taleweave.models.effect import (
EffectPattern, EffectPattern,
FloatEffectPattern, FloatEffectPattern,
@ -33,7 +38,10 @@ from taleweave.utils.string import normalize_name
logger = getLogger(__name__) logger = getLogger(__name__)
world_config: WorldConfig = DEFAULT_CONFIG.world
def get_world_config():
config = get_game_config()
return config.world
def duplicate_name_parser(existing_names: List[str]): def duplicate_name_parser(existing_names: List[str]):
@ -112,6 +120,7 @@ def generate_room(
actions = {} actions = {}
room = Room(name=name, description=desc, items=[], characters=[], actions=actions) room = Room(name=name, description=desc, items=[], characters=[], actions=actions)
world_config = get_world_config()
item_count = resolve_int_range(world_config.size.room_items) or 0 item_count = resolve_int_range(world_config.size.room_items) or 0
broadcast_generated( broadcast_generated(
format_prompt( format_prompt(
@ -276,6 +285,7 @@ def generate_item(
item = Item(name=name, description=desc, actions=actions) item = Item(name=name, description=desc, actions=actions)
generate_system_attributes(agent, world, item, systems) generate_system_attributes(agent, world, item, systems)
world_config = get_world_config()
effect_count = resolve_int_range(world_config.size.item_effects) or 0 effect_count = resolve_int_range(world_config.size.item_effects) or 0
broadcast_generated( broadcast_generated(
message=format_prompt( message=format_prompt(
@ -343,6 +353,7 @@ def generate_character(
generate_system_attributes(agent, world, character, systems) generate_system_attributes(agent, world, character, systems)
# generate the character's inventory # generate the character's inventory
world_config = get_world_config()
item_count = resolve_int_range(world_config.size.character_items) or 0 item_count = resolve_int_range(world_config.size.character_items) or 0
broadcast_generated( broadcast_generated(
message=format_prompt( message=format_prompt(
@ -499,6 +510,7 @@ def link_rooms(
rooms: List[Room] | None = None, rooms: List[Room] | None = None,
) -> None: ) -> None:
rooms = rooms or world.rooms rooms = rooms or world.rooms
world_config = get_world_config()
for room in rooms: for room in rooms:
num_portals = resolve_int_range(world_config.size.portals) or 0 num_portals = resolve_int_range(world_config.size.portals) or 0
@ -550,6 +562,7 @@ def generate_world(
systems: List[GameSystem], systems: List[GameSystem],
room_count: int | None = None, room_count: int | None = None,
) -> World: ) -> World:
world_config = get_world_config()
room_count = room_count or resolve_int_range(world_config.size.rooms) or 0 room_count = room_count or resolve_int_range(world_config.size.rooms) or 0
broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme")) broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme"))

View File

@ -27,7 +27,7 @@ except Exception as err:
logger = logger_with_colors(__name__) # , level="DEBUG") logger = logger_with_colors(__name__) # , level="DEBUG")
load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True) load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
if True: if True:
from taleweave.context import ( from taleweave.context import (
@ -35,6 +35,7 @@ if True:
get_system_data, get_system_data,
set_current_turn, set_current_turn,
set_dungeon_master, set_dungeon_master,
set_game_config,
set_system_data, set_system_data,
subscribe, subscribe,
) )
@ -312,6 +313,7 @@ def main():
if args.config: if args.config:
with open(args.config, "r") as f: with open(args.config, "r") as f:
config = Config(**load_yaml(f)) config = Config(**load_yaml(f))
set_game_config(config)
else: else:
config = DEFAULT_CONFIG config = DEFAULT_CONFIG

View File

@ -1,6 +1,6 @@
from typing import Dict, List from typing import Dict, List
from .base import IntRange, dataclass from .base import Attributes, IntRange, dataclass
@dataclass @dataclass
@ -44,6 +44,15 @@ class ServerConfig:
websocket: WebsocketServerConfig websocket: WebsocketServerConfig
@dataclass
class SystemsConfig:
"""
Configuration for the game systems.
"""
data: Attributes
@dataclass @dataclass
class WorldCharacterConfig: class WorldCharacterConfig:
conversation_limit: int conversation_limit: int
@ -80,6 +89,7 @@ class Config:
bot: BotConfig bot: BotConfig
render: RenderConfig render: RenderConfig
server: ServerConfig server: ServerConfig
systems: SystemsConfig
world: WorldConfig world: WorldConfig
@ -107,6 +117,7 @@ DEFAULT_CONFIG = Config(
steps=30, steps=30,
), ),
server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8001)), server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8001)),
systems=SystemsConfig(data={}),
world=WorldConfig( world=WorldConfig(
character=WorldCharacterConfig( character=WorldCharacterConfig(
conversation_limit=2, conversation_limit=2,

View File

@ -60,6 +60,7 @@ class PromptEvent(BaseModel):
A prompt for a character to take an action. A prompt for a character to take an action.
""" """
actions: Dict[str, Any]
prompt: str prompt: str
room: Room room: Room
character: Character character: Character

View File

@ -197,15 +197,20 @@ class RemotePlayer(BasePlayer):
Ask the player for input. Ask the player for input.
""" """
actions = {}
formatted_prompt = prompt.format(**kwargs) formatted_prompt = prompt.format(**kwargs)
if toolbox: if toolbox:
actions = toolbox.list_definitions()
formatted_prompt += self.format_psuedo_functions(toolbox) formatted_prompt += self.format_psuedo_functions(toolbox)
self.memory.append(HumanMessage(content=formatted_prompt)) self.memory.append(HumanMessage(content=formatted_prompt))
with action_context() as (current_room, current_character): with action_context() as (current_room, current_character):
prompt_event = PromptEvent( prompt_event = PromptEvent(
prompt=formatted_prompt, room=current_room, character=current_character actions=actions,
prompt=formatted_prompt,
room=current_room,
character=current_character,
) )
try: try:

View File

@ -15,9 +15,9 @@ from fnvhash import fnv1a_32
from jinja2 import Environment, FileSystemLoader, select_autoescape from jinja2 import Environment, FileSystemLoader, select_autoescape
from PIL import Image from PIL import Image
from taleweave.context import broadcast from taleweave.context import broadcast, get_game_config
from taleweave.models.base import uuid from taleweave.models.base import IntRange, uuid
from taleweave.models.config import DEFAULT_CONFIG, RenderConfig from taleweave.models.config import RenderConfig
from taleweave.models.entity import WorldEntity from taleweave.models.entity import WorldEntity
from taleweave.models.event import ( from taleweave.models.event import (
ActionEvent, ActionEvent,
@ -36,7 +36,6 @@ logger = getLogger(__name__)
server_address = environ["COMFY_API"] server_address = environ["COMFY_API"]
client_id = uuid() client_id = uuid()
render_config: RenderConfig = DEFAULT_CONFIG.render
# requests to generate images for game events # requests to generate images for game events
@ -44,12 +43,17 @@ render_queue: Queue[GameEvent | WorldEntity] = Queue()
render_thread: Thread | None = None render_thread: Thread | None = None
def generate_cfg(): def get_render_config():
return resolve_int_range(render_config.cfg) config = get_game_config()
return config.render
def generate_steps(): def generate_cfg(cfg: int | IntRange):
return resolve_int_range(render_config.steps) return resolve_int_range(cfg)
def generate_steps(steps: int | IntRange):
return resolve_int_range(steps)
def generate_batches( def generate_batches(
@ -148,9 +152,10 @@ def generate_image_tool(prompt, count, size="landscape"):
def generate_images( def generate_images(
prompt: str, count: int, size="landscape", prefix="output" prompt: str, count: int, size="landscape", prefix="output"
) -> List[str]: ) -> List[str]:
cfg = generate_cfg() render_config = get_render_config()
cfg = generate_cfg(render_config.cfg)
dims = render_config.sizes[size] dims = render_config.sizes[size]
steps = generate_steps() steps = generate_steps(render_config.steps)
seed = randint(0, 10000000) seed = randint(0, 10000000)
checkpoint = choice(render_config.checkpoints) checkpoint = choice(render_config.checkpoints)
logger.info( logger.info(
@ -248,6 +253,8 @@ def get_image_prefix(event: GameEvent | WorldEntity) -> str:
def render_loop(): def render_loop():
render_config = get_render_config()
while True: while True:
event = render_queue.get() event = render_queue.get()
prefix = get_image_prefix(event) prefix = get_image_prefix(event)
@ -317,13 +324,8 @@ def render_generated(event: GameEvent):
def launch_render(config: RenderConfig): def launch_render(config: RenderConfig):
global render_config
global render_thread global render_thread
# update the config
logger.info("updating render config: %s", config)
render_config = config
# start the render thread # start the render thread
logger.info("launching render thread") logger.info("launching render thread")
render_thread = Thread(target=render_loop, daemon=True) render_thread = Thread(target=render_loop, daemon=True)

View File

@ -16,10 +16,11 @@ from taleweave.context import (
broadcast, broadcast,
get_character_agent_for_name, get_character_agent_for_name,
get_current_world, get_current_world,
get_game_config,
set_character_agent, set_character_agent,
subscribe, subscribe,
) )
from taleweave.models.config import DEFAULT_CONFIG, WebsocketServerConfig from taleweave.models.config import WebsocketServerConfig
from taleweave.models.entity import World, WorldEntity from taleweave.models.entity import World, WorldEntity
from taleweave.models.event import ( from taleweave.models.event import (
GameEvent, GameEvent,
@ -47,7 +48,6 @@ last_snapshot: str | None = None
player_names: Dict[str, str] = {} player_names: Dict[str, str] = {}
recent_events: MutableSequence[GameEvent] = deque(maxlen=100) recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
recent_json: MutableSequence[str] = deque(maxlen=100) recent_json: MutableSequence[str] = deque(maxlen=100)
server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket
def get_player_name(client_id: str) -> str: def get_player_name(client_id: str) -> str:
@ -59,16 +59,15 @@ async def handler(websocket):
logger.info("client connected, given id: %s", id) logger.info("client connected, given id: %s", id)
connected.add(websocket) connected.add(websocket)
async def next_turn(character: str, prompt: str) -> None: async def next_turn(event: PromptEvent) -> None:
await websocket.send( await websocket.send(
dumps( dumps(
{ {
# TODO: these should be fields in the PromptEvent "type": event.type,
"type": "prompt", "client": id, # TODO: this should be a field in the PromptEvent
"client": id, "character": event.character,
"character": character, "prompt": event.prompt,
"prompt": prompt, "actions": event.actions,
"actions": [],
} }
), ),
) )
@ -77,7 +76,7 @@ async def handler(websocket):
# TODO: nothing about this is good # TODO: nothing about this is good
player = get_player(id) player = get_player(id)
if player and player.name == event.character.name: if player and player.name == event.character.name:
asyncio.run(next_turn(event.character.name, event.prompt)) asyncio.run(next_turn(event))
return True return True
return False return False
@ -303,10 +302,6 @@ def send_and_append(id: str, message: Dict):
def launch_server(config: WebsocketServerConfig): def launch_server(config: WebsocketServerConfig):
global socket_thread global socket_thread
global server_config
logger.info("configuring websocket server: %s", config)
server_config = config
def run_sockets(): def run_sockets():
asyncio.run(server_main()) asyncio.run(server_main())
@ -321,7 +316,11 @@ def launch_server(config: WebsocketServerConfig):
async def server_main(): async def server_main():
async with websockets.serve(handler, server_config.host, server_config.port): config = get_game_config()
async with websockets.serve(
handler, config.server.websocket.host, config.server.websocket.port
):
logger.info("websocket server started") logger.info("websocket server started")
await asyncio.Future() # run forever await asyncio.Future() # run forever

View File

@ -36,6 +36,7 @@ from taleweave.context import (
get_character_for_agent, get_character_for_agent,
get_current_turn, get_current_turn,
get_current_world, get_current_world,
get_game_config,
get_prompt, get_prompt,
set_current_character, set_current_character,
set_current_room, set_current_room,
@ -44,7 +45,6 @@ from taleweave.context import (
set_game_systems, set_game_systems,
) )
from taleweave.game_system import GameSystem from taleweave.game_system import GameSystem
from taleweave.models.config import DEFAULT_CONFIG
from taleweave.models.entity import Character, Room, World from taleweave.models.entity import Character, Room, World
from taleweave.models.event import ActionEvent, ResultEvent from taleweave.models.event import ActionEvent, ResultEvent
from taleweave.utils.conversation import make_keyword_condition, summarize_room from taleweave.utils.conversation import make_keyword_condition, summarize_room
@ -57,9 +57,6 @@ from taleweave.utils.world import describe_entity, format_attributes
logger = getLogger(__name__) logger = getLogger(__name__)
turn_config = DEFAULT_CONFIG.world.turn
def world_result_parser(value, agent, **kwargs): def world_result_parser(value, agent, **kwargs):
current_world = get_current_world() current_world = get_current_world()
if not current_world: if not current_world:
@ -121,10 +118,11 @@ def prompt_character_action(
event = ActionEvent.from_json(value, room, character) event = ActionEvent.from_json(value, room, character)
else: else:
# TODO: this path should be removed and throw # TODO: this path should be removed and throw
logger.warning( # logger.warning(
"invalid action, emitting as result event - this is a bug somewhere" # "invalid action, emitting as result event - this is a bug somewhere"
) # )
event = ResultEvent(value, room, character) # event = ResultEvent(value, room, character)
raise ValueError("invalid non-JSON action")
broadcast(event) broadcast(event)
@ -198,7 +196,8 @@ def prompt_character_planning(
current_turn: int, current_turn: int,
max_steps: int | None = None, max_steps: int | None = None,
) -> str: ) -> str:
max_steps = max_steps or turn_config.planning_steps config = get_game_config()
max_steps = max_steps or config.world.turn.planning_steps
notes_prompt, events_prompt = get_notes_events(character, current_turn) notes_prompt, events_prompt = get_notes_events(character, current_turn)

View File

@ -8,8 +8,7 @@ from packit.conditions import condition_or, condition_threshold, make_flag_condi
from packit.results import multi_function_or_str_result from packit.results import multi_function_or_str_result
from packit.utils import could_be_json from packit.utils import could_be_json
from taleweave.context import broadcast from taleweave.context import broadcast, get_game_config
from taleweave.models.config import DEFAULT_CONFIG
from taleweave.models.entity import Character, Room from taleweave.models.entity import Character, Room
from taleweave.models.event import ReplyEvent from taleweave.models.event import ReplyEvent
from taleweave.utils.prompt import format_str from taleweave.utils.prompt import format_str
@ -19,9 +18,6 @@ from .string import and_list, normalize_name
logger = getLogger(__name__) logger = getLogger(__name__)
character_config = DEFAULT_CONFIG.world.character
def make_keyword_condition(end_message: str, keywords=["end", "stop"]): def make_keyword_condition(end_message: str, keywords=["end", "stop"]):
set_end, condition_end = make_flag_condition() set_end, condition_end = make_flag_condition()
@ -99,8 +95,10 @@ def loop_conversation(
Loop through a conversation between a series of agents, using metadata from their characters. Loop through a conversation between a series of agents, using metadata from their characters.
""" """
config = get_game_config()
if max_length is None: if max_length is None:
max_length = character_config.conversation_limit max_length = config.world.character.conversation_limit
if len(characters) != len(agents): if len(characters) != len(agents):
raise ValueError("The number of characters and agents must match.") raise ValueError("The number of characters and agents must match.")

0
tests/utils/__init__.py Normal file
View File

View File

@ -0,0 +1,30 @@
from unittest import TestCase
from taleweave.models.entity import Room, World
from taleweave.utils.search import find_room
class TestFindRoom(TestCase):
def test_existing_room(self):
world = World(name="Test World", rooms=[], theme="testing", order=[])
room = Room(
name="Test Room",
description="A test room.",
characters=[],
items=[],
portals=[],
)
world.rooms.append(room)
self.assertEqual(find_room(world, "Test Room"), room)
def test_missing_room(self):
world = World(name="Test World", rooms=[], theme="testing", order=[])
room = Room(
name="Test Room",
description="A test room.",
characters=[],
items=[],
portals=[],
)
world.rooms.append(room)
self.assertEqual(find_room(world, "Missing Room"), None)