make sure config is used consistently, start adding tests
This commit is contained in:
parent
be37d58ade
commit
ef8529ef62
|
@ -5,3 +5,5 @@ venv/
|
|||
client/node_modules/
|
||||
client/out/
|
||||
taleweave/custom_*
|
||||
.coverage
|
||||
coverage.*
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# TaleWeave AI Command Line Options
|
||||
|
||||
The following command line arguments are available when launching the TaleWeave AI engine:
|
||||
|
||||
- **--actions**
|
||||
- **Type:** String
|
||||
- **Description:** Additional actions to include in the simulation. Note: More than one argument is allowed.
|
||||
|
||||
- **--add-rooms**
|
||||
- **Type:** Integer
|
||||
- **Default:** 0
|
||||
- **Description:** The number of new rooms to generate before starting the simulation.
|
||||
|
||||
- **--config**
|
||||
- **Type:** String
|
||||
- **Description:** The file to load additional configuration from.
|
||||
|
||||
- **--discord**
|
||||
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||
- **Description:** Run the simulation in a Discord bot.
|
||||
|
||||
- **--flavor**
|
||||
- **Type:** String
|
||||
- **Default:** ""
|
||||
- **Description:** Additional flavor text for the generated world.
|
||||
|
||||
- **--optional-actions**
|
||||
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||
- **Description:** Include optional actions in the simulation.
|
||||
|
||||
- **--player**
|
||||
- **Type:** String
|
||||
- **Description:** The name of the character to play as.
|
||||
|
||||
- **--prompts**
|
||||
- **Type:** String
|
||||
- **Description:** The file to load game prompts from. Note: More than one argument is allowed.
|
||||
|
||||
- **--render**
|
||||
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||
- **Description:** Run the render thread.
|
||||
|
||||
- **--render-generated**
|
||||
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||
- **Description:** Render entities as they are generated.
|
||||
|
||||
- **--rooms**
|
||||
- **Type:** Integer
|
||||
- **Description:** The number of rooms to generate.
|
||||
|
||||
- **--server**
|
||||
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||
- **Description:** Run the websocket server.
|
||||
|
||||
- **--state**
|
||||
- **Type:** String
|
||||
- **Description:** The file to save the world state to. Defaults to `$world.state.json` if not set.
|
||||
|
||||
- **--turns**
|
||||
- **Type:** Integer or "inf"
|
||||
- **Default:** 10
|
||||
- **Description:** The number of simulation turns to run.
|
||||
|
||||
- **--systems**
|
||||
- **Type:** String
|
||||
- **Description:** Extra systems to run in the simulation. Note: More than one argument is allowed.
|
||||
|
||||
- **--theme**
|
||||
- **Type:** String
|
||||
- **Default:** "fantasy"
|
||||
- **Description:** The theme of the generated world.
|
||||
|
||||
- **--world**
|
||||
- **Type:** String
|
||||
- **Default:** "world"
|
||||
- **Description:** The file to save the generated world to.
|
||||
|
||||
- **--world-template**
|
||||
- **Type:** String
|
||||
- **Description:** The template file to load the world prompt from.
|
|
@ -2,16 +2,14 @@ from taleweave.context import (
|
|||
action_context,
|
||||
get_agent_for_character,
|
||||
get_current_turn,
|
||||
get_game_config,
|
||||
get_prompt,
|
||||
)
|
||||
from taleweave.errors import ActionError
|
||||
from taleweave.models.config import DEFAULT_CONFIG
|
||||
from taleweave.models.planning import CalendarEvent
|
||||
from taleweave.utils.planning import get_recent_notes
|
||||
from taleweave.utils.prompt import format_prompt
|
||||
|
||||
character_config = DEFAULT_CONFIG.world.character
|
||||
|
||||
|
||||
def take_note(fact: str):
|
||||
"""
|
||||
|
@ -22,11 +20,13 @@ def take_note(fact: str):
|
|||
fact: The fact to remember.
|
||||
"""
|
||||
|
||||
config = get_game_config()
|
||||
|
||||
with action_context() as (_, action_character):
|
||||
if fact in action_character.planner.notes:
|
||||
raise ActionError(get_prompt("action_take_note_error_duplicate"))
|
||||
|
||||
if len(action_character.planner.notes) >= character_config.note_limit:
|
||||
if len(action_character.planner.notes) >= config.world.character.note_limit:
|
||||
raise ActionError(get_prompt("action_take_note_error_limit"))
|
||||
|
||||
action_character.planner.notes.append(fact)
|
||||
|
@ -103,6 +103,8 @@ def summarize_notes(limit: int) -> str:
|
|||
limit: The maximum number of notes to keep.
|
||||
"""
|
||||
|
||||
config = get_game_config()
|
||||
|
||||
with action_context() as (_, action_character):
|
||||
notes = action_character.planner.notes
|
||||
if len(notes) == 0:
|
||||
|
@ -120,11 +122,11 @@ def summarize_notes(limit: int) -> str:
|
|||
)
|
||||
|
||||
new_notes = [note.strip() for note in summary.split("\n") if note.strip()]
|
||||
if len(new_notes) > character_config.note_limit:
|
||||
if len(new_notes) > config.world.character.note_limit:
|
||||
raise ActionError(
|
||||
format_prompt(
|
||||
"action_summarize_notes_error_limit",
|
||||
limit=character_config.note_limit,
|
||||
limit=config.world.character.note_limit,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -165,7 +167,9 @@ def check_calendar(count: int):
|
|||
count: The number of upcoming events to read. 5 is usually a good number.
|
||||
"""
|
||||
|
||||
count = min(count, character_config.event_limit)
|
||||
config = get_game_config()
|
||||
|
||||
count = min(count, config.world.character.event_limit)
|
||||
current_turn = get_current_turn()
|
||||
|
||||
with action_context() as (_, action_character):
|
||||
|
|
|
@ -11,10 +11,11 @@ from taleweave.context import (
|
|||
broadcast,
|
||||
get_character_agent_for_name,
|
||||
get_current_world,
|
||||
get_game_config,
|
||||
set_character_agent,
|
||||
subscribe,
|
||||
)
|
||||
from taleweave.models.config import DEFAULT_CONFIG, DiscordBotConfig
|
||||
from taleweave.models.config import DiscordBotConfig
|
||||
from taleweave.models.event import (
|
||||
ActionEvent,
|
||||
GameEvent,
|
||||
|
@ -38,7 +39,6 @@ from taleweave.utils.prompt import format_prompt
|
|||
|
||||
logger = getLogger(__name__)
|
||||
client = None
|
||||
bot_config: DiscordBotConfig = DEFAULT_CONFIG.bot.discord
|
||||
|
||||
active_tasks = set()
|
||||
event_messages: Dict[int, str | GameEvent] = {}
|
||||
|
@ -78,21 +78,24 @@ class AdventureClient(Client):
|
|||
if message.author == self.user:
|
||||
return
|
||||
|
||||
config = get_game_config()
|
||||
author = message.author
|
||||
channel = message.channel
|
||||
user_name = author.name # include nick
|
||||
|
||||
if message.content.startswith(
|
||||
bot_config.command_prefix + bot_config.name_command
|
||||
config.bot.discord.command_prefix + config.bot.discord.name_command
|
||||
):
|
||||
world = get_current_world()
|
||||
if world:
|
||||
world_message = format_prompt(
|
||||
"discord_world_active", bot_name=bot_config.name_title, world=world
|
||||
"discord_world_active",
|
||||
bot_name=config.bot.discord.name_title,
|
||||
world=world,
|
||||
)
|
||||
else:
|
||||
world_message = format_prompt(
|
||||
"discord_world_none", bot_name=bot_config.name_title
|
||||
"discord_world_none", bot_name=config.bot.discord.name_title
|
||||
)
|
||||
|
||||
await message.channel.send(world_message)
|
||||
|
@ -100,7 +103,7 @@ class AdventureClient(Client):
|
|||
|
||||
if message.content.startswith("!help"):
|
||||
await message.channel.send(
|
||||
format_prompt("discord_help", bot_name=bot_config.name_command)
|
||||
format_prompt("discord_help", bot_name=config.bot.discord.name_command)
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -172,14 +175,11 @@ class AdventureClient(Client):
|
|||
|
||||
|
||||
def launch_bot(config: DiscordBotConfig):
|
||||
global bot_config
|
||||
global client
|
||||
|
||||
bot_config = config
|
||||
|
||||
# message contents need to be enabled for multi-server bots
|
||||
intents = Intents.default()
|
||||
if bot_config.content_intent:
|
||||
if config.content_intent:
|
||||
intents.message_content = True
|
||||
|
||||
client = AdventureClient(intents=intents)
|
||||
|
@ -246,12 +246,14 @@ def get_active_channels():
|
|||
if not client:
|
||||
return []
|
||||
|
||||
config = get_game_config()
|
||||
|
||||
# return client.private_channels
|
||||
return [
|
||||
channel
|
||||
for guild in client.guilds
|
||||
for channel in guild.text_channels
|
||||
if channel.name in bot_config.channels
|
||||
if channel.name in config.bot.discord.channels
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from packit.agent import Agent
|
|||
from pyee.base import EventEmitter
|
||||
|
||||
from taleweave.game_system import GameSystem
|
||||
from taleweave.models.config import DEFAULT_CONFIG, Config
|
||||
from taleweave.models.entity import Character, Room, World
|
||||
from taleweave.models.event import GameEvent, StatusEvent
|
||||
from taleweave.models.prompt import PromptLibrary
|
||||
|
@ -34,6 +35,7 @@ dungeon_master: Agent | None = None
|
|||
|
||||
# game context
|
||||
event_emitter = EventEmitter()
|
||||
game_config: Config = DEFAULT_CONFIG
|
||||
game_systems: List[GameSystem] = []
|
||||
prompt_library: PromptLibrary = PromptLibrary(prompts={})
|
||||
system_data: Dict[str, Any] = {}
|
||||
|
@ -160,6 +162,10 @@ def get_dungeon_master() -> Agent:
|
|||
return dungeon_master
|
||||
|
||||
|
||||
def get_game_config() -> Config:
|
||||
return game_config
|
||||
|
||||
|
||||
def get_game_systems() -> List[GameSystem]:
|
||||
return game_systems
|
||||
|
||||
|
@ -172,6 +178,10 @@ def get_prompt_library() -> PromptLibrary:
|
|||
return prompt_library
|
||||
|
||||
|
||||
def get_system_config(system: str) -> Any | None:
|
||||
return game_config.systems.data.get(system)
|
||||
|
||||
|
||||
def get_system_data(system: str) -> Any | None:
|
||||
return system_data.get(system)
|
||||
|
||||
|
@ -209,6 +219,11 @@ def set_dungeon_master(agent):
|
|||
dungeon_master = agent
|
||||
|
||||
|
||||
def set_game_config(config: Config):
|
||||
global game_config
|
||||
game_config = config
|
||||
|
||||
|
||||
def set_game_systems(systems: Sequence[GameSystem]):
|
||||
global game_systems
|
||||
game_systems = list(systems)
|
||||
|
|
|
@ -7,9 +7,14 @@ from packit.loops import loop_retry
|
|||
from packit.results import enum_result, int_result
|
||||
from packit.utils import could_be_json
|
||||
|
||||
from taleweave.context import broadcast, get_prompt, set_current_world, set_system_data
|
||||
from taleweave.context import (
|
||||
broadcast,
|
||||
get_game_config,
|
||||
get_prompt,
|
||||
set_current_world,
|
||||
set_system_data,
|
||||
)
|
||||
from taleweave.game_system import GameSystem
|
||||
from taleweave.models.config import DEFAULT_CONFIG, WorldConfig
|
||||
from taleweave.models.effect import (
|
||||
EffectPattern,
|
||||
FloatEffectPattern,
|
||||
|
@ -33,7 +38,10 @@ from taleweave.utils.string import normalize_name
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
world_config: WorldConfig = DEFAULT_CONFIG.world
|
||||
|
||||
def get_world_config():
|
||||
config = get_game_config()
|
||||
return config.world
|
||||
|
||||
|
||||
def duplicate_name_parser(existing_names: List[str]):
|
||||
|
@ -112,6 +120,7 @@ def generate_room(
|
|||
actions = {}
|
||||
room = Room(name=name, description=desc, items=[], characters=[], actions=actions)
|
||||
|
||||
world_config = get_world_config()
|
||||
item_count = resolve_int_range(world_config.size.room_items) or 0
|
||||
broadcast_generated(
|
||||
format_prompt(
|
||||
|
@ -276,6 +285,7 @@ def generate_item(
|
|||
item = Item(name=name, description=desc, actions=actions)
|
||||
generate_system_attributes(agent, world, item, systems)
|
||||
|
||||
world_config = get_world_config()
|
||||
effect_count = resolve_int_range(world_config.size.item_effects) or 0
|
||||
broadcast_generated(
|
||||
message=format_prompt(
|
||||
|
@ -343,6 +353,7 @@ def generate_character(
|
|||
generate_system_attributes(agent, world, character, systems)
|
||||
|
||||
# generate the character's inventory
|
||||
world_config = get_world_config()
|
||||
item_count = resolve_int_range(world_config.size.character_items) or 0
|
||||
broadcast_generated(
|
||||
message=format_prompt(
|
||||
|
@ -499,6 +510,7 @@ def link_rooms(
|
|||
rooms: List[Room] | None = None,
|
||||
) -> None:
|
||||
rooms = rooms or world.rooms
|
||||
world_config = get_world_config()
|
||||
|
||||
for room in rooms:
|
||||
num_portals = resolve_int_range(world_config.size.portals) or 0
|
||||
|
@ -550,6 +562,7 @@ def generate_world(
|
|||
systems: List[GameSystem],
|
||||
room_count: int | None = None,
|
||||
) -> World:
|
||||
world_config = get_world_config()
|
||||
room_count = room_count or resolve_int_range(world_config.size.rooms) or 0
|
||||
|
||||
broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme"))
|
||||
|
|
|
@ -27,7 +27,7 @@ except Exception as err:
|
|||
|
||||
logger = logger_with_colors(__name__) # , level="DEBUG")
|
||||
|
||||
load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True)
|
||||
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
|
||||
|
||||
if True:
|
||||
from taleweave.context import (
|
||||
|
@ -35,6 +35,7 @@ if True:
|
|||
get_system_data,
|
||||
set_current_turn,
|
||||
set_dungeon_master,
|
||||
set_game_config,
|
||||
set_system_data,
|
||||
subscribe,
|
||||
)
|
||||
|
@ -312,6 +313,7 @@ def main():
|
|||
if args.config:
|
||||
with open(args.config, "r") as f:
|
||||
config = Config(**load_yaml(f))
|
||||
set_game_config(config)
|
||||
else:
|
||||
config = DEFAULT_CONFIG
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Dict, List
|
||||
|
||||
from .base import IntRange, dataclass
|
||||
from .base import Attributes, IntRange, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -44,6 +44,15 @@ class ServerConfig:
|
|||
websocket: WebsocketServerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemsConfig:
|
||||
"""
|
||||
Configuration for the game systems.
|
||||
"""
|
||||
|
||||
data: Attributes
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorldCharacterConfig:
|
||||
conversation_limit: int
|
||||
|
@ -80,6 +89,7 @@ class Config:
|
|||
bot: BotConfig
|
||||
render: RenderConfig
|
||||
server: ServerConfig
|
||||
systems: SystemsConfig
|
||||
world: WorldConfig
|
||||
|
||||
|
||||
|
@ -107,6 +117,7 @@ DEFAULT_CONFIG = Config(
|
|||
steps=30,
|
||||
),
|
||||
server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8001)),
|
||||
systems=SystemsConfig(data={}),
|
||||
world=WorldConfig(
|
||||
character=WorldCharacterConfig(
|
||||
conversation_limit=2,
|
||||
|
|
|
@ -60,6 +60,7 @@ class PromptEvent(BaseModel):
|
|||
A prompt for a character to take an action.
|
||||
"""
|
||||
|
||||
actions: Dict[str, Any]
|
||||
prompt: str
|
||||
room: Room
|
||||
character: Character
|
||||
|
|
|
@ -197,15 +197,20 @@ class RemotePlayer(BasePlayer):
|
|||
Ask the player for input.
|
||||
"""
|
||||
|
||||
actions = {}
|
||||
formatted_prompt = prompt.format(**kwargs)
|
||||
if toolbox:
|
||||
actions = toolbox.list_definitions()
|
||||
formatted_prompt += self.format_psuedo_functions(toolbox)
|
||||
|
||||
self.memory.append(HumanMessage(content=formatted_prompt))
|
||||
|
||||
with action_context() as (current_room, current_character):
|
||||
prompt_event = PromptEvent(
|
||||
prompt=formatted_prompt, room=current_room, character=current_character
|
||||
actions=actions,
|
||||
prompt=formatted_prompt,
|
||||
room=current_room,
|
||||
character=current_character,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
@ -15,9 +15,9 @@ from fnvhash import fnv1a_32
|
|||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||
from PIL import Image
|
||||
|
||||
from taleweave.context import broadcast
|
||||
from taleweave.models.base import uuid
|
||||
from taleweave.models.config import DEFAULT_CONFIG, RenderConfig
|
||||
from taleweave.context import broadcast, get_game_config
|
||||
from taleweave.models.base import IntRange, uuid
|
||||
from taleweave.models.config import RenderConfig
|
||||
from taleweave.models.entity import WorldEntity
|
||||
from taleweave.models.event import (
|
||||
ActionEvent,
|
||||
|
@ -36,7 +36,6 @@ logger = getLogger(__name__)
|
|||
|
||||
server_address = environ["COMFY_API"]
|
||||
client_id = uuid()
|
||||
render_config: RenderConfig = DEFAULT_CONFIG.render
|
||||
|
||||
|
||||
# requests to generate images for game events
|
||||
|
@ -44,12 +43,17 @@ render_queue: Queue[GameEvent | WorldEntity] = Queue()
|
|||
render_thread: Thread | None = None
|
||||
|
||||
|
||||
def generate_cfg():
|
||||
return resolve_int_range(render_config.cfg)
|
||||
def get_render_config():
|
||||
config = get_game_config()
|
||||
return config.render
|
||||
|
||||
|
||||
def generate_steps():
|
||||
return resolve_int_range(render_config.steps)
|
||||
def generate_cfg(cfg: int | IntRange):
|
||||
return resolve_int_range(cfg)
|
||||
|
||||
|
||||
def generate_steps(steps: int | IntRange):
|
||||
return resolve_int_range(steps)
|
||||
|
||||
|
||||
def generate_batches(
|
||||
|
@ -148,9 +152,10 @@ def generate_image_tool(prompt, count, size="landscape"):
|
|||
def generate_images(
|
||||
prompt: str, count: int, size="landscape", prefix="output"
|
||||
) -> List[str]:
|
||||
cfg = generate_cfg()
|
||||
render_config = get_render_config()
|
||||
cfg = generate_cfg(render_config.cfg)
|
||||
dims = render_config.sizes[size]
|
||||
steps = generate_steps()
|
||||
steps = generate_steps(render_config.steps)
|
||||
seed = randint(0, 10000000)
|
||||
checkpoint = choice(render_config.checkpoints)
|
||||
logger.info(
|
||||
|
@ -248,6 +253,8 @@ def get_image_prefix(event: GameEvent | WorldEntity) -> str:
|
|||
|
||||
|
||||
def render_loop():
|
||||
render_config = get_render_config()
|
||||
|
||||
while True:
|
||||
event = render_queue.get()
|
||||
prefix = get_image_prefix(event)
|
||||
|
@ -317,13 +324,8 @@ def render_generated(event: GameEvent):
|
|||
|
||||
|
||||
def launch_render(config: RenderConfig):
|
||||
global render_config
|
||||
global render_thread
|
||||
|
||||
# update the config
|
||||
logger.info("updating render config: %s", config)
|
||||
render_config = config
|
||||
|
||||
# start the render thread
|
||||
logger.info("launching render thread")
|
||||
render_thread = Thread(target=render_loop, daemon=True)
|
||||
|
|
|
@ -16,10 +16,11 @@ from taleweave.context import (
|
|||
broadcast,
|
||||
get_character_agent_for_name,
|
||||
get_current_world,
|
||||
get_game_config,
|
||||
set_character_agent,
|
||||
subscribe,
|
||||
)
|
||||
from taleweave.models.config import DEFAULT_CONFIG, WebsocketServerConfig
|
||||
from taleweave.models.config import WebsocketServerConfig
|
||||
from taleweave.models.entity import World, WorldEntity
|
||||
from taleweave.models.event import (
|
||||
GameEvent,
|
||||
|
@ -47,7 +48,6 @@ last_snapshot: str | None = None
|
|||
player_names: Dict[str, str] = {}
|
||||
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
|
||||
recent_json: MutableSequence[str] = deque(maxlen=100)
|
||||
server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket
|
||||
|
||||
|
||||
def get_player_name(client_id: str) -> str:
|
||||
|
@ -59,16 +59,15 @@ async def handler(websocket):
|
|||
logger.info("client connected, given id: %s", id)
|
||||
connected.add(websocket)
|
||||
|
||||
async def next_turn(character: str, prompt: str) -> None:
|
||||
async def next_turn(event: PromptEvent) -> None:
|
||||
await websocket.send(
|
||||
dumps(
|
||||
{
|
||||
# TODO: these should be fields in the PromptEvent
|
||||
"type": "prompt",
|
||||
"client": id,
|
||||
"character": character,
|
||||
"prompt": prompt,
|
||||
"actions": [],
|
||||
"type": event.type,
|
||||
"client": id, # TODO: this should be a field in the PromptEvent
|
||||
"character": event.character,
|
||||
"prompt": event.prompt,
|
||||
"actions": event.actions,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -77,7 +76,7 @@ async def handler(websocket):
|
|||
# TODO: nothing about this is good
|
||||
player = get_player(id)
|
||||
if player and player.name == event.character.name:
|
||||
asyncio.run(next_turn(event.character.name, event.prompt))
|
||||
asyncio.run(next_turn(event))
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -303,10 +302,6 @@ def send_and_append(id: str, message: Dict):
|
|||
|
||||
def launch_server(config: WebsocketServerConfig):
|
||||
global socket_thread
|
||||
global server_config
|
||||
|
||||
logger.info("configuring websocket server: %s", config)
|
||||
server_config = config
|
||||
|
||||
def run_sockets():
|
||||
asyncio.run(server_main())
|
||||
|
@ -321,7 +316,11 @@ def launch_server(config: WebsocketServerConfig):
|
|||
|
||||
|
||||
async def server_main():
|
||||
async with websockets.serve(handler, server_config.host, server_config.port):
|
||||
config = get_game_config()
|
||||
|
||||
async with websockets.serve(
|
||||
handler, config.server.websocket.host, config.server.websocket.port
|
||||
):
|
||||
logger.info("websocket server started")
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ from taleweave.context import (
|
|||
get_character_for_agent,
|
||||
get_current_turn,
|
||||
get_current_world,
|
||||
get_game_config,
|
||||
get_prompt,
|
||||
set_current_character,
|
||||
set_current_room,
|
||||
|
@ -44,7 +45,6 @@ from taleweave.context import (
|
|||
set_game_systems,
|
||||
)
|
||||
from taleweave.game_system import GameSystem
|
||||
from taleweave.models.config import DEFAULT_CONFIG
|
||||
from taleweave.models.entity import Character, Room, World
|
||||
from taleweave.models.event import ActionEvent, ResultEvent
|
||||
from taleweave.utils.conversation import make_keyword_condition, summarize_room
|
||||
|
@ -57,9 +57,6 @@ from taleweave.utils.world import describe_entity, format_attributes
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
turn_config = DEFAULT_CONFIG.world.turn
|
||||
|
||||
|
||||
def world_result_parser(value, agent, **kwargs):
|
||||
current_world = get_current_world()
|
||||
if not current_world:
|
||||
|
@ -121,10 +118,11 @@ def prompt_character_action(
|
|||
event = ActionEvent.from_json(value, room, character)
|
||||
else:
|
||||
# TODO: this path should be removed and throw
|
||||
logger.warning(
|
||||
"invalid action, emitting as result event - this is a bug somewhere"
|
||||
)
|
||||
event = ResultEvent(value, room, character)
|
||||
# logger.warning(
|
||||
# "invalid action, emitting as result event - this is a bug somewhere"
|
||||
# )
|
||||
# event = ResultEvent(value, room, character)
|
||||
raise ValueError("invalid non-JSON action")
|
||||
|
||||
broadcast(event)
|
||||
|
||||
|
@ -198,7 +196,8 @@ def prompt_character_planning(
|
|||
current_turn: int,
|
||||
max_steps: int | None = None,
|
||||
) -> str:
|
||||
max_steps = max_steps or turn_config.planning_steps
|
||||
config = get_game_config()
|
||||
max_steps = max_steps or config.world.turn.planning_steps
|
||||
|
||||
notes_prompt, events_prompt = get_notes_events(character, current_turn)
|
||||
|
||||
|
|
|
@ -8,8 +8,7 @@ from packit.conditions import condition_or, condition_threshold, make_flag_condi
|
|||
from packit.results import multi_function_or_str_result
|
||||
from packit.utils import could_be_json
|
||||
|
||||
from taleweave.context import broadcast
|
||||
from taleweave.models.config import DEFAULT_CONFIG
|
||||
from taleweave.context import broadcast, get_game_config
|
||||
from taleweave.models.entity import Character, Room
|
||||
from taleweave.models.event import ReplyEvent
|
||||
from taleweave.utils.prompt import format_str
|
||||
|
@ -19,9 +18,6 @@ from .string import and_list, normalize_name
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
character_config = DEFAULT_CONFIG.world.character
|
||||
|
||||
|
||||
def make_keyword_condition(end_message: str, keywords=["end", "stop"]):
|
||||
set_end, condition_end = make_flag_condition()
|
||||
|
||||
|
@ -99,8 +95,10 @@ def loop_conversation(
|
|||
Loop through a conversation between a series of agents, using metadata from their characters.
|
||||
"""
|
||||
|
||||
config = get_game_config()
|
||||
|
||||
if max_length is None:
|
||||
max_length = character_config.conversation_limit
|
||||
max_length = config.world.character.conversation_limit
|
||||
|
||||
if len(characters) != len(agents):
|
||||
raise ValueError("The number of characters and agents must match.")
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from taleweave.models.entity import Room, World
|
||||
from taleweave.utils.search import find_room
|
||||
|
||||
|
||||
class TestFindRoom(TestCase):
|
||||
def test_existing_room(self):
|
||||
world = World(name="Test World", rooms=[], theme="testing", order=[])
|
||||
room = Room(
|
||||
name="Test Room",
|
||||
description="A test room.",
|
||||
characters=[],
|
||||
items=[],
|
||||
portals=[],
|
||||
)
|
||||
world.rooms.append(room)
|
||||
self.assertEqual(find_room(world, "Test Room"), room)
|
||||
|
||||
def test_missing_room(self):
|
||||
world = World(name="Test World", rooms=[], theme="testing", order=[])
|
||||
room = Room(
|
||||
name="Test Room",
|
||||
description="A test room.",
|
||||
characters=[],
|
||||
items=[],
|
||||
portals=[],
|
||||
)
|
||||
world.rooms.append(room)
|
||||
self.assertEqual(find_room(world, "Missing Room"), None)
|
Loading…
Reference in New Issue