1
0
Fork 0

fix types, rule skip logic

This commit is contained in:
Sean Sube 2024-05-05 17:46:24 -05:00
parent 564be90d9f
commit 84499982f0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 76 additions and 35 deletions

View File

@ -9,13 +9,14 @@ current_world: World | None = None
current_room: Room | None = None
current_actor: Actor | None = None
current_step = 0
dungeon_master: Agent | None = None
# TODO: where should this one go?
actor_agents: Dict[str, Tuple[Actor, Agent]] = {}
def get_current_context():
def get_current_context() -> Tuple[World, Room, Actor]:
if not current_world:
raise ValueError(
"The current world must be set before calling action functions"
@ -30,15 +31,15 @@ def get_current_context():
return (current_world, current_room, current_actor)
def get_current_world():
def get_current_world() -> World | None:
return current_world
def get_current_room():
def get_current_room() -> Room | None:
return current_room
def get_current_actor():
def get_current_actor() -> Actor | None:
return current_actor
@ -46,7 +47,7 @@ def get_current_broadcast():
return current_broadcast
def broadcast(message):
def broadcast(message: str):
if current_broadcast:
current_broadcast(message)
@ -56,26 +57,26 @@ def set_current_broadcast(broadcast):
current_broadcast = broadcast
def set_current_world(world):
def set_current_world(world: World | None):
global current_world
current_world = world
def set_current_room(room):
def set_current_room(room: Room | None):
global current_room
current_room = room
def set_current_actor(actor):
def set_current_actor(actor: Actor | None):
global current_actor
current_actor = actor
def get_step():
def get_current_step() -> int:
return current_step
def set_step(step):
def set_current_step(step: int):
global current_step
current_step = step
@ -119,3 +120,17 @@ def set_actor_agent_for_name(name, actor, agent):
def get_all_actor_agents():
return list(actor_agents.values())
def set_dungeon_master(agent):
global dungeon_master
dungeon_master = agent
def get_dungeon_master() -> Agent:
if not dungeon_master:
raise ValueError(
"The dungeon master must be set before calling action functions"
)
return dungeon_master

View File

@ -7,7 +7,15 @@ from pydantic import Field
from rule_engine import Rule
from yaml import Loader, load
from adventure.models import Actor, Item, Room, World, dataclass
from adventure.models import (
Actor,
Attributes,
AttributeValue,
Item,
Room,
World,
dataclass,
)
from adventure.plugins import get_plugin_function
logger = getLogger(__name__)
@ -16,47 +24,45 @@ logger = getLogger(__name__)
@dataclass
class LogicLabel:
backstory: str
description: str
description: str | None = None
@dataclass
class LogicRule:
chance: float = 1.0
group: Optional[str] = None
match: Optional[Dict[str, str]] = None
match: Optional[Attributes] = None
remove: Optional[List[str]] = None
rule: Optional[str] = None
set: Optional[Dict[str, str]] = None
set: Optional[Attributes] = None
trigger: Optional[List[str]] = None
@dataclass
class LogicTable:
rules: List[LogicRule]
labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict)
labels: Dict[str, Dict[AttributeValue, LogicLabel]] = Field(default_factory=dict)
LogicTrigger = Callable[[Room | Actor | Item, Dict[str, str]], Dict[str, str]]
LogicTrigger = Callable[[Room | Actor | Item, Attributes], Attributes]
TriggerTable = Dict[LogicRule, List[LogicTrigger]]
def update_attributes(
entity: Room | Actor | Item,
attributes: Dict[str, str],
attributes: Attributes,
rules: LogicTable,
triggers: TriggerTable,
) -> Dict[str, str]:
) -> Attributes:
entity_type = entity.__class__.__name__.lower()
skip_groups = set()
for rule in rules.rules:
if rule.group:
if rule.group in skip_groups:
logger.debug("skipping logic group: %s", rule.group)
logger.debug("already ran a rule from group %s, skipping", rule.group)
continue
skip_groups.add(rule.group)
typed_attributes = {
**attributes,
"type": entity_type,
@ -83,6 +89,9 @@ def update_attributes(
logger.info("logic skipped by chance: %s", rule.chance)
continue
if rule.group:
skip_groups.add(rule.group)
for key in rule.remove or []:
attributes.pop(key, None)
@ -120,7 +129,7 @@ def update_logic(
logger.info("updated world attributes")
def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> str:
def format_logic(attributes: Attributes, rules: LogicTable, self=True) -> str:
labels = []
for attribute, value in attributes.items():
@ -128,8 +137,10 @@ def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> st
label = rules.labels[attribute][value]
if self:
labels.append(label.backstory)
else:
elif label.description:
labels.append(label.description)
else:
logger.debug("label has no relevant description: %s", label)
if len(labels) > 0:
logger.info("adding labels: %s", labels)

View File

@ -39,12 +39,12 @@ if True:
from adventure.context import (
get_actor_agent_for_name,
get_actor_for_agent,
get_current_step,
get_current_world,
get_step,
set_current_actor,
set_current_room,
set_current_step,
set_current_world,
set_step,
)
from adventure.generate import generate_world
from adventure.models import Actor, Room, World, WorldState
@ -114,7 +114,7 @@ def simulate_world(
# simulate each actor
for i in range(steps):
current_step = get_step()
current_step = get_current_step()
logger.info(f"Simulating step {current_step}")
for actor_name in world.order:
actor, agent = get_actor_agent_for_name(actor_name)
@ -155,8 +155,8 @@ def simulate_world(
"You can take the following actions: {actions}. "
"You can move in the following directions: {directions}. "
"What will you do next? Reply with a JSON function call, calling one of the actions."
"You can only take one action per turn. Pick the most important action and save the rest for later."
"What is your action?"
"You can only perform one action per turn. What is your next action?"
# Pick the most important action and save the rest for later."
),
context={
"actions": action_names,
@ -181,7 +181,7 @@ def simulate_world(
for system_update, _ in systems:
system_update(world, current_step)
set_step(current_step + 1)
set_current_step(current_step + 1)
# main
@ -277,7 +277,7 @@ def main():
with open(world_state_file, "r") as f:
state = WorldState(**load(f))
set_step(state.step)
set_current_step(state.step)
memory = state.memory
world = state.world

View File

@ -9,6 +9,8 @@ else:
Actions = Dict[str, Callable]
AttributeValue = bool | int | str
Attributes = Dict[str, AttributeValue]
@dataclass
@ -16,7 +18,7 @@ class Item:
name: str
description: str
actions: Actions = Field(default_factory=dict)
attributes: Dict[str, str] = Field(default_factory=dict)
attributes: Attributes = Field(default_factory=dict)
@dataclass
@ -26,7 +28,7 @@ class Actor:
description: str
actions: Actions = Field(default_factory=dict)
items: List[Item] = Field(default_factory=list)
attributes: Dict[str, str] = Field(default_factory=dict)
attributes: Attributes = Field(default_factory=dict)
@dataclass
@ -37,7 +39,7 @@ class Room:
items: List[Item] = Field(default_factory=list)
actors: List[Actor] = Field(default_factory=list)
actions: Actions = Field(default_factory=dict)
attributes: Dict[str, str] = Field(default_factory=dict)
attributes: Attributes = Field(default_factory=dict)
@dataclass

View File

@ -117,7 +117,11 @@ class RemotePlayer(BasePlayer):
send_prompt: Callable[[str, str], bool]
def __init__(
self, name: str, backstory: str, send_prompt: Callable[[str, str], bool], fallback_agent = None
self,
name: str,
backstory: str,
send_prompt: Callable[[str, str], bool],
fallback_agent=None,
) -> None:
super().__init__(name, backstory)
self.fallback_agent = fallback_agent

View File

@ -77,6 +77,13 @@ async def handler(websocket):
logger.error(f"Failed to find actor {character_name}")
continue
# prevent any recursive fallback bugs
if isinstance(llm_agent, RemotePlayer):
logger.warning(
"patching recursive fallback for %s", character_name
)
llm_agent = llm_agent.fallback_agent
if character_name in [
player.name for player in characters.values()
]:
@ -84,7 +91,9 @@ async def handler(websocket):
continue
# player_name = data["player"]
player = RemotePlayer(actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent)
player = RemotePlayer(
actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent
)
characters[id] = player
logger.info(f"Client {id} is now character {character_name}")