fix types, rule skip logic
This commit is contained in:
parent
564be90d9f
commit
84499982f0
|
@ -9,13 +9,14 @@ current_world: World | None = None
|
||||||
current_room: Room | None = None
|
current_room: Room | None = None
|
||||||
current_actor: Actor | None = None
|
current_actor: Actor | None = None
|
||||||
current_step = 0
|
current_step = 0
|
||||||
|
dungeon_master: Agent | None = None
|
||||||
|
|
||||||
|
|
||||||
# TODO: where should this one go?
|
# TODO: where should this one go?
|
||||||
actor_agents: Dict[str, Tuple[Actor, Agent]] = {}
|
actor_agents: Dict[str, Tuple[Actor, Agent]] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_current_context():
|
def get_current_context() -> Tuple[World, Room, Actor]:
|
||||||
if not current_world:
|
if not current_world:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The current world must be set before calling action functions"
|
"The current world must be set before calling action functions"
|
||||||
|
@ -30,15 +31,15 @@ def get_current_context():
|
||||||
return (current_world, current_room, current_actor)
|
return (current_world, current_room, current_actor)
|
||||||
|
|
||||||
|
|
||||||
def get_current_world():
|
def get_current_world() -> World | None:
|
||||||
return current_world
|
return current_world
|
||||||
|
|
||||||
|
|
||||||
def get_current_room():
|
def get_current_room() -> Room | None:
|
||||||
return current_room
|
return current_room
|
||||||
|
|
||||||
|
|
||||||
def get_current_actor():
|
def get_current_actor() -> Actor | None:
|
||||||
return current_actor
|
return current_actor
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ def get_current_broadcast():
|
||||||
return current_broadcast
|
return current_broadcast
|
||||||
|
|
||||||
|
|
||||||
def broadcast(message):
|
def broadcast(message: str):
|
||||||
if current_broadcast:
|
if current_broadcast:
|
||||||
current_broadcast(message)
|
current_broadcast(message)
|
||||||
|
|
||||||
|
@ -56,26 +57,26 @@ def set_current_broadcast(broadcast):
|
||||||
current_broadcast = broadcast
|
current_broadcast = broadcast
|
||||||
|
|
||||||
|
|
||||||
def set_current_world(world):
|
def set_current_world(world: World | None):
|
||||||
global current_world
|
global current_world
|
||||||
current_world = world
|
current_world = world
|
||||||
|
|
||||||
|
|
||||||
def set_current_room(room):
|
def set_current_room(room: Room | None):
|
||||||
global current_room
|
global current_room
|
||||||
current_room = room
|
current_room = room
|
||||||
|
|
||||||
|
|
||||||
def set_current_actor(actor):
|
def set_current_actor(actor: Actor | None):
|
||||||
global current_actor
|
global current_actor
|
||||||
current_actor = actor
|
current_actor = actor
|
||||||
|
|
||||||
|
|
||||||
def get_step():
|
def get_current_step() -> int:
|
||||||
return current_step
|
return current_step
|
||||||
|
|
||||||
|
|
||||||
def set_step(step):
|
def set_current_step(step: int):
|
||||||
global current_step
|
global current_step
|
||||||
current_step = step
|
current_step = step
|
||||||
|
|
||||||
|
@ -119,3 +120,17 @@ def set_actor_agent_for_name(name, actor, agent):
|
||||||
|
|
||||||
def get_all_actor_agents():
|
def get_all_actor_agents():
|
||||||
return list(actor_agents.values())
|
return list(actor_agents.values())
|
||||||
|
|
||||||
|
|
||||||
|
def set_dungeon_master(agent):
|
||||||
|
global dungeon_master
|
||||||
|
dungeon_master = agent
|
||||||
|
|
||||||
|
|
||||||
|
def get_dungeon_master() -> Agent:
|
||||||
|
if not dungeon_master:
|
||||||
|
raise ValueError(
|
||||||
|
"The dungeon master must be set before calling action functions"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dungeon_master
|
||||||
|
|
|
@ -7,7 +7,15 @@ from pydantic import Field
|
||||||
from rule_engine import Rule
|
from rule_engine import Rule
|
||||||
from yaml import Loader, load
|
from yaml import Loader, load
|
||||||
|
|
||||||
from adventure.models import Actor, Item, Room, World, dataclass
|
from adventure.models import (
|
||||||
|
Actor,
|
||||||
|
Attributes,
|
||||||
|
AttributeValue,
|
||||||
|
Item,
|
||||||
|
Room,
|
||||||
|
World,
|
||||||
|
dataclass,
|
||||||
|
)
|
||||||
from adventure.plugins import get_plugin_function
|
from adventure.plugins import get_plugin_function
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -16,47 +24,45 @@ logger = getLogger(__name__)
|
||||||
@dataclass
|
@dataclass
|
||||||
class LogicLabel:
|
class LogicLabel:
|
||||||
backstory: str
|
backstory: str
|
||||||
description: str
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LogicRule:
|
class LogicRule:
|
||||||
chance: float = 1.0
|
chance: float = 1.0
|
||||||
group: Optional[str] = None
|
group: Optional[str] = None
|
||||||
match: Optional[Dict[str, str]] = None
|
match: Optional[Attributes] = None
|
||||||
remove: Optional[List[str]] = None
|
remove: Optional[List[str]] = None
|
||||||
rule: Optional[str] = None
|
rule: Optional[str] = None
|
||||||
set: Optional[Dict[str, str]] = None
|
set: Optional[Attributes] = None
|
||||||
trigger: Optional[List[str]] = None
|
trigger: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LogicTable:
|
class LogicTable:
|
||||||
rules: List[LogicRule]
|
rules: List[LogicRule]
|
||||||
labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict)
|
labels: Dict[str, Dict[AttributeValue, LogicLabel]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
LogicTrigger = Callable[[Room | Actor | Item, Dict[str, str]], Dict[str, str]]
|
LogicTrigger = Callable[[Room | Actor | Item, Attributes], Attributes]
|
||||||
TriggerTable = Dict[LogicRule, List[LogicTrigger]]
|
TriggerTable = Dict[LogicRule, List[LogicTrigger]]
|
||||||
|
|
||||||
|
|
||||||
def update_attributes(
|
def update_attributes(
|
||||||
entity: Room | Actor | Item,
|
entity: Room | Actor | Item,
|
||||||
attributes: Dict[str, str],
|
attributes: Attributes,
|
||||||
rules: LogicTable,
|
rules: LogicTable,
|
||||||
triggers: TriggerTable,
|
triggers: TriggerTable,
|
||||||
) -> Dict[str, str]:
|
) -> Attributes:
|
||||||
entity_type = entity.__class__.__name__.lower()
|
entity_type = entity.__class__.__name__.lower()
|
||||||
skip_groups = set()
|
skip_groups = set()
|
||||||
|
|
||||||
for rule in rules.rules:
|
for rule in rules.rules:
|
||||||
if rule.group:
|
if rule.group:
|
||||||
if rule.group in skip_groups:
|
if rule.group in skip_groups:
|
||||||
logger.debug("skipping logic group: %s", rule.group)
|
logger.debug("already ran a rule from group %s, skipping", rule.group)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
skip_groups.add(rule.group)
|
|
||||||
|
|
||||||
typed_attributes = {
|
typed_attributes = {
|
||||||
**attributes,
|
**attributes,
|
||||||
"type": entity_type,
|
"type": entity_type,
|
||||||
|
@ -83,6 +89,9 @@ def update_attributes(
|
||||||
logger.info("logic skipped by chance: %s", rule.chance)
|
logger.info("logic skipped by chance: %s", rule.chance)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if rule.group:
|
||||||
|
skip_groups.add(rule.group)
|
||||||
|
|
||||||
for key in rule.remove or []:
|
for key in rule.remove or []:
|
||||||
attributes.pop(key, None)
|
attributes.pop(key, None)
|
||||||
|
|
||||||
|
@ -120,7 +129,7 @@ def update_logic(
|
||||||
logger.info("updated world attributes")
|
logger.info("updated world attributes")
|
||||||
|
|
||||||
|
|
||||||
def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> str:
|
def format_logic(attributes: Attributes, rules: LogicTable, self=True) -> str:
|
||||||
labels = []
|
labels = []
|
||||||
|
|
||||||
for attribute, value in attributes.items():
|
for attribute, value in attributes.items():
|
||||||
|
@ -128,8 +137,10 @@ def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> st
|
||||||
label = rules.labels[attribute][value]
|
label = rules.labels[attribute][value]
|
||||||
if self:
|
if self:
|
||||||
labels.append(label.backstory)
|
labels.append(label.backstory)
|
||||||
else:
|
elif label.description:
|
||||||
labels.append(label.description)
|
labels.append(label.description)
|
||||||
|
else:
|
||||||
|
logger.debug("label has no relevant description: %s", label)
|
||||||
|
|
||||||
if len(labels) > 0:
|
if len(labels) > 0:
|
||||||
logger.info("adding labels: %s", labels)
|
logger.info("adding labels: %s", labels)
|
||||||
|
|
|
@ -39,12 +39,12 @@ if True:
|
||||||
from adventure.context import (
|
from adventure.context import (
|
||||||
get_actor_agent_for_name,
|
get_actor_agent_for_name,
|
||||||
get_actor_for_agent,
|
get_actor_for_agent,
|
||||||
|
get_current_step,
|
||||||
get_current_world,
|
get_current_world,
|
||||||
get_step,
|
|
||||||
set_current_actor,
|
set_current_actor,
|
||||||
set_current_room,
|
set_current_room,
|
||||||
|
set_current_step,
|
||||||
set_current_world,
|
set_current_world,
|
||||||
set_step,
|
|
||||||
)
|
)
|
||||||
from adventure.generate import generate_world
|
from adventure.generate import generate_world
|
||||||
from adventure.models import Actor, Room, World, WorldState
|
from adventure.models import Actor, Room, World, WorldState
|
||||||
|
@ -114,7 +114,7 @@ def simulate_world(
|
||||||
|
|
||||||
# simulate each actor
|
# simulate each actor
|
||||||
for i in range(steps):
|
for i in range(steps):
|
||||||
current_step = get_step()
|
current_step = get_current_step()
|
||||||
logger.info(f"Simulating step {current_step}")
|
logger.info(f"Simulating step {current_step}")
|
||||||
for actor_name in world.order:
|
for actor_name in world.order:
|
||||||
actor, agent = get_actor_agent_for_name(actor_name)
|
actor, agent = get_actor_agent_for_name(actor_name)
|
||||||
|
@ -155,8 +155,8 @@ def simulate_world(
|
||||||
"You can take the following actions: {actions}. "
|
"You can take the following actions: {actions}. "
|
||||||
"You can move in the following directions: {directions}. "
|
"You can move in the following directions: {directions}. "
|
||||||
"What will you do next? Reply with a JSON function call, calling one of the actions."
|
"What will you do next? Reply with a JSON function call, calling one of the actions."
|
||||||
"You can only take one action per turn. Pick the most important action and save the rest for later."
|
"You can only perform one action per turn. What is your next action?"
|
||||||
"What is your action?"
|
# Pick the most important action and save the rest for later."
|
||||||
),
|
),
|
||||||
context={
|
context={
|
||||||
"actions": action_names,
|
"actions": action_names,
|
||||||
|
@ -181,7 +181,7 @@ def simulate_world(
|
||||||
for system_update, _ in systems:
|
for system_update, _ in systems:
|
||||||
system_update(world, current_step)
|
system_update(world, current_step)
|
||||||
|
|
||||||
set_step(current_step + 1)
|
set_current_step(current_step + 1)
|
||||||
|
|
||||||
|
|
||||||
# main
|
# main
|
||||||
|
@ -277,7 +277,7 @@ def main():
|
||||||
with open(world_state_file, "r") as f:
|
with open(world_state_file, "r") as f:
|
||||||
state = WorldState(**load(f))
|
state = WorldState(**load(f))
|
||||||
|
|
||||||
set_step(state.step)
|
set_current_step(state.step)
|
||||||
|
|
||||||
memory = state.memory
|
memory = state.memory
|
||||||
world = state.world
|
world = state.world
|
||||||
|
|
|
@ -9,6 +9,8 @@ else:
|
||||||
|
|
||||||
|
|
||||||
Actions = Dict[str, Callable]
|
Actions = Dict[str, Callable]
|
||||||
|
AttributeValue = bool | int | str
|
||||||
|
Attributes = Dict[str, AttributeValue]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -16,7 +18,7 @@ class Item:
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
actions: Actions = Field(default_factory=dict)
|
actions: Actions = Field(default_factory=dict)
|
||||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
attributes: Attributes = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -26,7 +28,7 @@ class Actor:
|
||||||
description: str
|
description: str
|
||||||
actions: Actions = Field(default_factory=dict)
|
actions: Actions = Field(default_factory=dict)
|
||||||
items: List[Item] = Field(default_factory=list)
|
items: List[Item] = Field(default_factory=list)
|
||||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
attributes: Attributes = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -37,7 +39,7 @@ class Room:
|
||||||
items: List[Item] = Field(default_factory=list)
|
items: List[Item] = Field(default_factory=list)
|
||||||
actors: List[Actor] = Field(default_factory=list)
|
actors: List[Actor] = Field(default_factory=list)
|
||||||
actions: Actions = Field(default_factory=dict)
|
actions: Actions = Field(default_factory=dict)
|
||||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
attributes: Attributes = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -117,7 +117,11 @@ class RemotePlayer(BasePlayer):
|
||||||
send_prompt: Callable[[str, str], bool]
|
send_prompt: Callable[[str, str], bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, backstory: str, send_prompt: Callable[[str, str], bool], fallback_agent = None
|
self,
|
||||||
|
name: str,
|
||||||
|
backstory: str,
|
||||||
|
send_prompt: Callable[[str, str], bool],
|
||||||
|
fallback_agent=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(name, backstory)
|
super().__init__(name, backstory)
|
||||||
self.fallback_agent = fallback_agent
|
self.fallback_agent = fallback_agent
|
||||||
|
|
|
@ -77,6 +77,13 @@ async def handler(websocket):
|
||||||
logger.error(f"Failed to find actor {character_name}")
|
logger.error(f"Failed to find actor {character_name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# prevent any recursive fallback bugs
|
||||||
|
if isinstance(llm_agent, RemotePlayer):
|
||||||
|
logger.warning(
|
||||||
|
"patching recursive fallback for %s", character_name
|
||||||
|
)
|
||||||
|
llm_agent = llm_agent.fallback_agent
|
||||||
|
|
||||||
if character_name in [
|
if character_name in [
|
||||||
player.name for player in characters.values()
|
player.name for player in characters.values()
|
||||||
]:
|
]:
|
||||||
|
@ -84,7 +91,9 @@ async def handler(websocket):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# player_name = data["player"]
|
# player_name = data["player"]
|
||||||
player = RemotePlayer(actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent)
|
player = RemotePlayer(
|
||||||
|
actor.name, actor.backstory, sync_turn, fallback_agent=llm_agent
|
||||||
|
)
|
||||||
characters[id] = player
|
characters[id] = player
|
||||||
logger.info(f"Client {id} is now character {character_name}")
|
logger.info(f"Client {id} is now character {character_name}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue