diff --git a/adventure/main.py b/adventure/main.py index 93f5837..48524e8 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -202,6 +202,8 @@ def load_or_generate_world( world_state_file = args.state or (args.world + ".state.json") memory = {} + step = 0 + if path.exists(world_state_file): logger.info(f"loading world state from {world_state_file}") with open(world_state_file, "r") as f: @@ -211,6 +213,7 @@ def load_or_generate_world( load_or_initialize_system_data(args, systems, state.world) memory = state.memory + step = state.step world = state.world elif path.exists(world_file): logger.info(f"loading world from {world_file}") @@ -240,7 +243,7 @@ def load_or_generate_world( save_system_data(args, systems) create_agents(world, memory=memory, players=players) - return (world, world_state_file) + return (world, world_state_file, step) def main(): @@ -320,7 +323,7 @@ def main(): # load or generate the world world_prompt = get_world_prompt(args) - world, world_state_file = load_or_generate_world( + world, world_state_file, world_step = load_or_generate_world( args, players, extra_systems, world_prompt=world_prompt ) @@ -333,7 +336,7 @@ def main(): # hack: send a snapshot to the websocket server if args.server: - server_system(world, 0) + server_system(world, world_step) # create the DM llm = agent_easy_connect() diff --git a/adventure/utils/conversation.py b/adventure/utils/conversation.py index 8d787ab..5bdacee 100644 --- a/adventure/utils/conversation.py +++ b/adventure/utils/conversation.py @@ -4,7 +4,7 @@ from logging import getLogger from typing import List from packit.agent import Agent -from packit.conditions import condition_and, condition_threshold, make_flag_condition +from packit.conditions import condition_or, condition_threshold, make_flag_condition from packit.results import multi_function_or_str_result from packit.utils import could_be_json @@ -30,16 +30,28 @@ def make_keyword_condition(end_message: str, keywords=["end", "stop"]): set_end() return end_message - # sometimes the models will make up a tool named after the keyword - keyword_functions = [f'"function": "{kw}"' for kw in keywords] - if could_be_json(normalized_value) and any( - kw in normalized_value for kw in keyword_functions - ): - logger.debug( - f"found keyword function, setting stop condition: {normalized_value}" - ) - set_end() - return end_message + for keyword in keywords: + if keyword == normalized_value: + logger.debug( + f"found keyword, setting stop condition: {normalized_value}" + ) + set_end() + return end_message + + if normalized_value.endswith(keyword): + logger.debug( + f"found keyword at end of string, setting stop condition: {normalized_value}" + ) + set_end() + return value[: -len(keyword)].strip() + + keyword_function = f'"function": "{keyword}"' + if could_be_json(normalized_value) and keyword_function in normalized_value: + logger.debug( + f"found keyword function, setting stop condition: {normalized_value}" + ) + set_end() + return end_message return multi_function_or_str_result(value, **kwargs) @@ -115,7 +127,7 @@ def loop_conversation( _, condition_end, parse_end = make_keyword_condition(end_message) stop_length = partial(condition_threshold, max=max_length) - stop_condition = condition_and(condition_end, stop_length) + stop_condition = condition_or(condition_end, stop_length) def result_parser(value: str, **kwargs) -> str: value = parse_end(value, **kwargs) @@ -150,7 +162,9 @@ def loop_conversation( prompt, response=response, summary=summary, last_actor=last_actor ) response = result_parser(response) - broadcast(f"{actor.name} responds: {response}") + + logger.info(f"{actor.name} response: {response}") + broadcast(f"{actor.name} responds to {last_actor.name}: {response}") # increment the step counter i += 1