1
0
Fork 0

make memory configurable, consistently truncate discord messages, fix action prompt
Run Docker Build / build (push) Successful in 17s Details
Run Python Build / build (push) Successful in 27s Details

This commit is contained in:
Sean Sube 2024-06-02 20:00:17 -05:00
parent a9705727f0
commit f25dd57e67
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 64 additions and 45 deletions

View File

@ -69,7 +69,7 @@ export interface StringParameter {
export interface NumberParameter { export interface NumberParameter {
type: 'number'; type: 'number';
default?: number; default?: number;
enum?: Array<string>; enum?: Array<number>;
} }
export type Parameter = BooleanParameter | NumberParameter | StringParameter; export type Parameter = BooleanParameter | NumberParameter | StringParameter;

View File

@ -129,7 +129,7 @@ export function enumerateSignificantParameterValues(name: string, world: World)
} }
} }
export function convertSignificantParameter(name: string, parameter: Parameter, world: Maybe<World>): Parameter { export function convertSignificantParameter<T extends Parameter>(name: string, parameter: T, world: Maybe<World>): T {
if (parameter.type === 'boolean') { if (parameter.type === 'boolean') {
return parameter; return parameter;
} }
@ -154,15 +154,27 @@ export function formatAction(action: string, parameters: Record<string, boolean
return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`; return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`;
} }
export function getEnumOrDefault<T>(defaultValue: Maybe<T>, enumValues: Maybe<Array<T>>, evenMoreDefault: T): T {
if (doesExist(defaultValue)) {
return defaultValue;
}
if (doesExist(enumValues)) {
return enumValues[0];
}
return evenMoreDefault;
}
export function makeDefaultParameterValues(parameters: Record<string, Parameter>) { export function makeDefaultParameterValues(parameters: Record<string, Parameter>) {
return Object.entries(parameters).reduce((acc, [name, parameter]) => { return Object.entries(parameters).reduce((acc, [name, parameter]) => {
switch (parameter.type) { switch (parameter.type) {
case 'boolean': case 'boolean':
return { ...acc, [name]: mustDefault(parameter.default, false) }; return { ...acc, [name]: getEnumOrDefault(parameter.default, [], false) };
case 'number': case 'number':
return { ...acc, [name]: mustDefault(parameter.default, 0) }; return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, 0) };
case 'string': case 'string':
return { ...acc, [name]: mustDefault(parameter.default, '') }; return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, '') };
default: default:
return acc; return acc;
} }

View File

@ -5,6 +5,7 @@ from taleweave.context import (
broadcast, broadcast,
get_agent_for_character, get_agent_for_character,
get_character_agent_for_name, get_character_agent_for_name,
get_game_config,
get_prompt, get_prompt,
world_context, world_context,
) )
@ -22,8 +23,6 @@ from taleweave.utils.string import normalize_name
logger = getLogger(__name__) logger = getLogger(__name__)
MAX_CONVERSATION_STEPS = 2
def action_examine(target: str) -> str: def action_examine(target: str) -> str:
""" """
@ -173,7 +172,8 @@ def action_ask(character: str, question: str) -> str:
character: The name of the character to ask. You cannot ask yourself questions. character: The name of the character to ask. You cannot ask yourself questions.
question: The question to ask them. question: The question to ask them.
""" """
# capture references to the current character and room, because they will be overwritten config = get_game_config()
with action_context() as (action_room, action_character): with action_context() as (action_room, action_character):
# sanity checks # sanity checks
question_character, question_agent = get_character_agent_for_name(character) question_character, question_agent = get_character_agent_for_name(character)
@ -216,7 +216,7 @@ def action_ask(character: str, question: str) -> str:
end_prompt, end_prompt,
echo_function=action_tell.__name__, echo_function=action_tell.__name__,
echo_parameter="message", echo_parameter="message",
max_length=MAX_CONVERSATION_STEPS, max_length=config.world.character.conversation_limit,
) )
if result: if result:
@ -233,7 +233,7 @@ def action_tell(character: str, message: str) -> str:
character: The name of the character to tell. You cannot talk to yourself. character: The name of the character to tell. You cannot talk to yourself.
message: The message to tell them. message: The message to tell them.
""" """
# capture references to the current character and room, because they will be overwritten config = get_game_config()
with action_context() as (action_room, action_character): with action_context() as (action_room, action_character):
# sanity checks # sanity checks
@ -268,7 +268,7 @@ def action_tell(character: str, message: str) -> str:
end_prompt, end_prompt,
echo_function=action_tell.__name__, echo_function=action_tell.__name__,
echo_parameter="message", echo_parameter="message",
max_length=MAX_CONVERSATION_STEPS, max_length=config.world.character.conversation_limit,
) )
if result: if result:

