diff --git a/adventure/context.py b/adventure/context.py index 50be354..ca069d8 100644 --- a/adventure/context.py +++ b/adventure/context.py @@ -9,13 +9,14 @@ current_world: World | None = None current_room: Room | None = None current_actor: Actor | None = None current_step = 0 +dungeon_master: Agent | None = None # TODO: where should this one go? actor_agents: Dict[str, Tuple[Actor, Agent]] = {} -def get_current_context(): +def get_current_context() -> Tuple[World, Room, Actor]: if not current_world: raise ValueError( "The current world must be set before calling action functions" @@ -30,15 +31,15 @@ def get_current_context(): return (current_world, current_room, current_actor) -def get_current_world(): +def get_current_world() -> World | None: return current_world -def get_current_room(): +def get_current_room() -> Room | None: return current_room -def get_current_actor(): +def get_current_actor() -> Actor | None: return current_actor @@ -46,7 +47,7 @@ def get_current_broadcast(): return current_broadcast -def broadcast(message): +def broadcast(message: str): if current_broadcast: current_broadcast(message) @@ -56,26 +57,26 @@ def set_current_broadcast(broadcast): current_broadcast = broadcast -def set_current_world(world): +def set_current_world(world: World | None): global current_world current_world = world -def set_current_room(room): +def set_current_room(room: Room | None): global current_room current_room = room -def set_current_actor(actor): +def set_current_actor(actor: Actor | None): global current_actor current_actor = actor -def get_step(): +def get_current_step() -> int: return current_step -def set_step(step): +def set_current_step(step: int): global current_step current_step = step @@ -119,3 +120,17 @@ def set_actor_agent_for_name(name, actor, agent): def get_all_actor_agents(): return list(actor_agents.values()) + + +def set_dungeon_master(agent): + global dungeon_master + dungeon_master = agent + + +def get_dungeon_master() -> Agent: + if not dungeon_master: + raise ValueError( + "The dungeon master must be set before calling action functions" + ) + + return dungeon_master diff --git a/adventure/logic.py b/adventure/logic.py index 07e1c48..1b8c927 100644 --- a/adventure/logic.py +++ b/adventure/logic.py @@ -7,7 +7,15 @@ from pydantic import Field from rule_engine import Rule from yaml import Loader, load -from adventure.models import Actor, Item, Room, World, dataclass +from adventure.models import ( + Actor, + Attributes, + AttributeValue, + Item, + Room, + World, + dataclass, +) from adventure.plugins import get_plugin_function logger = getLogger(__name__) @@ -16,47 +24,45 @@ logger = getLogger(__name__) @dataclass class LogicLabel: backstory: str - description: str + description: str | None = None @dataclass class LogicRule: chance: float = 1.0 group: Optional[str] = None - match: Optional[Dict[str, str]] = None + match: Optional[Attributes] = None remove: Optional[List[str]] = None rule: Optional[str] = None - set: Optional[Dict[str, str]] = None + set: Optional[Attributes] = None trigger: Optional[List[str]] = None @dataclass class LogicTable: rules: List[LogicRule] - labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict) + labels: Dict[str, Dict[AttributeValue, LogicLabel]] = Field(default_factory=dict) -LogicTrigger = Callable[[Room | Actor | Item, Dict[str, str]], Dict[str, str]] +LogicTrigger = Callable[[Room | Actor | Item, Attributes], Attributes] TriggerTable = Dict[LogicRule, List[LogicTrigger]] def update_attributes( entity: Room | Actor | Item, - attributes: Dict[str, str], + attributes: Attributes, rules: LogicTable, triggers: TriggerTable, -) -> Dict[str, str]: +) -> Attributes: entity_type = entity.__class__.__name__.lower() skip_groups = set() for rule in rules.rules: if rule.group: if rule.group in skip_groups: - logger.debug("skipping logic group: %s", rule.group) + logger.debug("already ran a rule from group %s, skipping", rule.group) continue - skip_groups.add(rule.group) - typed_attributes = { **attributes, "type": entity_type, @@ -83,6 +89,9 @@ def update_attributes( logger.info("logic skipped by chance: %s", rule.chance) continue + if rule.group: + skip_groups.add(rule.group) + for key in rule.remove or []: attributes.pop(key, None) @@ -120,7 +129,7 @@ def update_logic( logger.info("updated world attributes") -def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> str: +def format_logic(attributes: Attributes, rules: LogicTable, self=True) -> str: labels = [] for attribute, value in attributes.items(): @@ -128,8 +137,10 @@ def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> st label = rules.labels[attribute][value] if self: labels.append(label.backstory) - else: + elif label.description: labels.append(label.description) + else: + logger.debug("label has no relevant description: %s", label) if len(labels) > 0: logger.info("adding labels: %s", labels) diff --git a/adventure/main.py b/adventure/main.py index d7546f3..3e7c067 100644 --- a/adventure/main.py +++ b/adventure/main.py @@ -39,12 +39,12 @@ if True: from adventure.context import ( get_actor_agent_for_name, get_actor_for_agent, + get_current_step, get_current_world, - get_step, set_current_actor, set_current_room, + set_current_step, set_current_world, - set_step, ) from adventure.generate import generate_world from adventure.models import Actor, Room, World, WorldState @@ -114,7 +114,7 @@ def simulate_world( # simulate each actor for i in range(steps): - current_step = get_step() + current_step = get_current_step() logger.info(f"Simulating step {current_step}") for actor_name in world.order: actor, agent = get_actor_agent_for_name(actor_name) @@ -155,8 +155,8 @@ def simulate_world( "You can take the following actions: {actions}. " "You can move in the following directions: {directions}. " "What will you do next? Reply with a JSON function call, calling one of the actions." - "You can only take one action per turn. Pick the most important action and save the rest for later." - "What is your action?" + "You can only perform one action per turn. What is your next action?" + # Pick the most important action and save the rest for later." ), context={ "actions": action_names, @@ -181,7 +181,7 @@ def simulate_world( for system_update, _ in systems: system_update(world, current_step) - set_step(current_step + 1) + set_current_step(current_step + 1) # main @@ -277,7 +277,7 @@ def main(): with open(world_state_file, "r") as f: state = WorldState(**load(f)) - set_step(state.step) + set_current_step(state.step) memory = state.memory world = state.world diff --git a/adventure/models.py b/adventure/models.py index 9aad3e4..86b8bd8 100644 --- a/adventure/models.py +++ b/adventure/models.py @@ -9,6 +9,8 @@ else: Actions = Dict[str, Callable] +AttributeValue = bool | int | str +Attributes = Dict[str, AttributeValue] @dataclass @@ -16,7 +18,7 @@ class Item: name: str description: str actions: Actions = Field(default_factory=dict) - attributes: Dict[str, str] = Field(default_factory=dict) + attributes: Attributes = Field(default_factory=dict) @dataclass @@ -26,7 +28,7 @@ class Actor: description: str actions: Actions = Field(default_factory=dict) items: List[Item] = Field(default_factory=list) - attributes: Dict[str, str] = Field(default_factory=dict) + attributes: Attributes = Field(default_factory=dict) @dataclass @@ -37,7 +39,7 @@ class Room: items: List[Item] = Field(default_factory=list) actors: List[Actor] = Field(default_factory=list) actions: Actions = Field(default_factory=dict) - attributes: Dict[str, str] = Field(default_factory=dict) + attributes: Attributes = Field(default_factory=dict) @dataclass diff --git a/adventure/player.py b/adventure/player.py index 3f6c9a6..721e1b5 100644 --- a/adventure/player.py +++ b/adventure/player.py @@ -117,7 +117,11 @@ class RemotePlayer(BasePlayer): send_prompt: Callable[[str, str], bool] def __init__( - self, name: str, backstory: str, send_prompt: Callable[[str, str], bool], fallback_agent = None + self, + name: str, + backstory: str, + send_prompt: Callable[[str, str], bool], + fallback_agent=None, ) -> None: super().__init__(name, backstory) self.fallback_agent = fallback_agent diff --git a/adventure/server.py b/adventure/server.py index e0a25c1..8235fd2 100644 --- a/adventure/server.py +++ b/adventure/server.py @@ -77,6 +77,13 @@ async def handler(websocket): logger.error(f"Failed to find actor {character_name}") continue + # prevent any recursive fallback bugs + if isinstance(llm_agent, RemotePlayer): + logger.warning( + "patching recursive fallback for %s", character_name + ) + llm_agent = llm_agent.fallback_agent + if character_name in [ player.name for player in characters.values() ]: @@ -84,7 +91,9 @@ async def handler(websocket): continue # player_name = data["player"] - player = RemotePlayer(actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent) + player = RemotePlayer( + actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent + ) characters[id] = player logger.info(f"Client {id} is now character {character_name}")