add missing fields to config, use config memory limit
This commit is contained in:
parent
f25dd57e67
commit
0927259ece
15
config.yml
15
config.yml
|
@ -1,6 +1,9 @@
|
||||||
bot:
|
bot:
|
||||||
discord:
|
discord:
|
||||||
channels: [taleweave]
|
channels: [taleweave]
|
||||||
|
command_prefix: "!"
|
||||||
|
name_command: taleweave
|
||||||
|
name_title: TaleWeave AI
|
||||||
render:
|
render:
|
||||||
cfg:
|
cfg:
|
||||||
min: 5
|
min: 5
|
||||||
|
@ -8,6 +11,7 @@ render:
|
||||||
checkpoints: [
|
checkpoints: [
|
||||||
"diffusion-sdxl-dynavision-0-5-5-7.safetensors",
|
"diffusion-sdxl-dynavision-0-5-5-7.safetensors",
|
||||||
]
|
]
|
||||||
|
count: 2
|
||||||
path: /tmp/taleweave-images
|
path: /tmp/taleweave-images
|
||||||
sizes:
|
sizes:
|
||||||
landscape:
|
landscape:
|
||||||
|
@ -26,7 +30,14 @@ server:
|
||||||
websocket:
|
websocket:
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
port: 8001
|
port: 8001
|
||||||
|
systems:
|
||||||
|
data: {}
|
||||||
world:
|
world:
|
||||||
|
character:
|
||||||
|
conversation_limit: 2
|
||||||
|
event_limit: 5
|
||||||
|
memory_limit: 25
|
||||||
|
note_limit: 10
|
||||||
size:
|
size:
|
||||||
character_items:
|
character_items:
|
||||||
min: 0
|
min: 0
|
||||||
|
@ -46,3 +57,7 @@ world:
|
||||||
room_items:
|
room_items:
|
||||||
min: 0
|
min: 0
|
||||||
max: 3
|
max: 3
|
||||||
|
turn:
|
||||||
|
action_retries: 5
|
||||||
|
planning_retries: 3
|
||||||
|
planning_steps: 3
|
|
@ -31,6 +31,7 @@ from taleweave.player import (
|
||||||
RemotePlayer,
|
RemotePlayer,
|
||||||
get_player,
|
get_player,
|
||||||
has_player,
|
has_player,
|
||||||
|
list_players,
|
||||||
remove_player,
|
remove_player,
|
||||||
set_player,
|
set_player,
|
||||||
)
|
)
|
||||||
|
@ -50,7 +51,7 @@ def remove_tags(text: str) -> str:
|
||||||
Remove any <foo> tags.
|
Remove any <foo> tags.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return sub(r"<[^>]*>", "", text)
|
return sub(r"<[^>]*>", "", text).strip()
|
||||||
|
|
||||||
|
|
||||||
class AdventureClient(Client):
|
class AdventureClient(Client):
|
||||||
|
@ -82,8 +83,9 @@ class AdventureClient(Client):
|
||||||
author = message.author
|
author = message.author
|
||||||
channel = message.channel
|
channel = message.channel
|
||||||
user_name = author.name # include nick
|
user_name = author.name # include nick
|
||||||
|
content = remove_tags(message.content)
|
||||||
|
|
||||||
if message.content.startswith(
|
if content.startswith(
|
||||||
config.bot.discord.command_prefix + config.bot.discord.name_command
|
config.bot.discord.command_prefix + config.bot.discord.name_command
|
||||||
):
|
):
|
||||||
world = get_current_world()
|
world = get_current_world()
|
||||||
|
@ -101,14 +103,14 @@ class AdventureClient(Client):
|
||||||
await message.channel.send(world_message)
|
await message.channel.send(world_message)
|
||||||
return
|
return
|
||||||
|
|
||||||
if message.content.startswith("!help"):
|
if content.startswith("!help"):
|
||||||
await message.channel.send(
|
await message.channel.send(
|
||||||
format_prompt("discord_help", bot_name=config.bot.discord.name_command)
|
format_prompt("discord_help", bot_name=config.bot.discord.name_command)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if message.content.startswith("!join"):
|
if content.startswith("!join"):
|
||||||
character_name = remove_tags(message.content).replace("!join", "").strip()
|
character_name = content.replace("!join", "").strip()
|
||||||
if has_player(character_name):
|
if has_player(character_name):
|
||||||
await channel.send(
|
await channel.send(
|
||||||
format_prompt("discord_join_error_taken", character=character_name)
|
format_prompt("discord_join_error_taken", character=character_name)
|
||||||
|
@ -145,9 +147,14 @@ class AdventureClient(Client):
|
||||||
join_event = PlayerEvent("join", character_name, user_name)
|
join_event = PlayerEvent("join", character_name, user_name)
|
||||||
return broadcast(join_event)
|
return broadcast(join_event)
|
||||||
|
|
||||||
|
if content.startswith("!players"):
|
||||||
|
players = list_players()
|
||||||
|
await channel.send(embed=format_players(players))
|
||||||
|
return
|
||||||
|
|
||||||
player = get_player(user_name)
|
player = get_player(user_name)
|
||||||
if isinstance(player, RemotePlayer):
|
if isinstance(player, RemotePlayer):
|
||||||
if message.content.startswith("!leave"):
|
if content.startswith("!leave"):
|
||||||
remove_player(user_name)
|
remove_player(user_name)
|
||||||
|
|
||||||
# revert to LLM agent
|
# revert to LLM agent
|
||||||
|
@ -163,7 +170,6 @@ class AdventureClient(Client):
|
||||||
leave_event = PlayerEvent("leave", player.name, user_name)
|
leave_event = PlayerEvent("leave", player.name, user_name)
|
||||||
return broadcast(leave_event)
|
return broadcast(leave_event)
|
||||||
else:
|
else:
|
||||||
content = remove_tags(message.content)
|
|
||||||
player.input_queue.put(content)
|
player.input_queue.put(content)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"received message from {user_name} for {player.name}: {content}"
|
f"received message from {user_name} for {player.name}: {content}"
|
||||||
|
@ -174,6 +180,14 @@ class AdventureClient(Client):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def format_players(players: Dict[str, str]):
|
||||||
|
player_embed = Embed(title="Players")
|
||||||
|
for player, character in players.items():
|
||||||
|
player_embed.add_field(name=player, value=character)
|
||||||
|
|
||||||
|
return player_embed
|
||||||
|
|
||||||
|
|
||||||
def launch_bot(config: DiscordBotConfig):
|
def launch_bot(config: DiscordBotConfig):
|
||||||
global client
|
global client
|
||||||
|
|
||||||
|
|
|
@ -48,12 +48,7 @@ if True:
|
||||||
from taleweave.models.prompt import PromptLibrary
|
from taleweave.models.prompt import PromptLibrary
|
||||||
from taleweave.plugins import load_plugin
|
from taleweave.plugins import load_plugin
|
||||||
from taleweave.simulate import simulate_world
|
from taleweave.simulate import simulate_world
|
||||||
from taleweave.state import (
|
from taleweave.state import create_agents, save_world, save_world_state
|
||||||
MEMORY_LIMIT,
|
|
||||||
create_agents,
|
|
||||||
save_world,
|
|
||||||
save_world_state,
|
|
||||||
)
|
|
||||||
from taleweave.utils.prompt import format_prompt
|
from taleweave.utils.prompt import format_prompt
|
||||||
|
|
||||||
# start the debugger, if needed
|
# start the debugger, if needed
|
||||||
|
@ -65,9 +60,6 @@ if environ.get("DEBUG", "false").lower() == "true":
|
||||||
debugpy.wait_for_client()
|
debugpy.wait_for_client()
|
||||||
|
|
||||||
|
|
||||||
memory_factory = partial(make_limited_memory, limit=MEMORY_LIMIT)
|
|
||||||
|
|
||||||
|
|
||||||
def int_or_inf(value: str) -> float | int:
|
def int_or_inf(value: str) -> float | int:
|
||||||
if value == "inf":
|
if value == "inf":
|
||||||
return float("inf")
|
return float("inf")
|
||||||
|
@ -241,7 +233,7 @@ def save_system_data(args, systems: List[GameSystem]):
|
||||||
|
|
||||||
|
|
||||||
def load_or_generate_world(
|
def load_or_generate_world(
|
||||||
args, players, systems: List[GameSystem], world_prompt: WorldPrompt
|
args, config: Config, players, systems: List[GameSystem], world_prompt: WorldPrompt
|
||||||
):
|
):
|
||||||
world_file = args.world + ".json"
|
world_file = args.world + ".json"
|
||||||
world_state_file = args.state or (args.world + ".state.json")
|
world_state_file = args.state or (args.world + ".state.json")
|
||||||
|
@ -252,6 +244,9 @@ def load_or_generate_world(
|
||||||
|
|
||||||
# prepare an agent for the world builder
|
# prepare an agent for the world builder
|
||||||
llm = agent_easy_connect()
|
llm = agent_easy_connect()
|
||||||
|
memory_factory = partial(
|
||||||
|
make_limited_memory, limit=config.world.character.memory_limit
|
||||||
|
)
|
||||||
world_builder = Agent(
|
world_builder = Agent(
|
||||||
"World Builder",
|
"World Builder",
|
||||||
format_prompt(
|
format_prompt(
|
||||||
|
@ -388,7 +383,7 @@ def main():
|
||||||
# load or generate the world
|
# load or generate the world
|
||||||
world_prompt = get_world_prompt(args)
|
world_prompt = get_world_prompt(args)
|
||||||
world, world_state_file, world_turn = load_or_generate_world(
|
world, world_state_file, world_turn = load_or_generate_world(
|
||||||
args, players, extra_systems, world_prompt=world_prompt
|
args, config, players, extra_systems, world_prompt=world_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure the snapshot system runs last
|
# make sure the snapshot system runs last
|
||||||
|
@ -404,6 +399,9 @@ def main():
|
||||||
|
|
||||||
# create the DM
|
# create the DM
|
||||||
llm = agent_easy_connect()
|
llm = agent_easy_connect()
|
||||||
|
memory_factory = partial(
|
||||||
|
make_limited_memory, limit=config.world.character.memory_limit
|
||||||
|
)
|
||||||
world_builder = Agent(
|
world_builder = Agent(
|
||||||
"dungeon master",
|
"dungeon master",
|
||||||
format_prompt(
|
format_prompt(
|
||||||
|
|
|
@ -57,6 +57,7 @@ class SystemsConfig:
|
||||||
class WorldCharacterConfig:
|
class WorldCharacterConfig:
|
||||||
conversation_limit: int
|
conversation_limit: int
|
||||||
event_limit: int
|
event_limit: int
|
||||||
|
memory_limit: int
|
||||||
note_limit: int
|
note_limit: int
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,6 +123,7 @@ DEFAULT_CONFIG = Config(
|
||||||
character=WorldCharacterConfig(
|
character=WorldCharacterConfig(
|
||||||
conversation_limit=2,
|
conversation_limit=2,
|
||||||
event_limit=5,
|
event_limit=5,
|
||||||
|
memory_limit=25,
|
||||||
note_limit=10,
|
note_limit=10,
|
||||||
),
|
),
|
||||||
size=WorldSizeConfig(
|
size=WorldSizeConfig(
|
||||||
|
|
Loading…
Reference in New Issue