1
0
Fork 0

make memory configurable, consistently truncate discord messages, fix action prompt
Run Docker Build / build (push) Successful in 17s Details
Run Python Build / build (push) Successful in 27s Details

This commit is contained in:
Sean Sube 2024-06-02 20:00:17 -05:00
parent a9705727f0
commit f25dd57e67
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 64 additions and 45 deletions

View File

@ -69,7 +69,7 @@ export interface StringParameter {
export interface NumberParameter {
type: 'number';
default?: number;
enum?: Array<string>;
enum?: Array<number>;
}
export type Parameter = BooleanParameter | NumberParameter | StringParameter;

View File

@ -129,7 +129,7 @@ export function enumerateSignificantParameterValues(name: string, world: World)
}
}
export function convertSignificantParameter(name: string, parameter: Parameter, world: Maybe<World>): Parameter {
export function convertSignificantParameter<T extends Parameter>(name: string, parameter: T, world: Maybe<World>): T {
if (parameter.type === 'boolean') {
return parameter;
}
@ -154,15 +154,27 @@ export function formatAction(action: string, parameters: Record<string, boolean
return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`;
}
export function getEnumOrDefault<T>(defaultValue: Maybe<T>, enumValues: Maybe<Array<T>>, evenMoreDefault: T): T {
if (doesExist(defaultValue)) {
return defaultValue;
}
if (doesExist(enumValues)) {
return enumValues[0];
}
return evenMoreDefault;
}
export function makeDefaultParameterValues(parameters: Record<string, Parameter>) {
return Object.entries(parameters).reduce((acc, [name, parameter]) => {
switch (parameter.type) {
case 'boolean':
return { ...acc, [name]: mustDefault(parameter.default, false) };
return { ...acc, [name]: getEnumOrDefault(parameter.default, [], false) };
case 'number':
return { ...acc, [name]: mustDefault(parameter.default, 0) };
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, 0) };
case 'string':
return { ...acc, [name]: mustDefault(parameter.default, '') };
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, '') };
default:
return acc;
}

View File

@ -5,6 +5,7 @@ from taleweave.context import (
broadcast,
get_agent_for_character,
get_character_agent_for_name,
get_game_config,
get_prompt,
world_context,
)
@ -22,8 +23,6 @@ from taleweave.utils.string import normalize_name
logger = getLogger(__name__)
MAX_CONVERSATION_STEPS = 2
def action_examine(target: str) -> str:
"""
@ -173,7 +172,8 @@ def action_ask(character: str, question: str) -> str:
character: The name of the character to ask. You cannot ask yourself questions.
question: The question to ask them.
"""
# capture references to the current character and room, because they will be overwritten
config = get_game_config()
with action_context() as (action_room, action_character):
# sanity checks
question_character, question_agent = get_character_agent_for_name(character)
@ -216,7 +216,7 @@ def action_ask(character: str, question: str) -> str:
end_prompt,
echo_function=action_tell.__name__,
echo_parameter="message",
max_length=MAX_CONVERSATION_STEPS,
max_length=config.world.character.conversation_limit,
)
if result:
@ -233,7 +233,7 @@ def action_tell(character: str, message: str) -> str:
character: The name of the character to tell. You cannot talk to yourself.
message: The message to tell them.
"""
# capture references to the current character and room, because they will be overwritten
config = get_game_config()
with action_context() as (action_room, action_character):
# sanity checks
@ -268,7 +268,7 @@ def action_tell(character: str, message: str) -> str:
end_prompt,
echo_function=action_tell.__name__,
echo_parameter="message",
max_length=MAX_CONVERSATION_STEPS,
max_length=config.world.character.conversation_limit,
)
if result:

View File

@ -323,6 +323,12 @@ async def broadcast_event(message: str | GameEvent):
event_messages[event_message.id] = message
def truncate(text: str, length: int = 1000) -> str:
if len(text) > length:
return text[:length] + "..."
return text
def embed_from_event(event: GameEvent) -> Embed | None:
if isinstance(event, GenerateEvent):
return embed_from_generate(event)
@ -357,7 +363,7 @@ def embed_from_action(event: ActionEvent):
def embed_from_reply(event: ReplyEvent):
reply_embed = Embed(title=event.room.name, description=event.speaker.name)
reply_embed.add_field(name="Reply", value=event.text)
reply_embed.add_field(name="Reply", value=truncate(event.text))
return reply_embed
@ -367,12 +373,8 @@ def embed_from_generate(event: GenerateEvent) -> Embed:
def embed_from_result(event: ResultEvent):
text = event.result
if len(text) > 1000:
text = text[:1000] + "..."
result_embed = Embed(title=event.room.name, description=event.character.name)
result_embed.add_field(name="Result", value=text)
result_embed.add_field(name="Result", value=truncate(event.result))
return result_embed
@ -384,14 +386,14 @@ def embed_from_player(event: PlayerEvent):
title = format_prompt("discord_leave_title", event=event)
description = format_prompt("discord_leave_result", event=event)
player_embed = Embed(title=title, description=description)
player_embed = Embed(title=title, description=truncate(description))
return player_embed
def embed_from_prompt(event: PromptEvent):
# TODO: ping the player
prompt_embed = Embed(title=event.room.name, description=event.character.name)
prompt_embed.add_field(name="Prompt", value=event.prompt)
prompt_embed.add_field(name="Prompt", value=truncate(event.prompt))
return prompt_embed
@ -400,5 +402,5 @@ def embed_from_status(event: StatusEvent):
title=event.room.name if event.room else "",
description=event.character.name if event.character else "",
)
status_embed.add_field(name="Status", value=event.text)
status_embed.add_field(name="Status", value=truncate(event.text))
return status_embed

View File

@ -416,7 +416,7 @@ def main():
set_dungeon_master(world_builder)
# start the sim
logger.debug("simulating world: %s", world)
logger.debug("simulating world: %s", world.name)
simulate_world(
world,
turns=args.turns,

View File

@ -44,6 +44,7 @@ from taleweave.context import (
set_current_world,
set_game_systems,
)
from taleweave.errors import ActionError
from taleweave.game_system import GameSystem
from taleweave.models.entity import Character, Room, World
from taleweave.models.event import ActionEvent, ResultEvent
@ -117,12 +118,9 @@ def prompt_character_action(
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
event = ActionEvent.from_json(value, room, character)
else:
# TODO: this path should be removed and throw
# logger.warning(
# "invalid action, emitting as result event - this is a bug somewhere"
# )
# event = ResultEvent(value, room, character)
raise ValueError("invalid non-JSON action")
raise ActionError(
"Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format."
)
broadcast(event)
@ -216,14 +214,14 @@ def prompt_character_planning(
while not stop_condition(current=i):
result = loop_retry(
agent,
get_prompt("world_simulate_character_planning"),
context={
"event_count": event_count,
"events_prompt": events_prompt,
"note_count": note_count,
"notes_prompt": notes_prompt,
"room_summary": summarize_room(room, character),
},
format_prompt(
"world_simulate_character_planning",
event_count=event_count,
events_prompt=events_prompt,
note_count=note_count,
notes_prompt=notes_prompt,
room_summary=summarize_room(room, character),
),
result_parser=result_parser,
stop_condition=stop_condition,
toolbox=planner_toolbox,

View File

@ -7,12 +7,14 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System
from packit.agent import Agent, agent_easy_connect
from pydantic import RootModel
from taleweave.context import get_all_character_agents, set_character_agent
from taleweave.context import (
get_all_character_agents,
get_game_config,
set_character_agent,
)
from taleweave.models.entity import World
from taleweave.player import LocalPlayer
MEMORY_LIMIT = 25 # 10
def create_agents(
world: World,
@ -69,6 +71,7 @@ def snapshot_world(world: World, turn: int):
def restore_memory(
data: Sequence[str | Dict[str, str]]
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
config = get_game_config()
memories = []
for memory in data:
@ -85,7 +88,7 @@ def restore_memory(
elif memory_type == "ai":
memories.append(AIMessage(content=memory_content))
return deque(memories, maxlen=MEMORY_LIMIT)
return deque(memories, maxlen=config.world.character.memory_limit)
def save_world(world, filename):

View File

@ -5,6 +5,7 @@ from taleweave.context import get_current_world, get_prompt_library, subscribe
from taleweave.game_system import FormatPerspective, GameSystem
from taleweave.models.entity import Character, Room, World, WorldEntity
from taleweave.models.event import ActionEvent, GameEvent
from taleweave.utils.prompt import format_str
from taleweave.utils.search import find_containing_room
logger = getLogger(__name__)
@ -22,7 +23,7 @@ def create_turn_digest(
if prompt_key in library.prompts:
try:
template = library.prompts[prompt_key]
message = template.format(event=event)
message = format_str(template, event=event)
messages.append(message)
except Exception:
logger.exception("error formatting digest event: %s", event)

View File

@ -3,10 +3,17 @@ from logging import getLogger
from jinja2 import Environment
from taleweave.context import get_prompt_library
from taleweave.utils.string import and_list, or_list
from taleweave.utils.world import describe_entity, name_entity
logger = getLogger(__name__)
jinja_env = Environment()
jinja_env.filters["describe"] = describe_entity
jinja_env.filters["name"] = name_entity
jinja_env.filters["and_list"] = and_list
jinja_env.filters["or_list"] = or_list
def format_prompt(prompt_key: str, **kwargs) -> str:
try:
@ -19,9 +26,5 @@ def format_prompt(prompt_key: str, **kwargs) -> str:
def format_str(template_str: str, **kwargs) -> str:
env = Environment()
env.filters["describe"] = describe_entity
env.filters["name"] = name_entity
template = env.from_string(template_str)
template = jinja_env.from_string(template_str)
return template.render(**kwargs)