make memory configurable, consistently truncate discord messages, fix action prompt
This commit is contained in:
parent
a9705727f0
commit
f25dd57e67
|
@ -69,7 +69,7 @@ export interface StringParameter {
|
|||
export interface NumberParameter {
|
||||
type: 'number';
|
||||
default?: number;
|
||||
enum?: Array<string>;
|
||||
enum?: Array<number>;
|
||||
}
|
||||
|
||||
export type Parameter = BooleanParameter | NumberParameter | StringParameter;
|
||||
|
|
|
@ -129,7 +129,7 @@ export function enumerateSignificantParameterValues(name: string, world: World)
|
|||
}
|
||||
}
|
||||
|
||||
export function convertSignificantParameter(name: string, parameter: Parameter, world: Maybe<World>): Parameter {
|
||||
export function convertSignificantParameter<T extends Parameter>(name: string, parameter: T, world: Maybe<World>): T {
|
||||
if (parameter.type === 'boolean') {
|
||||
return parameter;
|
||||
}
|
||||
|
@ -154,15 +154,27 @@ export function formatAction(action: string, parameters: Record<string, boolean
|
|||
return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`;
|
||||
}
|
||||
|
||||
export function getEnumOrDefault<T>(defaultValue: Maybe<T>, enumValues: Maybe<Array<T>>, evenMoreDefault: T): T {
|
||||
if (doesExist(defaultValue)) {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
if (doesExist(enumValues)) {
|
||||
return enumValues[0];
|
||||
}
|
||||
|
||||
return evenMoreDefault;
|
||||
}
|
||||
|
||||
export function makeDefaultParameterValues(parameters: Record<string, Parameter>) {
|
||||
return Object.entries(parameters).reduce((acc, [name, parameter]) => {
|
||||
switch (parameter.type) {
|
||||
case 'boolean':
|
||||
return { ...acc, [name]: mustDefault(parameter.default, false) };
|
||||
return { ...acc, [name]: getEnumOrDefault(parameter.default, [], false) };
|
||||
case 'number':
|
||||
return { ...acc, [name]: mustDefault(parameter.default, 0) };
|
||||
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, 0) };
|
||||
case 'string':
|
||||
return { ...acc, [name]: mustDefault(parameter.default, '') };
|
||||
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, '') };
|
||||
default:
|
||||
return acc;
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ from taleweave.context import (
|
|||
broadcast,
|
||||
get_agent_for_character,
|
||||
get_character_agent_for_name,
|
||||
get_game_config,
|
||||
get_prompt,
|
||||
world_context,
|
||||
)
|
||||
|
@ -22,8 +23,6 @@ from taleweave.utils.string import normalize_name
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
MAX_CONVERSATION_STEPS = 2
|
||||
|
||||
|
||||
def action_examine(target: str) -> str:
|
||||
"""
|
||||
|
@ -173,7 +172,8 @@ def action_ask(character: str, question: str) -> str:
|
|||
character: The name of the character to ask. You cannot ask yourself questions.
|
||||
question: The question to ask them.
|
||||
"""
|
||||
# capture references to the current character and room, because they will be overwritten
|
||||
config = get_game_config()
|
||||
|
||||
with action_context() as (action_room, action_character):
|
||||
# sanity checks
|
||||
question_character, question_agent = get_character_agent_for_name(character)
|
||||
|
@ -216,7 +216,7 @@ def action_ask(character: str, question: str) -> str:
|
|||
end_prompt,
|
||||
echo_function=action_tell.__name__,
|
||||
echo_parameter="message",
|
||||
max_length=MAX_CONVERSATION_STEPS,
|
||||
max_length=config.world.character.conversation_limit,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
@ -233,7 +233,7 @@ def action_tell(character: str, message: str) -> str:
|
|||
character: The name of the character to tell. You cannot talk to yourself.
|
||||
message: The message to tell them.
|
||||
"""
|
||||
# capture references to the current character and room, because they will be overwritten
|
||||
config = get_game_config()
|
||||
|
||||
with action_context() as (action_room, action_character):
|
||||
# sanity checks
|
||||
|
@ -268,7 +268,7 @@ def action_tell(character: str, message: str) -> str:
|
|||
end_prompt,
|
||||
echo_function=action_tell.__name__,
|
||||
echo_parameter="message",
|
||||
max_length=MAX_CONVERSATION_STEPS,
|
||||
max_length=config.world.character.conversation_limit,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
|
|
@ -323,6 +323,12 @@ async def broadcast_event(message: str | GameEvent):
|
|||
event_messages[event_message.id] = message
|
||||
|
||||
|
||||
def truncate(text: str, length: int = 1000) -> str:
|
||||
if len(text) > length:
|
||||
return text[:length] + "..."
|
||||
return text
|
||||
|
||||
|
||||
def embed_from_event(event: GameEvent) -> Embed | None:
|
||||
if isinstance(event, GenerateEvent):
|
||||
return embed_from_generate(event)
|
||||
|
@ -357,7 +363,7 @@ def embed_from_action(event: ActionEvent):
|
|||
|
||||
def embed_from_reply(event: ReplyEvent):
|
||||
reply_embed = Embed(title=event.room.name, description=event.speaker.name)
|
||||
reply_embed.add_field(name="Reply", value=event.text)
|
||||
reply_embed.add_field(name="Reply", value=truncate(event.text))
|
||||
return reply_embed
|
||||
|
||||
|
||||
|
@ -367,12 +373,8 @@ def embed_from_generate(event: GenerateEvent) -> Embed:
|
|||
|
||||
|
||||
def embed_from_result(event: ResultEvent):
|
||||
text = event.result
|
||||
if len(text) > 1000:
|
||||
text = text[:1000] + "..."
|
||||
|
||||
result_embed = Embed(title=event.room.name, description=event.character.name)
|
||||
result_embed.add_field(name="Result", value=text)
|
||||
result_embed.add_field(name="Result", value=truncate(event.result))
|
||||
return result_embed
|
||||
|
||||
|
||||
|
@ -384,14 +386,14 @@ def embed_from_player(event: PlayerEvent):
|
|||
title = format_prompt("discord_leave_title", event=event)
|
||||
description = format_prompt("discord_leave_result", event=event)
|
||||
|
||||
player_embed = Embed(title=title, description=description)
|
||||
player_embed = Embed(title=title, description=truncate(description))
|
||||
return player_embed
|
||||
|
||||
|
||||
def embed_from_prompt(event: PromptEvent):
|
||||
# TODO: ping the player
|
||||
prompt_embed = Embed(title=event.room.name, description=event.character.name)
|
||||
prompt_embed.add_field(name="Prompt", value=event.prompt)
|
||||
prompt_embed.add_field(name="Prompt", value=truncate(event.prompt))
|
||||
return prompt_embed
|
||||
|
||||
|
||||
|
@ -400,5 +402,5 @@ def embed_from_status(event: StatusEvent):
|
|||
title=event.room.name if event.room else "",
|
||||
description=event.character.name if event.character else "",
|
||||
)
|
||||
status_embed.add_field(name="Status", value=event.text)
|
||||
status_embed.add_field(name="Status", value=truncate(event.text))
|
||||
return status_embed
|
||||
|
|
|
@ -416,7 +416,7 @@ def main():
|
|||
set_dungeon_master(world_builder)
|
||||
|
||||
# start the sim
|
||||
logger.debug("simulating world: %s", world)
|
||||
logger.debug("simulating world: %s", world.name)
|
||||
simulate_world(
|
||||
world,
|
||||
turns=args.turns,
|
||||
|
|
|
@ -44,6 +44,7 @@ from taleweave.context import (
|
|||
set_current_world,
|
||||
set_game_systems,
|
||||
)
|
||||
from taleweave.errors import ActionError
|
||||
from taleweave.game_system import GameSystem
|
||||
from taleweave.models.entity import Character, Room, World
|
||||
from taleweave.models.event import ActionEvent, ResultEvent
|
||||
|
@ -117,12 +118,9 @@ def prompt_character_action(
|
|||
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
|
||||
event = ActionEvent.from_json(value, room, character)
|
||||
else:
|
||||
# TODO: this path should be removed and throw
|
||||
# logger.warning(
|
||||
# "invalid action, emitting as result event - this is a bug somewhere"
|
||||
# )
|
||||
# event = ResultEvent(value, room, character)
|
||||
raise ValueError("invalid non-JSON action")
|
||||
raise ActionError(
|
||||
"Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format."
|
||||
)
|
||||
|
||||
broadcast(event)
|
||||
|
||||
|
@ -216,14 +214,14 @@ def prompt_character_planning(
|
|||
while not stop_condition(current=i):
|
||||
result = loop_retry(
|
||||
agent,
|
||||
get_prompt("world_simulate_character_planning"),
|
||||
context={
|
||||
"event_count": event_count,
|
||||
"events_prompt": events_prompt,
|
||||
"note_count": note_count,
|
||||
"notes_prompt": notes_prompt,
|
||||
"room_summary": summarize_room(room, character),
|
||||
},
|
||||
format_prompt(
|
||||
"world_simulate_character_planning",
|
||||
event_count=event_count,
|
||||
events_prompt=events_prompt,
|
||||
note_count=note_count,
|
||||
notes_prompt=notes_prompt,
|
||||
room_summary=summarize_room(room, character),
|
||||
),
|
||||
result_parser=result_parser,
|
||||
stop_condition=stop_condition,
|
||||
toolbox=planner_toolbox,
|
||||
|
|
|
@ -7,12 +7,14 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System
|
|||
from packit.agent import Agent, agent_easy_connect
|
||||
from pydantic import RootModel
|
||||
|
||||
from taleweave.context import get_all_character_agents, set_character_agent
|
||||
from taleweave.context import (
|
||||
get_all_character_agents,
|
||||
get_game_config,
|
||||
set_character_agent,
|
||||
)
|
||||
from taleweave.models.entity import World
|
||||
from taleweave.player import LocalPlayer
|
||||
|
||||
MEMORY_LIMIT = 25 # 10
|
||||
|
||||
|
||||
def create_agents(
|
||||
world: World,
|
||||
|
@ -69,6 +71,7 @@ def snapshot_world(world: World, turn: int):
|
|||
def restore_memory(
|
||||
data: Sequence[str | Dict[str, str]]
|
||||
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
|
||||
config = get_game_config()
|
||||
memories = []
|
||||
|
||||
for memory in data:
|
||||
|
@ -85,7 +88,7 @@ def restore_memory(
|
|||
elif memory_type == "ai":
|
||||
memories.append(AIMessage(content=memory_content))
|
||||
|
||||
return deque(memories, maxlen=MEMORY_LIMIT)
|
||||
return deque(memories, maxlen=config.world.character.memory_limit)
|
||||
|
||||
|
||||
def save_world(world, filename):
|
||||
|
|
|
@ -5,6 +5,7 @@ from taleweave.context import get_current_world, get_prompt_library, subscribe
|
|||
from taleweave.game_system import FormatPerspective, GameSystem
|
||||
from taleweave.models.entity import Character, Room, World, WorldEntity
|
||||
from taleweave.models.event import ActionEvent, GameEvent
|
||||
from taleweave.utils.prompt import format_str
|
||||
from taleweave.utils.search import find_containing_room
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -22,7 +23,7 @@ def create_turn_digest(
|
|||
if prompt_key in library.prompts:
|
||||
try:
|
||||
template = library.prompts[prompt_key]
|
||||
message = template.format(event=event)
|
||||
message = format_str(template, event=event)
|
||||
messages.append(message)
|
||||
except Exception:
|
||||
logger.exception("error formatting digest event: %s", event)
|
||||
|
|
|
@ -3,10 +3,17 @@ from logging import getLogger
|
|||
from jinja2 import Environment
|
||||
|
||||
from taleweave.context import get_prompt_library
|
||||
from taleweave.utils.string import and_list, or_list
|
||||
from taleweave.utils.world import describe_entity, name_entity
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
jinja_env = Environment()
|
||||
jinja_env.filters["describe"] = describe_entity
|
||||
jinja_env.filters["name"] = name_entity
|
||||
jinja_env.filters["and_list"] = and_list
|
||||
jinja_env.filters["or_list"] = or_list
|
||||
|
||||
|
||||
def format_prompt(prompt_key: str, **kwargs) -> str:
|
||||
try:
|
||||
|
@ -19,9 +26,5 @@ def format_prompt(prompt_key: str, **kwargs) -> str:
|
|||
|
||||
|
||||
def format_str(template_str: str, **kwargs) -> str:
|
||||
env = Environment()
|
||||
env.filters["describe"] = describe_entity
|
||||
env.filters["name"] = name_entity
|
||||
|
||||
template = env.from_string(template_str)
|
||||
template = jinja_env.from_string(template_str)
|
||||
return template.render(**kwargs)
|
||||
|
|
Loading…
Reference in New Issue