From 5a32bd9fc47d3bcccf06a8dd0fc2d96394fe43bd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 4 Jun 2024 22:07:26 -0500 Subject: [PATCH] make search helpers more flexible, split up some prompts, pass world to system generators --- prompts/llama-base.yml | 8 ++++- taleweave/actions/base.py | 11 +++--- taleweave/actions/optional.py | 14 ++++++-- taleweave/actions/planning.py | 13 +++++-- taleweave/bot/discord.py | 2 +- taleweave/game_system.py | 4 ++- taleweave/generate.py | 3 +- taleweave/main.py | 6 ++-- taleweave/models/files.py | 5 ++- taleweave/render/comfy.py | 6 ++-- taleweave/server/websocket.py | 3 +- taleweave/systems/digest.py | 2 +- taleweave/systems/quest.py | 2 +- taleweave/systems/weather/__init__.py | 4 +-- taleweave/utils/planning.py | 8 +++-- taleweave/utils/prompt.py | 1 + taleweave/utils/search.py | 50 +++++++++++++++------------ worlds.yml | 8 ++--- 18 files changed, 91 insertions(+), 59 deletions(-) diff --git a/prompts/llama-base.yml b/prompts/llama-base.yml index b629e2a..d1d8bc9 100644 --- a/prompts/llama-base.yml +++ b/prompts/llama-base.yml @@ -124,8 +124,10 @@ prompts: The {{item}} item is not available in your inventory or in the room. action_use_error_target: | The {{target}} is not in the room, so you cannot use the {{item}} item on it. - action_use_broadcast: | + action_use_broadcast_effect: | {{action_character | name}} uses {{item}} on {{target}} and applies the {{effect}} effect. + action_use_broadcast_outcome: | + Using the {{item}} item on {{target}} resulted in: {{outcome}}. action_use_dm_effect: | {{action_character | name}} uses {{item}} on {{target}}. {{item}} can apply any of the following effects: {{effect_names}}. Which effect should be applied? Specify the effect. Do not include the question or any JSON. Only reply with the effect name. @@ -177,6 +179,10 @@ prompts: action_schedule_event_error_name: | The event must have a name. + action_schedule_event_error_limit: | + You have reached the maximum number of events. Please delete or reschedule some of your existing events before adding more. + action_schedule_event_error_duplicate: | + You already have an event with that name. Please choose a unique name for the event. action_schedule_event_result: | You scheduled an event that will happen in {{turns}} turns. diff --git a/taleweave/actions/base.py b/taleweave/actions/base.py index 0378533..0ebf0e7 100644 --- a/taleweave/actions/base.py +++ b/taleweave/actions/base.py @@ -177,22 +177,21 @@ def action_ask(character: str, question: str) -> str: with action_context() as (action_room, action_character): # sanity checks - question_character, question_agent = get_character_agent_for_name(character) - if question_character == action_character: - raise ActionError(format_prompt("action_ask_error_self")) - + question_character = find_character_in_room(action_room, character) if not question_character: raise ActionError( format_prompt("action_ask_error_target", character=character) ) + if question_character == action_character: + raise ActionError(format_prompt("action_ask_error_self")) + + question_agent = get_agent_for_character(question_character) if not question_agent: raise ActionError( format_prompt("action_ask_error_agent", character=character) ) - # TODO: make sure they are in the same room - broadcast( format_prompt( "action_ask_broadcast", diff --git a/taleweave/actions/optional.py b/taleweave/actions/optional.py index 0883ef1..35ebfcc 100644 --- a/taleweave/actions/optional.py +++ b/taleweave/actions/optional.py @@ -214,7 +214,7 @@ def action_use(item: str, target: str) -> str: broadcast( format_prompt( - "action_use_broadcast", + "action_use_broadcast_effect", action_character=action_character, effect=effect, item=item, @@ -233,8 +233,16 @@ def action_use(item: str, target: str) -> str: ) ) broadcast( - f"The action resulted in: {outcome}" - ) # TODO: should this be removed or moved to the prompt library? + format_prompt( + "action_use_broadcast_outcome", + action_character=action_character, + action_item=action_item, + effect=effect, + item=item, + target=target, + outcome=outcome, + ) + ) # make sure both agents remember the outcome target_agent = get_agent_for_character(target_character) diff --git a/taleweave/actions/planning.py b/taleweave/actions/planning.py index 1d052e4..006bacc 100644 --- a/taleweave/actions/planning.py +++ b/taleweave/actions/planning.py @@ -145,15 +145,22 @@ def schedule_event(name: str, turns: int): turns: The number of turns until the event happens. """ - # TODO: check for existing events with the same name - # TODO: limit the number of events that can be scheduled - + config = get_game_config() current_turn = get_current_turn() with action_context() as (_, action_character): if not name: raise ActionError(get_prompt("action_schedule_event_error_name")) + if ( + len(action_character.planner.calendar.events) + >= config.world.character.event_limit + ): + raise ActionError(get_prompt("action_schedule_event_error_limit")) + + if name in [event.name for event in action_character.planner.calendar.events]: + raise ActionError(get_prompt("action_schedule_event_error_duplicate")) + event = CalendarEvent(name, turns + current_turn) action_character.planner.calendar.events.append(event) return format_prompt("action_schedule_event_result", name=name, turns=turns) diff --git a/taleweave/bot/discord.py b/taleweave/bot/discord.py index 15eaa71..d11d73b 100644 --- a/taleweave/bot/discord.py +++ b/taleweave/bot/discord.py @@ -68,7 +68,7 @@ class AdventureClient(Client): message_id = reaction.message.id if message_id not in event_messages: logger.warning(f"message {message_id} not found in event messages") - # TODO: return error message + await reaction.message.add_reaction("❌") return event = event_messages[message_id] diff --git a/taleweave/game_system.py b/taleweave/game_system.py index ab9dcfd..634121e 100644 --- a/taleweave/game_system.py +++ b/taleweave/game_system.py @@ -23,9 +23,11 @@ class SystemFormat(Protocol): class SystemGenerate(Protocol): - def __call__(self, agent: Agent, theme: str, entity: WorldEntity) -> None: + def __call__(self, agent: Agent, world: World, entity: WorldEntity) -> None: """ Generate a new world entity based on the given theme and entity. + + TODO: should this include the WorldPrompt as a parameter? """ ... diff --git a/taleweave/generate.py b/taleweave/generate.py index f6d8111..63fde29 100644 --- a/taleweave/generate.py +++ b/taleweave/generate.py @@ -92,8 +92,7 @@ def generate_system_attributes( ) -> None: for system in systems: if system.generate: - # TODO: pass the whole world - system.generate(agent, world.theme, entity) + system.generate(agent, world, entity) def generate_room( diff --git a/taleweave/main.py b/taleweave/main.py index 920f34f..7f3eb1d 100644 --- a/taleweave/main.py +++ b/taleweave/main.py @@ -45,7 +45,7 @@ if True: from taleweave.models.config import DEFAULT_CONFIG, Config from taleweave.models.entity import World, WorldState from taleweave.models.event import GenerateEvent - from taleweave.models.files import PromptFile, WorldPrompt + from taleweave.models.files import TemplateFile, WorldPrompt from taleweave.models.prompt import PromptLibrary from taleweave.plugins import load_plugin from taleweave.simulate import simulate_world @@ -180,8 +180,8 @@ def get_world_prompt(args) -> WorldPrompt: if args.world_template: prompt_file, prompt_name = args.world_template.split(":") with open(prompt_file, "r") as f: - prompts = PromptFile(**load_yaml(f)) - for prompt in prompts.prompts: + prompts = TemplateFile(**load_yaml(f)) + for prompt in prompts.templates: if prompt.name == prompt_name: return prompt diff --git a/taleweave/models/files.py b/taleweave/models/files.py index 655b13c..b79ceb1 100644 --- a/taleweave/models/files.py +++ b/taleweave/models/files.py @@ -10,7 +10,6 @@ class WorldPrompt: flavor: str = "" -# TODO: rename to WorldTemplates @dataclass -class PromptFile: - prompts: List[WorldPrompt] +class TemplateFile: + templates: List[WorldPrompt] diff --git a/taleweave/render/comfy.py b/taleweave/render/comfy.py index 39ea993..5aeaefa 100644 --- a/taleweave/render/comfy.py +++ b/taleweave/render/comfy.py @@ -274,9 +274,9 @@ def render_loop(): ) if isinstance(event, WorldEntity): - title = event.name # TODO: generate a real title + title = event.name else: - title = event.type + title = event.type # TODO: generate a real title broadcast( RenderEvent( @@ -292,7 +292,7 @@ def render_loop(): if isinstance(event, WorldEntity): logger.info("rendering entity %s", event.name) prompt = prompt_from_entity(event) - title = event.name # TODO: generate a real title + title = event.name else: logger.info("rendering event %s", event.id) prompt = prompt_from_event(event) diff --git a/taleweave/server/websocket.py b/taleweave/server/websocket.py index 0b2ce7d..9e7bc05 100644 --- a/taleweave/server/websocket.py +++ b/taleweave/server/websocket.py @@ -63,8 +63,9 @@ async def handler(websocket): await websocket.send( dumps( { + "id": event.id, "type": event.type, - "client": id, # TODO: this should be a field in the PromptEvent + "client": id, # TODO: should this be a field in the PromptEvent? "character": event.character, "prompt": event.prompt, "actions": event.actions, diff --git a/taleweave/systems/digest.py b/taleweave/systems/digest.py index a24d91f..6e3d206 100644 --- a/taleweave/systems/digest.py +++ b/taleweave/systems/digest.py @@ -127,7 +127,7 @@ def format_digest( return "\n".join(digest) -def generate_digest(agent: Any, theme: str, entity: WorldEntity): +def generate_digest(agent: Any, world: World, entity: WorldEntity): if isinstance(entity, Character): if entity.name not in character_buffers: character_buffers[entity.name] = [] diff --git a/taleweave/systems/quest.py b/taleweave/systems/quest.py index 58475b7..adbb159 100644 --- a/taleweave/systems/quest.py +++ b/taleweave/systems/quest.py @@ -172,7 +172,7 @@ def initialize_quests(world: World) -> QuestData: return QuestData(active={}, available={}, completed={}) -def generate_quests(agent: Agent, theme: str, entity: WorldEntity) -> None: +def generate_quests(agent: Agent, world: World, entity: WorldEntity) -> None: """ Generate new quests for the world. """ diff --git a/taleweave/systems/weather/__init__.py b/taleweave/systems/weather/__init__.py index 3d5b368..33d0ab6 100644 --- a/taleweave/systems/weather/__init__.py +++ b/taleweave/systems/weather/__init__.py @@ -76,10 +76,10 @@ def generate_room_weather(agent: Agent, theme: str, entity: Room) -> None: logger.info(f"generated environment for {entity.name}: {environment}") -def generate_weather(agent: Agent, theme: str, entity: WorldEntity) -> None: +def generate_weather(agent: Agent, world: World, entity: WorldEntity) -> None: if isinstance(entity, Room): if "environment" not in entity.attributes: - generate_room_weather(agent, theme, entity) + generate_room_weather(agent, world.theme, entity) def simulate_weather(world: World, turn: int, data: None = None): diff --git a/taleweave/utils/planning.py b/taleweave/utils/planning.py index f2ad36e..c5ea387 100644 --- a/taleweave/utils/planning.py +++ b/taleweave/utils/planning.py @@ -31,9 +31,13 @@ def get_upcoming_events( """ calendar = character.planner.calendar - # TODO: sort events by turn - return [ + upcoming = [ event for event in calendar.events if event.turn - current_turn <= upcoming_turns ] + + # sort by turn + upcoming.sort(key=lambda event: event.turn) + + return upcoming diff --git a/taleweave/utils/prompt.py b/taleweave/utils/prompt.py index 2dd988c..5c4e836 100644 --- a/taleweave/utils/prompt.py +++ b/taleweave/utils/prompt.py @@ -26,5 +26,6 @@ def format_prompt(prompt_key: str, **kwargs) -> str: def format_str(template_str: str, **kwargs) -> str: + # TODO: cache templates template = jinja_env.from_string(template_str) return template.render(**kwargs) diff --git a/taleweave/utils/search.py b/taleweave/utils/search.py index 4b57acc..66b734c 100644 --- a/taleweave/utils/search.py +++ b/taleweave/utils/search.py @@ -13,6 +13,13 @@ from taleweave.models.entity import ( from .string import normalize_name +def get_entity_name(entity: WorldEntity | str) -> str: + if isinstance(entity, str): + return entity + + return normalize_name(entity.name) + + def find_room(world: World, room_name: str) -> Room | None: for room in world.rooms: if normalize_name(room.name) == normalize_name(room_name): @@ -58,61 +65,60 @@ def find_portal_in_room(room: Room, portal_name: str) -> Portal | None: # TODO: allow item or str def find_item( world: World, - item_name: str, + item: Item | str, include_character_inventory=False, include_item_inventory=False, ) -> Item | None: + item_name = get_entity_name(item) + for room in world.rooms: - item = find_item_in_room( + result = find_item_in_room( room, item_name, include_character_inventory, include_item_inventory ) - if item: - return item + if result: + return result return None def find_item_in_character( - character: Character, item_name: str, include_item_inventory=False + character: Character, item: Item | str, include_item_inventory=False ) -> Item | None: - return find_item_in_container(character, item_name, include_item_inventory) + return find_item_in_container(character, item, include_item_inventory) def find_item_in_container( - container: Character | Item, item_name: str, include_item_inventory=False + container: Room | Character | Item, item: Item | str, include_item_inventory=False ) -> Item | None: + item_name = get_entity_name(item) + for item in container.items: if normalize_name(item.name) == normalize_name(item_name): return item if include_item_inventory: - item = find_item_in_container(item, item_name, include_item_inventory) - if item: - return item + result = find_item_in_container(item, item_name, include_item_inventory) + if result: + return result return None def find_item_in_room( room: Room, - item_name: str, + item: Item | str, include_character_inventory=False, include_item_inventory=False, ) -> Item | None: - for item in room.items: - if normalize_name(item.name) == normalize_name(item_name): - return item - - if include_item_inventory: - item = find_item_in_container(item, item_name, include_item_inventory) - if item: - return item + result = find_item_in_container(room, item, include_item_inventory) + if result: + return result if include_character_inventory: for character in room.characters: - item = find_item_in_character(character, item_name, include_item_inventory) - if item: - return item + result = find_item_in_character(character, item, include_item_inventory) + if result: + return result return None diff --git a/worlds.yml b/worlds.yml index 5f48d89..4c9f6e3 100644 --- a/worlds.yml +++ b/worlds.yml @@ -1,4 +1,4 @@ -prompts: +templates: - name: outback-animals theme: talking animal truckers in the Australian outback flavor: create a fun and happy world where rough and tumble talking animals drive trucks and run saloons in the outback @@ -15,12 +15,12 @@ prompts: theme: opening scenes from Jurassic Park flavor: | follow the script of the film Jurassic Park exactly. do not deviate from the script in any way. - include accurate characters and make sure they will fully utilize all of the actions available to them in this world + include accurate characters and instruct them to utilize all of the actions available to them in this world - name: star-wars - theme: opening scenes from Star Wars + theme: opening scenes from the 1977 film Star Wars flavor: | follow the script of the 1977 film Star Wars exactly. do not deviate from the script in any way. - include accurate characters and make sure they will fully utilize all of the actions available to them in this world + include accurate characters and instruct them to fully utilize all of the actions available to them in this world - name: cyberpunk-utopia theme: wealthy cyberpunk utopia with a dark secret flavor: make a strange and dangerous world where technology is pervasive and scarcity is unheard of - for the upper class, at least