make memory configurable, consistently truncate discord messages, fix action prompt
This commit is contained in:
parent
a9705727f0
commit
f25dd57e67
|
@ -69,7 +69,7 @@ export interface StringParameter {
|
||||||
export interface NumberParameter {
|
export interface NumberParameter {
|
||||||
type: 'number';
|
type: 'number';
|
||||||
default?: number;
|
default?: number;
|
||||||
enum?: Array<string>;
|
enum?: Array<number>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type Parameter = BooleanParameter | NumberParameter | StringParameter;
|
export type Parameter = BooleanParameter | NumberParameter | StringParameter;
|
||||||
|
|
|
@ -129,7 +129,7 @@ export function enumerateSignificantParameterValues(name: string, world: World)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function convertSignificantParameter(name: string, parameter: Parameter, world: Maybe<World>): Parameter {
|
export function convertSignificantParameter<T extends Parameter>(name: string, parameter: T, world: Maybe<World>): T {
|
||||||
if (parameter.type === 'boolean') {
|
if (parameter.type === 'boolean') {
|
||||||
return parameter;
|
return parameter;
|
||||||
}
|
}
|
||||||
|
@ -154,15 +154,27 @@ export function formatAction(action: string, parameters: Record<string, boolean
|
||||||
return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`;
|
return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getEnumOrDefault<T>(defaultValue: Maybe<T>, enumValues: Maybe<Array<T>>, evenMoreDefault: T): T {
|
||||||
|
if (doesExist(defaultValue)) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (doesExist(enumValues)) {
|
||||||
|
return enumValues[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
return evenMoreDefault;
|
||||||
|
}
|
||||||
|
|
||||||
export function makeDefaultParameterValues(parameters: Record<string, Parameter>) {
|
export function makeDefaultParameterValues(parameters: Record<string, Parameter>) {
|
||||||
return Object.entries(parameters).reduce((acc, [name, parameter]) => {
|
return Object.entries(parameters).reduce((acc, [name, parameter]) => {
|
||||||
switch (parameter.type) {
|
switch (parameter.type) {
|
||||||
case 'boolean':
|
case 'boolean':
|
||||||
return { ...acc, [name]: mustDefault(parameter.default, false) };
|
return { ...acc, [name]: getEnumOrDefault(parameter.default, [], false) };
|
||||||
case 'number':
|
case 'number':
|
||||||
return { ...acc, [name]: mustDefault(parameter.default, 0) };
|
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, 0) };
|
||||||
case 'string':
|
case 'string':
|
||||||
return { ...acc, [name]: mustDefault(parameter.default, '') };
|
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, '') };
|
||||||
default:
|
default:
|
||||||
return acc;
|
return acc;
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ from taleweave.context import (
|
||||||
broadcast,
|
broadcast,
|
||||||
get_agent_for_character,
|
get_agent_for_character,
|
||||||
get_character_agent_for_name,
|
get_character_agent_for_name,
|
||||||
|
get_game_config,
|
||||||
get_prompt,
|
get_prompt,
|
||||||
world_context,
|
world_context,
|
||||||
)
|
)
|
||||||
|
@ -22,8 +23,6 @@ from taleweave.utils.string import normalize_name
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
MAX_CONVERSATION_STEPS = 2
|
|
||||||
|
|
||||||
|
|
||||||
def action_examine(target: str) -> str:
|
def action_examine(target: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -173,7 +172,8 @@ def action_ask(character: str, question: str) -> str:
|
||||||
character: The name of the character to ask. You cannot ask yourself questions.
|
character: The name of the character to ask. You cannot ask yourself questions.
|
||||||
question: The question to ask them.
|
question: The question to ask them.
|
||||||
"""
|
"""
|
||||||
# capture references to the current character and room, because they will be overwritten
|
config = get_game_config()
|
||||||
|
|
||||||
with action_context() as (action_room, action_character):
|
with action_context() as (action_room, action_character):
|
||||||
# sanity checks
|
# sanity checks
|
||||||
question_character, question_agent = get_character_agent_for_name(character)
|
question_character, question_agent = get_character_agent_for_name(character)
|
||||||
|
@ -216,7 +216,7 @@ def action_ask(character: str, question: str) -> str:
|
||||||
end_prompt,
|
end_prompt,
|
||||||
echo_function=action_tell.__name__,
|
echo_function=action_tell.__name__,
|
||||||
echo_parameter="message",
|
echo_parameter="message",
|
||||||
max_length=MAX_CONVERSATION_STEPS,
|
max_length=config.world.character.conversation_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
@ -233,7 +233,7 @@ def action_tell(character: str, message: str) -> str:
|
||||||
character: The name of the character to tell. You cannot talk to yourself.
|
character: The name of the character to tell. You cannot talk to yourself.
|
||||||
message: The message to tell them.
|
message: The message to tell them.
|
||||||
"""
|
"""
|
||||||
# capture references to the current character and room, because they will be overwritten
|
config = get_game_config()
|
||||||
|
|
||||||
with action_context() as (action_room, action_character):
|
with action_context() as (action_room, action_character):
|
||||||
# sanity checks
|
# sanity checks
|
||||||
|
@ -268,7 +268,7 @@ def action_tell(character: str, message: str) -> str:
|
||||||
end_prompt,
|
end_prompt,
|
||||||
echo_function=action_tell.__name__,
|
echo_function=action_tell.__name__,
|
||||||
echo_parameter="message",
|
echo_parameter="message",
|
||||||
max_length=MAX_CONVERSATION_STEPS,
|
max_length=config.world.character.conversation_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|
|
@ -323,6 +323,12 @@ async def broadcast_event(message: str | GameEvent):
|
||||||
event_messages[event_message.id] = message
|
event_messages[event_message.id] = message
|
||||||
|
|
||||||
|
|
||||||
|
def truncate(text: str, length: int = 1000) -> str:
|
||||||
|
if len(text) > length:
|
||||||
|
return text[:length] + "..."
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def embed_from_event(event: GameEvent) -> Embed | None:
|
def embed_from_event(event: GameEvent) -> Embed | None:
|
||||||
if isinstance(event, GenerateEvent):
|
if isinstance(event, GenerateEvent):
|
||||||
return embed_from_generate(event)
|
return embed_from_generate(event)
|
||||||
|
@ -357,7 +363,7 @@ def embed_from_action(event: ActionEvent):
|
||||||
|
|
||||||
def embed_from_reply(event: ReplyEvent):
|
def embed_from_reply(event: ReplyEvent):
|
||||||
reply_embed = Embed(title=event.room.name, description=event.speaker.name)
|
reply_embed = Embed(title=event.room.name, description=event.speaker.name)
|
||||||
reply_embed.add_field(name="Reply", value=event.text)
|
reply_embed.add_field(name="Reply", value=truncate(event.text))
|
||||||
return reply_embed
|
return reply_embed
|
||||||
|
|
||||||
|
|
||||||
|
@ -367,12 +373,8 @@ def embed_from_generate(event: GenerateEvent) -> Embed:
|
||||||
|
|
||||||
|
|
||||||
def embed_from_result(event: ResultEvent):
|
def embed_from_result(event: ResultEvent):
|
||||||
text = event.result
|
|
||||||
if len(text) > 1000:
|
|
||||||
text = text[:1000] + "..."
|
|
||||||
|
|
||||||
result_embed = Embed(title=event.room.name, description=event.character.name)
|
result_embed = Embed(title=event.room.name, description=event.character.name)
|
||||||
result_embed.add_field(name="Result", value=text)
|
result_embed.add_field(name="Result", value=truncate(event.result))
|
||||||
return result_embed
|
return result_embed
|
||||||
|
|
||||||
|
|
||||||
|
@ -384,14 +386,14 @@ def embed_from_player(event: PlayerEvent):
|
||||||
title = format_prompt("discord_leave_title", event=event)
|
title = format_prompt("discord_leave_title", event=event)
|
||||||
description = format_prompt("discord_leave_result", event=event)
|
description = format_prompt("discord_leave_result", event=event)
|
||||||
|
|
||||||
player_embed = Embed(title=title, description=description)
|
player_embed = Embed(title=title, description=truncate(description))
|
||||||
return player_embed
|
return player_embed
|
||||||
|
|
||||||
|
|
||||||
def embed_from_prompt(event: PromptEvent):
|
def embed_from_prompt(event: PromptEvent):
|
||||||
# TODO: ping the player
|
# TODO: ping the player
|
||||||
prompt_embed = Embed(title=event.room.name, description=event.character.name)
|
prompt_embed = Embed(title=event.room.name, description=event.character.name)
|
||||||
prompt_embed.add_field(name="Prompt", value=event.prompt)
|
prompt_embed.add_field(name="Prompt", value=truncate(event.prompt))
|
||||||
return prompt_embed
|
return prompt_embed
|
||||||
|
|
||||||
|
|
||||||
|
@ -400,5 +402,5 @@ def embed_from_status(event: StatusEvent):
|
||||||
title=event.room.name if event.room else "",
|
title=event.room.name if event.room else "",
|
||||||
description=event.character.name if event.character else "",
|
description=event.character.name if event.character else "",
|
||||||
)
|
)
|
||||||
status_embed.add_field(name="Status", value=event.text)
|
status_embed.add_field(name="Status", value=truncate(event.text))
|
||||||
return status_embed
|
return status_embed
|
||||||
|
|
|
@ -416,7 +416,7 @@ def main():
|
||||||
set_dungeon_master(world_builder)
|
set_dungeon_master(world_builder)
|
||||||
|
|
||||||
# start the sim
|
# start the sim
|
||||||
logger.debug("simulating world: %s", world)
|
logger.debug("simulating world: %s", world.name)
|
||||||
simulate_world(
|
simulate_world(
|
||||||
world,
|
world,
|
||||||
turns=args.turns,
|
turns=args.turns,
|
||||||
|
|
|
@ -44,6 +44,7 @@ from taleweave.context import (
|
||||||
set_current_world,
|
set_current_world,
|
||||||
set_game_systems,
|
set_game_systems,
|
||||||
)
|
)
|
||||||
|
from taleweave.errors import ActionError
|
||||||
from taleweave.game_system import GameSystem
|
from taleweave.game_system import GameSystem
|
||||||
from taleweave.models.entity import Character, Room, World
|
from taleweave.models.entity import Character, Room, World
|
||||||
from taleweave.models.event import ActionEvent, ResultEvent
|
from taleweave.models.event import ActionEvent, ResultEvent
|
||||||
|
@ -117,12 +118,9 @@ def prompt_character_action(
|
||||||
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
|
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
|
||||||
event = ActionEvent.from_json(value, room, character)
|
event = ActionEvent.from_json(value, room, character)
|
||||||
else:
|
else:
|
||||||
# TODO: this path should be removed and throw
|
raise ActionError(
|
||||||
# logger.warning(
|
"Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format."
|
||||||
# "invalid action, emitting as result event - this is a bug somewhere"
|
)
|
||||||
# )
|
|
||||||
# event = ResultEvent(value, room, character)
|
|
||||||
raise ValueError("invalid non-JSON action")
|
|
||||||
|
|
||||||
broadcast(event)
|
broadcast(event)
|
||||||
|
|
||||||
|
@ -216,14 +214,14 @@ def prompt_character_planning(
|
||||||
while not stop_condition(current=i):
|
while not stop_condition(current=i):
|
||||||
result = loop_retry(
|
result = loop_retry(
|
||||||
agent,
|
agent,
|
||||||
get_prompt("world_simulate_character_planning"),
|
format_prompt(
|
||||||
context={
|
"world_simulate_character_planning",
|
||||||
"event_count": event_count,
|
event_count=event_count,
|
||||||
"events_prompt": events_prompt,
|
events_prompt=events_prompt,
|
||||||
"note_count": note_count,
|
note_count=note_count,
|
||||||
"notes_prompt": notes_prompt,
|
notes_prompt=notes_prompt,
|
||||||
"room_summary": summarize_room(room, character),
|
room_summary=summarize_room(room, character),
|
||||||
},
|
),
|
||||||
result_parser=result_parser,
|
result_parser=result_parser,
|
||||||
stop_condition=stop_condition,
|
stop_condition=stop_condition,
|
||||||
toolbox=planner_toolbox,
|
toolbox=planner_toolbox,
|
||||||
|
|
|
@ -7,12 +7,14 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System
|
||||||
from packit.agent import Agent, agent_easy_connect
|
from packit.agent import Agent, agent_easy_connect
|
||||||
from pydantic import RootModel
|
from pydantic import RootModel
|
||||||
|
|
||||||
from taleweave.context import get_all_character_agents, set_character_agent
|
from taleweave.context import (
|
||||||
|
get_all_character_agents,
|
||||||
|
get_game_config,
|
||||||
|
set_character_agent,
|
||||||
|
)
|
||||||
from taleweave.models.entity import World
|
from taleweave.models.entity import World
|
||||||
from taleweave.player import LocalPlayer
|
from taleweave.player import LocalPlayer
|
||||||
|
|
||||||
MEMORY_LIMIT = 25 # 10
|
|
||||||
|
|
||||||
|
|
||||||
def create_agents(
|
def create_agents(
|
||||||
world: World,
|
world: World,
|
||||||
|
@ -69,6 +71,7 @@ def snapshot_world(world: World, turn: int):
|
||||||
def restore_memory(
|
def restore_memory(
|
||||||
data: Sequence[str | Dict[str, str]]
|
data: Sequence[str | Dict[str, str]]
|
||||||
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
|
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
|
||||||
|
config = get_game_config()
|
||||||
memories = []
|
memories = []
|
||||||
|
|
||||||
for memory in data:
|
for memory in data:
|
||||||
|
@ -85,7 +88,7 @@ def restore_memory(
|
||||||
elif memory_type == "ai":
|
elif memory_type == "ai":
|
||||||
memories.append(AIMessage(content=memory_content))
|
memories.append(AIMessage(content=memory_content))
|
||||||
|
|
||||||
return deque(memories, maxlen=MEMORY_LIMIT)
|
return deque(memories, maxlen=config.world.character.memory_limit)
|
||||||
|
|
||||||
|
|
||||||
def save_world(world, filename):
|
def save_world(world, filename):
|
||||||
|
|
|
@ -5,6 +5,7 @@ from taleweave.context import get_current_world, get_prompt_library, subscribe
|
||||||
from taleweave.game_system import FormatPerspective, GameSystem
|
from taleweave.game_system import FormatPerspective, GameSystem
|
||||||
from taleweave.models.entity import Character, Room, World, WorldEntity
|
from taleweave.models.entity import Character, Room, World, WorldEntity
|
||||||
from taleweave.models.event import ActionEvent, GameEvent
|
from taleweave.models.event import ActionEvent, GameEvent
|
||||||
|
from taleweave.utils.prompt import format_str
|
||||||
from taleweave.utils.search import find_containing_room
|
from taleweave.utils.search import find_containing_room
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -22,7 +23,7 @@ def create_turn_digest(
|
||||||
if prompt_key in library.prompts:
|
if prompt_key in library.prompts:
|
||||||
try:
|
try:
|
||||||
template = library.prompts[prompt_key]
|
template = library.prompts[prompt_key]
|
||||||
message = template.format(event=event)
|
message = format_str(template, event=event)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error formatting digest event: %s", event)
|
logger.exception("error formatting digest event: %s", event)
|
||||||
|
|
|
@ -3,10 +3,17 @@ from logging import getLogger
|
||||||
from jinja2 import Environment
|
from jinja2 import Environment
|
||||||
|
|
||||||
from taleweave.context import get_prompt_library
|
from taleweave.context import get_prompt_library
|
||||||
|
from taleweave.utils.string import and_list, or_list
|
||||||
from taleweave.utils.world import describe_entity, name_entity
|
from taleweave.utils.world import describe_entity, name_entity
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
jinja_env = Environment()
|
||||||
|
jinja_env.filters["describe"] = describe_entity
|
||||||
|
jinja_env.filters["name"] = name_entity
|
||||||
|
jinja_env.filters["and_list"] = and_list
|
||||||
|
jinja_env.filters["or_list"] = or_list
|
||||||
|
|
||||||
|
|
||||||
def format_prompt(prompt_key: str, **kwargs) -> str:
|
def format_prompt(prompt_key: str, **kwargs) -> str:
|
||||||
try:
|
try:
|
||||||
|
@ -19,9 +26,5 @@ def format_prompt(prompt_key: str, **kwargs) -> str:
|
||||||
|
|
||||||
|
|
||||||
def format_str(template_str: str, **kwargs) -> str:
|
def format_str(template_str: str, **kwargs) -> str:
|
||||||
env = Environment()
|
template = jinja_env.from_string(template_str)
|
||||||
env.filters["describe"] = describe_entity
|
|
||||||
env.filters["name"] = name_entity
|
|
||||||
|
|
||||||
template = env.from_string(template_str)
|
|
||||||
return template.render(**kwargs)
|
return template.render(**kwargs)
|
||||||
|
|
Loading…
Reference in New Issue