make sure config is used consistently, start adding tests
This commit is contained in:
parent
be37d58ade
commit
ef8529ef62
|
@ -5,3 +5,5 @@ venv/
|
||||||
client/node_modules/
|
client/node_modules/
|
||||||
client/out/
|
client/out/
|
||||||
taleweave/custom_*
|
taleweave/custom_*
|
||||||
|
.coverage
|
||||||
|
coverage.*
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
# TaleWeave AI Command Line Options
|
||||||
|
|
||||||
|
The following command line arguments are available when launching the TaleWeave AI engine:
|
||||||
|
|
||||||
|
- **--actions**
|
||||||
|
- **Type:** String
|
||||||
|
- **Description:** Additional actions to include in the simulation. Note: More than one argument is allowed.
|
||||||
|
|
||||||
|
- **--add-rooms**
|
||||||
|
- **Type:** Integer
|
||||||
|
- **Default:** 0
|
||||||
|
- **Description:** The number of new rooms to generate before starting the simulation.
|
||||||
|
|
||||||
|
- **--config**
|
||||||
|
- **Type:** String
|
||||||
|
- **Description:** The file to load additional configuration from.
|
||||||
|
|
||||||
|
- **--discord**
|
||||||
|
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||||
|
- **Description:** Run the simulation in a Discord bot.
|
||||||
|
|
||||||
|
- **--flavor**
|
||||||
|
- **Type:** String
|
||||||
|
- **Default:** ""
|
||||||
|
- **Description:** Additional flavor text for the generated world.
|
||||||
|
|
||||||
|
- **--optional-actions**
|
||||||
|
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||||
|
- **Description:** Include optional actions in the simulation.
|
||||||
|
|
||||||
|
- **--player**
|
||||||
|
- **Type:** String
|
||||||
|
- **Description:** The name of the character to play as.
|
||||||
|
|
||||||
|
- **--prompts**
|
||||||
|
- **Type:** String
|
||||||
|
- **Description:** The file to load game prompts from. Note: More than one argument is allowed.
|
||||||
|
|
||||||
|
- **--render**
|
||||||
|
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||||
|
- **Description:** Run the render thread.
|
||||||
|
|
||||||
|
- **--render-generated**
|
||||||
|
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||||
|
- **Description:** Render entities as they are generated.
|
||||||
|
|
||||||
|
- **--rooms**
|
||||||
|
- **Type:** Integer
|
||||||
|
- **Description:** The number of rooms to generate.
|
||||||
|
|
||||||
|
- **--server**
|
||||||
|
- **Action:** No options are needed for this argument. Simply passing the argument name is enough to enable this option.
|
||||||
|
- **Description:** Run the websocket server.
|
||||||
|
|
||||||
|
- **--state**
|
||||||
|
- **Type:** String
|
||||||
|
- **Description:** The file to save the world state to. Defaults to `$world.state.json` if not set.
|
||||||
|
|
||||||
|
- **--turns**
|
||||||
|
- **Type:** Integer or "inf"
|
||||||
|
- **Default:** 10
|
||||||
|
- **Description:** The number of simulation turns to run.
|
||||||
|
|
||||||
|
- **--systems**
|
||||||
|
- **Type:** String
|
||||||
|
- **Description:** Extra systems to run in the simulation. Note: More than one argument is allowed.
|
||||||
|
|
||||||
|
- **--theme**
|
||||||
|
- **Type:** String
|
||||||
|
- **Default:** "fantasy"
|
||||||
|
- **Description:** The theme of the generated world.
|
||||||
|
|
||||||
|
- **--world**
|
||||||
|
- **Type:** String
|
||||||
|
- **Default:** "world"
|
||||||
|
- **Description:** The file to save the generated world to.
|
||||||
|
|
||||||
|
- **--world-template**
|
||||||
|
- **Type:** String
|
||||||
|
- **Description:** The template file to load the world prompt from.
|
|
@ -2,16 +2,14 @@ from taleweave.context import (
|
||||||
action_context,
|
action_context,
|
||||||
get_agent_for_character,
|
get_agent_for_character,
|
||||||
get_current_turn,
|
get_current_turn,
|
||||||
|
get_game_config,
|
||||||
get_prompt,
|
get_prompt,
|
||||||
)
|
)
|
||||||
from taleweave.errors import ActionError
|
from taleweave.errors import ActionError
|
||||||
from taleweave.models.config import DEFAULT_CONFIG
|
|
||||||
from taleweave.models.planning import CalendarEvent
|
from taleweave.models.planning import CalendarEvent
|
||||||
from taleweave.utils.planning import get_recent_notes
|
from taleweave.utils.planning import get_recent_notes
|
||||||
from taleweave.utils.prompt import format_prompt
|
from taleweave.utils.prompt import format_prompt
|
||||||
|
|
||||||
character_config = DEFAULT_CONFIG.world.character
|
|
||||||
|
|
||||||
|
|
||||||
def take_note(fact: str):
|
def take_note(fact: str):
|
||||||
"""
|
"""
|
||||||
|
@ -22,11 +20,13 @@ def take_note(fact: str):
|
||||||
fact: The fact to remember.
|
fact: The fact to remember.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config = get_game_config()
|
||||||
|
|
||||||
with action_context() as (_, action_character):
|
with action_context() as (_, action_character):
|
||||||
if fact in action_character.planner.notes:
|
if fact in action_character.planner.notes:
|
||||||
raise ActionError(get_prompt("action_take_note_error_duplicate"))
|
raise ActionError(get_prompt("action_take_note_error_duplicate"))
|
||||||
|
|
||||||
if len(action_character.planner.notes) >= character_config.note_limit:
|
if len(action_character.planner.notes) >= config.world.character.note_limit:
|
||||||
raise ActionError(get_prompt("action_take_note_error_limit"))
|
raise ActionError(get_prompt("action_take_note_error_limit"))
|
||||||
|
|
||||||
action_character.planner.notes.append(fact)
|
action_character.planner.notes.append(fact)
|
||||||
|
@ -103,6 +103,8 @@ def summarize_notes(limit: int) -> str:
|
||||||
limit: The maximum number of notes to keep.
|
limit: The maximum number of notes to keep.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config = get_game_config()
|
||||||
|
|
||||||
with action_context() as (_, action_character):
|
with action_context() as (_, action_character):
|
||||||
notes = action_character.planner.notes
|
notes = action_character.planner.notes
|
||||||
if len(notes) == 0:
|
if len(notes) == 0:
|
||||||
|
@ -120,11 +122,11 @@ def summarize_notes(limit: int) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
new_notes = [note.strip() for note in summary.split("\n") if note.strip()]
|
new_notes = [note.strip() for note in summary.split("\n") if note.strip()]
|
||||||
if len(new_notes) > character_config.note_limit:
|
if len(new_notes) > config.world.character.note_limit:
|
||||||
raise ActionError(
|
raise ActionError(
|
||||||
format_prompt(
|
format_prompt(
|
||||||
"action_summarize_notes_error_limit",
|
"action_summarize_notes_error_limit",
|
||||||
limit=character_config.note_limit,
|
limit=config.world.character.note_limit,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -165,7 +167,9 @@ def check_calendar(count: int):
|
||||||
count: The number of upcoming events to read. 5 is usually a good number.
|
count: The number of upcoming events to read. 5 is usually a good number.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
count = min(count, character_config.event_limit)
|
config = get_game_config()
|
||||||
|
|
||||||
|
count = min(count, config.world.character.event_limit)
|
||||||
current_turn = get_current_turn()
|
current_turn = get_current_turn()
|
||||||
|
|
||||||
with action_context() as (_, action_character):
|
with action_context() as (_, action_character):
|
||||||
|
|
|
@ -11,10 +11,11 @@ from taleweave.context import (
|
||||||
broadcast,
|
broadcast,
|
||||||
get_character_agent_for_name,
|
get_character_agent_for_name,
|
||||||
get_current_world,
|
get_current_world,
|
||||||
|
get_game_config,
|
||||||
set_character_agent,
|
set_character_agent,
|
||||||
subscribe,
|
subscribe,
|
||||||
)
|
)
|
||||||
from taleweave.models.config import DEFAULT_CONFIG, DiscordBotConfig
|
from taleweave.models.config import DiscordBotConfig
|
||||||
from taleweave.models.event import (
|
from taleweave.models.event import (
|
||||||
ActionEvent,
|
ActionEvent,
|
||||||
GameEvent,
|
GameEvent,
|
||||||
|
@ -38,7 +39,6 @@ from taleweave.utils.prompt import format_prompt
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
client = None
|
client = None
|
||||||
bot_config: DiscordBotConfig = DEFAULT_CONFIG.bot.discord
|
|
||||||
|
|
||||||
active_tasks = set()
|
active_tasks = set()
|
||||||
event_messages: Dict[int, str | GameEvent] = {}
|
event_messages: Dict[int, str | GameEvent] = {}
|
||||||
|
@ -78,21 +78,24 @@ class AdventureClient(Client):
|
||||||
if message.author == self.user:
|
if message.author == self.user:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
config = get_game_config()
|
||||||
author = message.author
|
author = message.author
|
||||||
channel = message.channel
|
channel = message.channel
|
||||||
user_name = author.name # include nick
|
user_name = author.name # include nick
|
||||||
|
|
||||||
if message.content.startswith(
|
if message.content.startswith(
|
||||||
bot_config.command_prefix + bot_config.name_command
|
config.bot.discord.command_prefix + config.bot.discord.name_command
|
||||||
):
|
):
|
||||||
world = get_current_world()
|
world = get_current_world()
|
||||||
if world:
|
if world:
|
||||||
world_message = format_prompt(
|
world_message = format_prompt(
|
||||||
"discord_world_active", bot_name=bot_config.name_title, world=world
|
"discord_world_active",
|
||||||
|
bot_name=config.bot.discord.name_title,
|
||||||
|
world=world,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
world_message = format_prompt(
|
world_message = format_prompt(
|
||||||
"discord_world_none", bot_name=bot_config.name_title
|
"discord_world_none", bot_name=config.bot.discord.name_title
|
||||||
)
|
)
|
||||||
|
|
||||||
await message.channel.send(world_message)
|
await message.channel.send(world_message)
|
||||||
|
@ -100,7 +103,7 @@ class AdventureClient(Client):
|
||||||
|
|
||||||
if message.content.startswith("!help"):
|
if message.content.startswith("!help"):
|
||||||
await message.channel.send(
|
await message.channel.send(
|
||||||
format_prompt("discord_help", bot_name=bot_config.name_command)
|
format_prompt("discord_help", bot_name=config.bot.discord.name_command)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -172,14 +175,11 @@ class AdventureClient(Client):
|
||||||
|
|
||||||
|
|
||||||
def launch_bot(config: DiscordBotConfig):
|
def launch_bot(config: DiscordBotConfig):
|
||||||
global bot_config
|
|
||||||
global client
|
global client
|
||||||
|
|
||||||
bot_config = config
|
|
||||||
|
|
||||||
# message contents need to be enabled for multi-server bots
|
# message contents need to be enabled for multi-server bots
|
||||||
intents = Intents.default()
|
intents = Intents.default()
|
||||||
if bot_config.content_intent:
|
if config.content_intent:
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
|
|
||||||
client = AdventureClient(intents=intents)
|
client = AdventureClient(intents=intents)
|
||||||
|
@ -246,12 +246,14 @@ def get_active_channels():
|
||||||
if not client:
|
if not client:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
config = get_game_config()
|
||||||
|
|
||||||
# return client.private_channels
|
# return client.private_channels
|
||||||
return [
|
return [
|
||||||
channel
|
channel
|
||||||
for guild in client.guilds
|
for guild in client.guilds
|
||||||
for channel in guild.text_channels
|
for channel in guild.text_channels
|
||||||
if channel.name in bot_config.channels
|
if channel.name in config.bot.discord.channels
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from packit.agent import Agent
|
||||||
from pyee.base import EventEmitter
|
from pyee.base import EventEmitter
|
||||||
|
|
||||||
from taleweave.game_system import GameSystem
|
from taleweave.game_system import GameSystem
|
||||||
|
from taleweave.models.config import DEFAULT_CONFIG, Config
|
||||||
from taleweave.models.entity import Character, Room, World
|
from taleweave.models.entity import Character, Room, World
|
||||||
from taleweave.models.event import GameEvent, StatusEvent
|
from taleweave.models.event import GameEvent, StatusEvent
|
||||||
from taleweave.models.prompt import PromptLibrary
|
from taleweave.models.prompt import PromptLibrary
|
||||||
|
@ -34,6 +35,7 @@ dungeon_master: Agent | None = None
|
||||||
|
|
||||||
# game context
|
# game context
|
||||||
event_emitter = EventEmitter()
|
event_emitter = EventEmitter()
|
||||||
|
game_config: Config = DEFAULT_CONFIG
|
||||||
game_systems: List[GameSystem] = []
|
game_systems: List[GameSystem] = []
|
||||||
prompt_library: PromptLibrary = PromptLibrary(prompts={})
|
prompt_library: PromptLibrary = PromptLibrary(prompts={})
|
||||||
system_data: Dict[str, Any] = {}
|
system_data: Dict[str, Any] = {}
|
||||||
|
@ -160,6 +162,10 @@ def get_dungeon_master() -> Agent:
|
||||||
return dungeon_master
|
return dungeon_master
|
||||||
|
|
||||||
|
|
||||||
|
def get_game_config() -> Config:
|
||||||
|
return game_config
|
||||||
|
|
||||||
|
|
||||||
def get_game_systems() -> List[GameSystem]:
|
def get_game_systems() -> List[GameSystem]:
|
||||||
return game_systems
|
return game_systems
|
||||||
|
|
||||||
|
@ -172,6 +178,10 @@ def get_prompt_library() -> PromptLibrary:
|
||||||
return prompt_library
|
return prompt_library
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_config(system: str) -> Any | None:
|
||||||
|
return game_config.systems.data.get(system)
|
||||||
|
|
||||||
|
|
||||||
def get_system_data(system: str) -> Any | None:
|
def get_system_data(system: str) -> Any | None:
|
||||||
return system_data.get(system)
|
return system_data.get(system)
|
||||||
|
|
||||||
|
@ -209,6 +219,11 @@ def set_dungeon_master(agent):
|
||||||
dungeon_master = agent
|
dungeon_master = agent
|
||||||
|
|
||||||
|
|
||||||
|
def set_game_config(config: Config):
|
||||||
|
global game_config
|
||||||
|
game_config = config
|
||||||
|
|
||||||
|
|
||||||
def set_game_systems(systems: Sequence[GameSystem]):
|
def set_game_systems(systems: Sequence[GameSystem]):
|
||||||
global game_systems
|
global game_systems
|
||||||
game_systems = list(systems)
|
game_systems = list(systems)
|
||||||
|
|
|
@ -7,9 +7,14 @@ from packit.loops import loop_retry
|
||||||
from packit.results import enum_result, int_result
|
from packit.results import enum_result, int_result
|
||||||
from packit.utils import could_be_json
|
from packit.utils import could_be_json
|
||||||
|
|
||||||
from taleweave.context import broadcast, get_prompt, set_current_world, set_system_data
|
from taleweave.context import (
|
||||||
|
broadcast,
|
||||||
|
get_game_config,
|
||||||
|
get_prompt,
|
||||||
|
set_current_world,
|
||||||
|
set_system_data,
|
||||||
|
)
|
||||||
from taleweave.game_system import GameSystem
|
from taleweave.game_system import GameSystem
|
||||||
from taleweave.models.config import DEFAULT_CONFIG, WorldConfig
|
|
||||||
from taleweave.models.effect import (
|
from taleweave.models.effect import (
|
||||||
EffectPattern,
|
EffectPattern,
|
||||||
FloatEffectPattern,
|
FloatEffectPattern,
|
||||||
|
@ -33,7 +38,10 @@ from taleweave.utils.string import normalize_name
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
world_config: WorldConfig = DEFAULT_CONFIG.world
|
|
||||||
|
def get_world_config():
|
||||||
|
config = get_game_config()
|
||||||
|
return config.world
|
||||||
|
|
||||||
|
|
||||||
def duplicate_name_parser(existing_names: List[str]):
|
def duplicate_name_parser(existing_names: List[str]):
|
||||||
|
@ -112,6 +120,7 @@ def generate_room(
|
||||||
actions = {}
|
actions = {}
|
||||||
room = Room(name=name, description=desc, items=[], characters=[], actions=actions)
|
room = Room(name=name, description=desc, items=[], characters=[], actions=actions)
|
||||||
|
|
||||||
|
world_config = get_world_config()
|
||||||
item_count = resolve_int_range(world_config.size.room_items) or 0
|
item_count = resolve_int_range(world_config.size.room_items) or 0
|
||||||
broadcast_generated(
|
broadcast_generated(
|
||||||
format_prompt(
|
format_prompt(
|
||||||
|
@ -276,6 +285,7 @@ def generate_item(
|
||||||
item = Item(name=name, description=desc, actions=actions)
|
item = Item(name=name, description=desc, actions=actions)
|
||||||
generate_system_attributes(agent, world, item, systems)
|
generate_system_attributes(agent, world, item, systems)
|
||||||
|
|
||||||
|
world_config = get_world_config()
|
||||||
effect_count = resolve_int_range(world_config.size.item_effects) or 0
|
effect_count = resolve_int_range(world_config.size.item_effects) or 0
|
||||||
broadcast_generated(
|
broadcast_generated(
|
||||||
message=format_prompt(
|
message=format_prompt(
|
||||||
|
@ -343,6 +353,7 @@ def generate_character(
|
||||||
generate_system_attributes(agent, world, character, systems)
|
generate_system_attributes(agent, world, character, systems)
|
||||||
|
|
||||||
# generate the character's inventory
|
# generate the character's inventory
|
||||||
|
world_config = get_world_config()
|
||||||
item_count = resolve_int_range(world_config.size.character_items) or 0
|
item_count = resolve_int_range(world_config.size.character_items) or 0
|
||||||
broadcast_generated(
|
broadcast_generated(
|
||||||
message=format_prompt(
|
message=format_prompt(
|
||||||
|
@ -499,6 +510,7 @@ def link_rooms(
|
||||||
rooms: List[Room] | None = None,
|
rooms: List[Room] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
rooms = rooms or world.rooms
|
rooms = rooms or world.rooms
|
||||||
|
world_config = get_world_config()
|
||||||
|
|
||||||
for room in rooms:
|
for room in rooms:
|
||||||
num_portals = resolve_int_range(world_config.size.portals) or 0
|
num_portals = resolve_int_range(world_config.size.portals) or 0
|
||||||
|
@ -550,6 +562,7 @@ def generate_world(
|
||||||
systems: List[GameSystem],
|
systems: List[GameSystem],
|
||||||
room_count: int | None = None,
|
room_count: int | None = None,
|
||||||
) -> World:
|
) -> World:
|
||||||
|
world_config = get_world_config()
|
||||||
room_count = room_count or resolve_int_range(world_config.size.rooms) or 0
|
room_count = room_count or resolve_int_range(world_config.size.rooms) or 0
|
||||||
|
|
||||||
broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme"))
|
broadcast_generated(message=format_prompt("world_generate_world_broadcast_theme"))
|
||||||
|
|
|
@ -27,7 +27,7 @@ except Exception as err:
|
||||||
|
|
||||||
logger = logger_with_colors(__name__) # , level="DEBUG")
|
logger = logger_with_colors(__name__) # , level="DEBUG")
|
||||||
|
|
||||||
load_dotenv(environ.get("ADVENTURE_ENV", ".env"), override=True)
|
load_dotenv(environ.get("TALEWEAVE_ENV", ".env"), override=True)
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
from taleweave.context import (
|
from taleweave.context import (
|
||||||
|
@ -35,6 +35,7 @@ if True:
|
||||||
get_system_data,
|
get_system_data,
|
||||||
set_current_turn,
|
set_current_turn,
|
||||||
set_dungeon_master,
|
set_dungeon_master,
|
||||||
|
set_game_config,
|
||||||
set_system_data,
|
set_system_data,
|
||||||
subscribe,
|
subscribe,
|
||||||
)
|
)
|
||||||
|
@ -312,6 +313,7 @@ def main():
|
||||||
if args.config:
|
if args.config:
|
||||||
with open(args.config, "r") as f:
|
with open(args.config, "r") as f:
|
||||||
config = Config(**load_yaml(f))
|
config = Config(**load_yaml(f))
|
||||||
|
set_game_config(config)
|
||||||
else:
|
else:
|
||||||
config = DEFAULT_CONFIG
|
config = DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from .base import IntRange, dataclass
|
from .base import Attributes, IntRange, dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -44,6 +44,15 @@ class ServerConfig:
|
||||||
websocket: WebsocketServerConfig
|
websocket: WebsocketServerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SystemsConfig:
|
||||||
|
"""
|
||||||
|
Configuration for the game systems.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: Attributes
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WorldCharacterConfig:
|
class WorldCharacterConfig:
|
||||||
conversation_limit: int
|
conversation_limit: int
|
||||||
|
@ -80,6 +89,7 @@ class Config:
|
||||||
bot: BotConfig
|
bot: BotConfig
|
||||||
render: RenderConfig
|
render: RenderConfig
|
||||||
server: ServerConfig
|
server: ServerConfig
|
||||||
|
systems: SystemsConfig
|
||||||
world: WorldConfig
|
world: WorldConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,6 +117,7 @@ DEFAULT_CONFIG = Config(
|
||||||
steps=30,
|
steps=30,
|
||||||
),
|
),
|
||||||
server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8001)),
|
server=ServerConfig(websocket=WebsocketServerConfig(host="localhost", port=8001)),
|
||||||
|
systems=SystemsConfig(data={}),
|
||||||
world=WorldConfig(
|
world=WorldConfig(
|
||||||
character=WorldCharacterConfig(
|
character=WorldCharacterConfig(
|
||||||
conversation_limit=2,
|
conversation_limit=2,
|
||||||
|
|
|
@ -60,6 +60,7 @@ class PromptEvent(BaseModel):
|
||||||
A prompt for a character to take an action.
|
A prompt for a character to take an action.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
actions: Dict[str, Any]
|
||||||
prompt: str
|
prompt: str
|
||||||
room: Room
|
room: Room
|
||||||
character: Character
|
character: Character
|
||||||
|
|
|
@ -197,15 +197,20 @@ class RemotePlayer(BasePlayer):
|
||||||
Ask the player for input.
|
Ask the player for input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
actions = {}
|
||||||
formatted_prompt = prompt.format(**kwargs)
|
formatted_prompt = prompt.format(**kwargs)
|
||||||
if toolbox:
|
if toolbox:
|
||||||
|
actions = toolbox.list_definitions()
|
||||||
formatted_prompt += self.format_psuedo_functions(toolbox)
|
formatted_prompt += self.format_psuedo_functions(toolbox)
|
||||||
|
|
||||||
self.memory.append(HumanMessage(content=formatted_prompt))
|
self.memory.append(HumanMessage(content=formatted_prompt))
|
||||||
|
|
||||||
with action_context() as (current_room, current_character):
|
with action_context() as (current_room, current_character):
|
||||||
prompt_event = PromptEvent(
|
prompt_event = PromptEvent(
|
||||||
prompt=formatted_prompt, room=current_room, character=current_character
|
actions=actions,
|
||||||
|
prompt=formatted_prompt,
|
||||||
|
room=current_room,
|
||||||
|
character=current_character,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -15,9 +15,9 @@ from fnvhash import fnv1a_32
|
||||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from taleweave.context import broadcast
|
from taleweave.context import broadcast, get_game_config
|
||||||
from taleweave.models.base import uuid
|
from taleweave.models.base import IntRange, uuid
|
||||||
from taleweave.models.config import DEFAULT_CONFIG, RenderConfig
|
from taleweave.models.config import RenderConfig
|
||||||
from taleweave.models.entity import WorldEntity
|
from taleweave.models.entity import WorldEntity
|
||||||
from taleweave.models.event import (
|
from taleweave.models.event import (
|
||||||
ActionEvent,
|
ActionEvent,
|
||||||
|
@ -36,7 +36,6 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
server_address = environ["COMFY_API"]
|
server_address = environ["COMFY_API"]
|
||||||
client_id = uuid()
|
client_id = uuid()
|
||||||
render_config: RenderConfig = DEFAULT_CONFIG.render
|
|
||||||
|
|
||||||
|
|
||||||
# requests to generate images for game events
|
# requests to generate images for game events
|
||||||
|
@ -44,12 +43,17 @@ render_queue: Queue[GameEvent | WorldEntity] = Queue()
|
||||||
render_thread: Thread | None = None
|
render_thread: Thread | None = None
|
||||||
|
|
||||||
|
|
||||||
def generate_cfg():
|
def get_render_config():
|
||||||
return resolve_int_range(render_config.cfg)
|
config = get_game_config()
|
||||||
|
return config.render
|
||||||
|
|
||||||
|
|
||||||
def generate_steps():
|
def generate_cfg(cfg: int | IntRange):
|
||||||
return resolve_int_range(render_config.steps)
|
return resolve_int_range(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_steps(steps: int | IntRange):
|
||||||
|
return resolve_int_range(steps)
|
||||||
|
|
||||||
|
|
||||||
def generate_batches(
|
def generate_batches(
|
||||||
|
@ -148,9 +152,10 @@ def generate_image_tool(prompt, count, size="landscape"):
|
||||||
def generate_images(
|
def generate_images(
|
||||||
prompt: str, count: int, size="landscape", prefix="output"
|
prompt: str, count: int, size="landscape", prefix="output"
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
cfg = generate_cfg()
|
render_config = get_render_config()
|
||||||
|
cfg = generate_cfg(render_config.cfg)
|
||||||
dims = render_config.sizes[size]
|
dims = render_config.sizes[size]
|
||||||
steps = generate_steps()
|
steps = generate_steps(render_config.steps)
|
||||||
seed = randint(0, 10000000)
|
seed = randint(0, 10000000)
|
||||||
checkpoint = choice(render_config.checkpoints)
|
checkpoint = choice(render_config.checkpoints)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -248,6 +253,8 @@ def get_image_prefix(event: GameEvent | WorldEntity) -> str:
|
||||||
|
|
||||||
|
|
||||||
def render_loop():
|
def render_loop():
|
||||||
|
render_config = get_render_config()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
event = render_queue.get()
|
event = render_queue.get()
|
||||||
prefix = get_image_prefix(event)
|
prefix = get_image_prefix(event)
|
||||||
|
@ -317,13 +324,8 @@ def render_generated(event: GameEvent):
|
||||||
|
|
||||||
|
|
||||||
def launch_render(config: RenderConfig):
|
def launch_render(config: RenderConfig):
|
||||||
global render_config
|
|
||||||
global render_thread
|
global render_thread
|
||||||
|
|
||||||
# update the config
|
|
||||||
logger.info("updating render config: %s", config)
|
|
||||||
render_config = config
|
|
||||||
|
|
||||||
# start the render thread
|
# start the render thread
|
||||||
logger.info("launching render thread")
|
logger.info("launching render thread")
|
||||||
render_thread = Thread(target=render_loop, daemon=True)
|
render_thread = Thread(target=render_loop, daemon=True)
|
||||||
|
|
|
@ -16,10 +16,11 @@ from taleweave.context import (
|
||||||
broadcast,
|
broadcast,
|
||||||
get_character_agent_for_name,
|
get_character_agent_for_name,
|
||||||
get_current_world,
|
get_current_world,
|
||||||
|
get_game_config,
|
||||||
set_character_agent,
|
set_character_agent,
|
||||||
subscribe,
|
subscribe,
|
||||||
)
|
)
|
||||||
from taleweave.models.config import DEFAULT_CONFIG, WebsocketServerConfig
|
from taleweave.models.config import WebsocketServerConfig
|
||||||
from taleweave.models.entity import World, WorldEntity
|
from taleweave.models.entity import World, WorldEntity
|
||||||
from taleweave.models.event import (
|
from taleweave.models.event import (
|
||||||
GameEvent,
|
GameEvent,
|
||||||
|
@ -47,7 +48,6 @@ last_snapshot: str | None = None
|
||||||
player_names: Dict[str, str] = {}
|
player_names: Dict[str, str] = {}
|
||||||
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
|
recent_events: MutableSequence[GameEvent] = deque(maxlen=100)
|
||||||
recent_json: MutableSequence[str] = deque(maxlen=100)
|
recent_json: MutableSequence[str] = deque(maxlen=100)
|
||||||
server_config: WebsocketServerConfig = DEFAULT_CONFIG.server.websocket
|
|
||||||
|
|
||||||
|
|
||||||
def get_player_name(client_id: str) -> str:
|
def get_player_name(client_id: str) -> str:
|
||||||
|
@ -59,16 +59,15 @@ async def handler(websocket):
|
||||||
logger.info("client connected, given id: %s", id)
|
logger.info("client connected, given id: %s", id)
|
||||||
connected.add(websocket)
|
connected.add(websocket)
|
||||||
|
|
||||||
async def next_turn(character: str, prompt: str) -> None:
|
async def next_turn(event: PromptEvent) -> None:
|
||||||
await websocket.send(
|
await websocket.send(
|
||||||
dumps(
|
dumps(
|
||||||
{
|
{
|
||||||
# TODO: these should be fields in the PromptEvent
|
"type": event.type,
|
||||||
"type": "prompt",
|
"client": id, # TODO: this should be a field in the PromptEvent
|
||||||
"client": id,
|
"character": event.character,
|
||||||
"character": character,
|
"prompt": event.prompt,
|
||||||
"prompt": prompt,
|
"actions": event.actions,
|
||||||
"actions": [],
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -77,7 +76,7 @@ async def handler(websocket):
|
||||||
# TODO: nothing about this is good
|
# TODO: nothing about this is good
|
||||||
player = get_player(id)
|
player = get_player(id)
|
||||||
if player and player.name == event.character.name:
|
if player and player.name == event.character.name:
|
||||||
asyncio.run(next_turn(event.character.name, event.prompt))
|
asyncio.run(next_turn(event))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
@ -303,10 +302,6 @@ def send_and_append(id: str, message: Dict):
|
||||||
|
|
||||||
def launch_server(config: WebsocketServerConfig):
|
def launch_server(config: WebsocketServerConfig):
|
||||||
global socket_thread
|
global socket_thread
|
||||||
global server_config
|
|
||||||
|
|
||||||
logger.info("configuring websocket server: %s", config)
|
|
||||||
server_config = config
|
|
||||||
|
|
||||||
def run_sockets():
|
def run_sockets():
|
||||||
asyncio.run(server_main())
|
asyncio.run(server_main())
|
||||||
|
@ -321,7 +316,11 @@ def launch_server(config: WebsocketServerConfig):
|
||||||
|
|
||||||
|
|
||||||
async def server_main():
|
async def server_main():
|
||||||
async with websockets.serve(handler, server_config.host, server_config.port):
|
config = get_game_config()
|
||||||
|
|
||||||
|
async with websockets.serve(
|
||||||
|
handler, config.server.websocket.host, config.server.websocket.port
|
||||||
|
):
|
||||||
logger.info("websocket server started")
|
logger.info("websocket server started")
|
||||||
await asyncio.Future() # run forever
|
await asyncio.Future() # run forever
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ from taleweave.context import (
|
||||||
get_character_for_agent,
|
get_character_for_agent,
|
||||||
get_current_turn,
|
get_current_turn,
|
||||||
get_current_world,
|
get_current_world,
|
||||||
|
get_game_config,
|
||||||
get_prompt,
|
get_prompt,
|
||||||
set_current_character,
|
set_current_character,
|
||||||
set_current_room,
|
set_current_room,
|
||||||
|
@ -44,7 +45,6 @@ from taleweave.context import (
|
||||||
set_game_systems,
|
set_game_systems,
|
||||||
)
|
)
|
||||||
from taleweave.game_system import GameSystem
|
from taleweave.game_system import GameSystem
|
||||||
from taleweave.models.config import DEFAULT_CONFIG
|
|
||||||
from taleweave.models.entity import Character, Room, World
|
from taleweave.models.entity import Character, Room, World
|
||||||
from taleweave.models.event import ActionEvent, ResultEvent
|
from taleweave.models.event import ActionEvent, ResultEvent
|
||||||
from taleweave.utils.conversation import make_keyword_condition, summarize_room
|
from taleweave.utils.conversation import make_keyword_condition, summarize_room
|
||||||
|
@ -57,9 +57,6 @@ from taleweave.utils.world import describe_entity, format_attributes
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
turn_config = DEFAULT_CONFIG.world.turn
|
|
||||||
|
|
||||||
|
|
||||||
def world_result_parser(value, agent, **kwargs):
|
def world_result_parser(value, agent, **kwargs):
|
||||||
current_world = get_current_world()
|
current_world = get_current_world()
|
||||||
if not current_world:
|
if not current_world:
|
||||||
|
@ -121,10 +118,11 @@ def prompt_character_action(
|
||||||
event = ActionEvent.from_json(value, room, character)
|
event = ActionEvent.from_json(value, room, character)
|
||||||
else:
|
else:
|
||||||
# TODO: this path should be removed and throw
|
# TODO: this path should be removed and throw
|
||||||
logger.warning(
|
# logger.warning(
|
||||||
"invalid action, emitting as result event - this is a bug somewhere"
|
# "invalid action, emitting as result event - this is a bug somewhere"
|
||||||
)
|
# )
|
||||||
event = ResultEvent(value, room, character)
|
# event = ResultEvent(value, room, character)
|
||||||
|
raise ValueError("invalid non-JSON action")
|
||||||
|
|
||||||
broadcast(event)
|
broadcast(event)
|
||||||
|
|
||||||
|
@ -198,7 +196,8 @@ def prompt_character_planning(
|
||||||
current_turn: int,
|
current_turn: int,
|
||||||
max_steps: int | None = None,
|
max_steps: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
max_steps = max_steps or turn_config.planning_steps
|
config = get_game_config()
|
||||||
|
max_steps = max_steps or config.world.turn.planning_steps
|
||||||
|
|
||||||
notes_prompt, events_prompt = get_notes_events(character, current_turn)
|
notes_prompt, events_prompt = get_notes_events(character, current_turn)
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,7 @@ from packit.conditions import condition_or, condition_threshold, make_flag_condi
|
||||||
from packit.results import multi_function_or_str_result
|
from packit.results import multi_function_or_str_result
|
||||||
from packit.utils import could_be_json
|
from packit.utils import could_be_json
|
||||||
|
|
||||||
from taleweave.context import broadcast
|
from taleweave.context import broadcast, get_game_config
|
||||||
from taleweave.models.config import DEFAULT_CONFIG
|
|
||||||
from taleweave.models.entity import Character, Room
|
from taleweave.models.entity import Character, Room
|
||||||
from taleweave.models.event import ReplyEvent
|
from taleweave.models.event import ReplyEvent
|
||||||
from taleweave.utils.prompt import format_str
|
from taleweave.utils.prompt import format_str
|
||||||
|
@ -19,9 +18,6 @@ from .string import and_list, normalize_name
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
character_config = DEFAULT_CONFIG.world.character
|
|
||||||
|
|
||||||
|
|
||||||
def make_keyword_condition(end_message: str, keywords=["end", "stop"]):
|
def make_keyword_condition(end_message: str, keywords=["end", "stop"]):
|
||||||
set_end, condition_end = make_flag_condition()
|
set_end, condition_end = make_flag_condition()
|
||||||
|
|
||||||
|
@ -99,8 +95,10 @@ def loop_conversation(
|
||||||
Loop through a conversation between a series of agents, using metadata from their characters.
|
Loop through a conversation between a series of agents, using metadata from their characters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config = get_game_config()
|
||||||
|
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = character_config.conversation_limit
|
max_length = config.world.character.conversation_limit
|
||||||
|
|
||||||
if len(characters) != len(agents):
|
if len(characters) != len(agents):
|
||||||
raise ValueError("The number of characters and agents must match.")
|
raise ValueError("The number of characters and agents must match.")
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from taleweave.models.entity import Room, World
|
||||||
|
from taleweave.utils.search import find_room
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindRoom(TestCase):
|
||||||
|
def test_existing_room(self):
|
||||||
|
world = World(name="Test World", rooms=[], theme="testing", order=[])
|
||||||
|
room = Room(
|
||||||
|
name="Test Room",
|
||||||
|
description="A test room.",
|
||||||
|
characters=[],
|
||||||
|
items=[],
|
||||||
|
portals=[],
|
||||||
|
)
|
||||||
|
world.rooms.append(room)
|
||||||
|
self.assertEqual(find_room(world, "Test Room"), room)
|
||||||
|
|
||||||
|
def test_missing_room(self):
|
||||||
|
world = World(name="Test World", rooms=[], theme="testing", order=[])
|
||||||
|
room = Room(
|
||||||
|
name="Test Room",
|
||||||
|
description="A test room.",
|
||||||
|
characters=[],
|
||||||
|
items=[],
|
||||||
|
portals=[],
|
||||||
|
)
|
||||||
|
world.rooms.append(room)
|
||||||
|
self.assertEqual(find_room(world, "Missing Room"), None)
|
Loading…
Reference in New Issue