View File

@ -323,6 +323,12 @@ async def broadcast_event(message: str | GameEvent):
event_messages[event_message.id] = message event_messages[event_message.id] = message
def truncate(text: str, length: int = 1000) -> str:
if len(text) > length:
return text[:length] + "..."
return text
def embed_from_event(event: GameEvent) -> Embed | None: def embed_from_event(event: GameEvent) -> Embed | None:
if isinstance(event, GenerateEvent): if isinstance(event, GenerateEvent):
return embed_from_generate(event) return embed_from_generate(event)
@ -357,7 +363,7 @@ def embed_from_action(event: ActionEvent):
def embed_from_reply(event: ReplyEvent): def embed_from_reply(event: ReplyEvent):
reply_embed = Embed(title=event.room.name, description=event.speaker.name) reply_embed = Embed(title=event.room.name, description=event.speaker.name)
reply_embed.add_field(name="Reply", value=event.text) reply_embed.add_field(name="Reply", value=truncate(event.text))
return reply_embed return reply_embed
@ -367,12 +373,8 @@ def embed_from_generate(event: GenerateEvent) -> Embed:
def embed_from_result(event: ResultEvent): def embed_from_result(event: ResultEvent):
text = event.result
if len(text) > 1000:
text = text[:1000] + "..."
result_embed = Embed(title=event.room.name, description=event.character.name) result_embed = Embed(title=event.room.name, description=event.character.name)
result_embed.add_field(name="Result", value=text) result_embed.add_field(name="Result", value=truncate(event.result))
return result_embed return result_embed
@ -384,14 +386,14 @@ def embed_from_player(event: PlayerEvent):
title = format_prompt("discord_leave_title", event=event) title = format_prompt("discord_leave_title", event=event)
description = format_prompt("discord_leave_result", event=event) description = format_prompt("discord_leave_result", event=event)
player_embed = Embed(title=title, description=description) player_embed = Embed(title=title, description=truncate(description))
return player_embed return player_embed
def embed_from_prompt(event: PromptEvent): def embed_from_prompt(event: PromptEvent):
# TODO: ping the player # TODO: ping the player
prompt_embed = Embed(title=event.room.name, description=event.character.name) prompt_embed = Embed(title=event.room.name, description=event.character.name)
prompt_embed.add_field(name="Prompt", value=event.prompt) prompt_embed.add_field(name="Prompt", value=truncate(event.prompt))
return prompt_embed return prompt_embed
@ -400,5 +402,5 @@ def embed_from_status(event: StatusEvent):
title=event.room.name if event.room else "", title=event.room.name if event.room else "",
description=event.character.name if event.character else "", description=event.character.name if event.character else "",
) )
status_embed.add_field(name="Status", value=event.text) status_embed.add_field(name="Status", value=truncate(event.text))
return status_embed return status_embed

View File

@ -416,7 +416,7 @@ def main():
set_dungeon_master(world_builder) set_dungeon_master(world_builder)
# start the sim # start the sim
logger.debug("simulating world: %s", world) logger.debug("simulating world: %s", world.name)
simulate_world( simulate_world(
world, world,
turns=args.turns, turns=args.turns,

View File

@ -44,6 +44,7 @@ from taleweave.context import (
set_current_world, set_current_world,
set_game_systems, set_game_systems,
) )
from taleweave.errors import ActionError
from taleweave.game_system import GameSystem from taleweave.game_system import GameSystem
from taleweave.models.entity import Character, Room, World from taleweave.models.entity import Character, Room, World
from taleweave.models.event import ActionEvent, ResultEvent from taleweave.models.event import ActionEvent, ResultEvent
@ -117,12 +118,9 @@ def prompt_character_action(
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice # TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
event = ActionEvent.from_json(value, room, character) event = ActionEvent.from_json(value, room, character)
else: else:
# TODO: this path should be removed and throw raise ActionError(
# logger.warning( "Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format."
# "invalid action, emitting as result event - this is a bug somewhere" )
# )
# event = ResultEvent(value, room, character)
raise ValueError("invalid non-JSON action")
broadcast(event) broadcast(event)
@ -216,14 +214,14 @@ def prompt_character_planning(
while not stop_condition(current=i): while not stop_condition(current=i):
result = loop_retry( result = loop_retry(
agent, agent,
get_prompt("world_simulate_character_planning"), format_prompt(
context={ "world_simulate_character_planning",
"event_count": event_count, event_count=event_count,
"events_prompt": events_prompt, events_prompt=events_prompt,
"note_count": note_count, note_count=note_count,
"notes_prompt": notes_prompt, notes_prompt=notes_prompt,
"room_summary": summarize_room(room, character), room_summary=summarize_room(room, character),
}, ),
result_parser=result_parser, result_parser=result_parser,
stop_condition=stop_condition, stop_condition=stop_condition,
toolbox=planner_toolbox, toolbox=planner_toolbox,

View File

@ -7,12 +7,14 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System
from packit.agent import Agent, agent_easy_connect from packit.agent import Agent, agent_easy_connect
from pydantic import RootModel from pydantic import RootModel
from taleweave.context import get_all_character_agents, set_character_agent from taleweave.context import (
get_all_character_agents,
get_game_config,
set_character_agent,
)
from taleweave.models.entity import World from taleweave.models.entity import World
from taleweave.player import LocalPlayer from taleweave.player import LocalPlayer
MEMORY_LIMIT = 25 # 10
def create_agents( def create_agents(
world: World, world: World,
@ -69,6 +71,7 @@ def snapshot_world(world: World, turn: int):
def restore_memory( def restore_memory(
data: Sequence[str | Dict[str, str]] data: Sequence[str | Dict[str, str]]
) -> deque[str | AIMessage | HumanMessage | SystemMessage]: ) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
config = get_game_config()
memories = [] memories = []
for memory in data: for memory in data:
@ -85,7 +88,7 @@ def restore_memory(
elif memory_type == "ai": elif memory_type == "ai":
memories.append(AIMessage(content=memory_content)) memories.append(AIMessage(content=memory_content))
return deque(memories, maxlen=MEMORY_LIMIT) return deque(memories, maxlen=config.world.character.memory_limit)
def save_world(world, filename): def save_world(world, filename):

View File

@ -5,6 +5,7 @@ from taleweave.context import get_current_world, get_prompt_library, subscribe
from taleweave.game_system import FormatPerspective, GameSystem from taleweave.game_system import FormatPerspective, GameSystem
from taleweave.models.entity import Character, Room, World, WorldEntity from taleweave.models.entity import Character, Room, World, WorldEntity
from taleweave.models.event import ActionEvent, GameEvent from taleweave.models.event import ActionEvent, GameEvent
from taleweave.utils.prompt import format_str
from taleweave.utils.search import find_containing_room from taleweave.utils.search import find_containing_room
logger = getLogger(__name__) logger = getLogger(__name__)
@ -22,7 +23,7 @@ def create_turn_digest(
if prompt_key in library.prompts: if prompt_key in library.prompts:
try: try:
template = library.prompts[prompt_key] template = library.prompts[prompt_key]
message = template.format(event=event) message = format_str(template, event=event)
messages.append(message) messages.append(message)
except Exception: except Exception:
logger.exception("error formatting digest event: %s", event) logger.exception("error formatting digest event: %s", event)

View File

@ -3,10 +3,17 @@ from logging import getLogger
from jinja2 import Environment from jinja2 import Environment
from taleweave.context import get_prompt_library from taleweave.context import get_prompt_library
from taleweave.utils.string import and_list, or_list
from taleweave.utils.world import describe_entity, name_entity from taleweave.utils.world import describe_entity, name_entity
logger = getLogger(__name__) logger = getLogger(__name__)
jinja_env = Environment()
jinja_env.filters["describe"] = describe_entity
jinja_env.filters["name"] = name_entity
jinja_env.filters["and_list"] = and_list
jinja_env.filters["or_list"] = or_list
def format_prompt(prompt_key: str, **kwargs) -> str: def format_prompt(prompt_key: str, **kwargs) -> str:
try: try:
@ -19,9 +26,5 @@ def format_prompt(prompt_key: str, **kwargs) -> str:
def format_str(template_str: str, **kwargs) -> str: def format_str(template_str: str, **kwargs) -> str:
env = Environment() template = jinja_env.from_string(template_str)
env.filters["describe"] = describe_entity
env.filters["name"] = name_entity
template = env.from_string(template_str)
return template.render(**kwargs) return template.render(**kwargs)