1
0
Fork 0
taleweave-ai/adventure/context.py

250 lines
5.5 KiB
Python
Raw Normal View History

from contextlib import contextmanager
2024-05-18 21:58:11 +00:00
from logging import getLogger
from types import UnionType
from typing import (
Any,
2024-05-18 21:58:11 +00:00
Callable,
Dict,
List,
Sequence,
Tuple,
Type,
Union,
get_args,
get_origin,
)
2024-05-02 11:56:57 +00:00
from packit.agent import Agent
from pyee.base import EventEmitter
2024-05-02 11:56:57 +00:00
from adventure.game_system import GameSystem
from adventure.models.entity import Character, Room, World
from adventure.models.event import GameEvent
from adventure.utils.string import normalize_name
2024-05-02 11:56:57 +00:00
2024-05-18 21:58:11 +00:00
logger = getLogger(__name__)
# world context
current_step = 0
2024-05-04 20:35:42 +00:00
current_world: World | None = None
current_room: Room | None = None
current_character: Character | None = None
2024-05-05 22:46:24 +00:00
dungeon_master: Agent | None = None
# game context
event_emitter = EventEmitter()
game_systems: List[GameSystem] = []
system_data: Dict[str, Any] = {}
2024-05-02 11:56:57 +00:00
# TODO: where should this one go?
character_agents: Dict[str, Tuple[Character, Agent]] = {}
2024-05-02 11:56:57 +00:00
STRING_EVENT_TYPE = "message"
def get_event_name(event: GameEvent | Type[GameEvent]):
return f"event:{event.type}"
2024-05-02 11:56:57 +00:00
def broadcast(message: str | GameEvent):
2024-05-18 21:58:11 +00:00
if isinstance(message, GameEvent):
event_name = get_event_name(message)
logger.debug(f"broadcasting {event_name}")
event_emitter.emit(event_name, message)
2024-05-18 21:58:11 +00:00
else:
logger.warning("broadcasting a string message is deprecated")
event_emitter.emit(STRING_EVENT_TYPE, message)
2024-05-18 21:58:11 +00:00
def is_union(type_: Type | UnionType):
origin = get_origin(type_)
return origin is UnionType or origin is Union
def subscribe(
event_type: Type[str] | Type[GameEvent] | UnionType,
callback: Callable[[GameEvent], None],
):
if is_union(event_type):
for t in get_args(event_type):
subscribe(t, callback)
return
if event_type is str:
event_name = STRING_EVENT_TYPE
else:
event_name = get_event_name(event_type)
2024-05-18 21:58:11 +00:00
logger.debug(f"subscribing {callback.__name__} to {event_type}")
event_emitter.on(event_name, callback)
2024-05-09 02:11:16 +00:00
def has_dungeon_master():
return dungeon_master is not None
# region context manager
2024-05-18 21:58:11 +00:00
@contextmanager
def action_context():
room, character = get_action_context()
yield room, character
2024-05-18 21:58:11 +00:00
@contextmanager
def world_context():
world, room, character = get_world_context()
yield world, room, character
# endregion
2024-05-09 02:11:16 +00:00
# region context getters
def get_action_context() -> Tuple[Room, Character]:
if not current_room:
raise ValueError("The current room must be set before calling action functions")
if not current_character:
raise ValueError(
"The current character must be set before calling action functions"
)
return (current_room, current_character)
def get_world_context() -> Tuple[World, Room, Character]:
2024-05-02 11:56:57 +00:00
if not current_world:
raise ValueError(
"The current world must be set before calling action functions"
)
if not current_room:
raise ValueError("The current room must be set before calling action functions")
if not current_character:
2024-05-02 11:56:57 +00:00
raise ValueError(
"The current character must be set before calling action functions"
2024-05-02 11:56:57 +00:00
)
return (current_world, current_room, current_character)
2024-05-02 11:56:57 +00:00
2024-05-05 22:46:24 +00:00
def get_current_world() -> World | None:
2024-05-02 11:56:57 +00:00
return current_world
2024-05-05 22:46:24 +00:00
def get_current_room() -> Room | None:
2024-05-02 11:56:57 +00:00
return current_room
def get_current_character() -> Character | None:
return current_character
2024-05-02 11:56:57 +00:00
2024-05-09 02:11:16 +00:00
def get_current_step() -> int:
return current_step
2024-05-04 20:35:42 +00:00
2024-05-09 02:11:16 +00:00
def get_dungeon_master() -> Agent:
if not dungeon_master:
raise ValueError(
"The dungeon master must be set before calling action functions"
)
return dungeon_master
def get_game_systems() -> List[GameSystem]:
return game_systems
def get_system_data(system: str) -> Any | None:
return system_data.get(system)
2024-05-09 02:11:16 +00:00
# endregion
# region context setters
2024-05-05 22:46:24 +00:00
def set_current_world(world: World | None):
2024-05-02 11:56:57 +00:00
global current_world
current_world = world
2024-05-05 22:46:24 +00:00
def set_current_room(room: Room | None):
2024-05-02 11:56:57 +00:00
global current_room
current_room = room
def set_current_character(character: Character | None):
global current_character
current_character = character
2024-05-02 11:56:57 +00:00
2024-05-05 22:46:24 +00:00
def set_current_step(step: int):
2024-05-02 11:56:57 +00:00
global current_step
current_step = step
def set_character_agent(name, character, agent):
character_agents[name] = (character, agent)
2024-05-09 02:11:16 +00:00
def set_dungeon_master(agent):
global dungeon_master
dungeon_master = agent
def set_game_systems(systems: Sequence[GameSystem]):
global game_systems
game_systems = list(systems)
def set_system_data(system: str, data: Any):
system_data[system] = data
2024-05-09 02:11:16 +00:00
# endregion
# region search functions
def get_character_for_agent(agent: Agent) -> Character | None:
2024-05-02 11:56:57 +00:00
return next(
(
inner_character
for inner_character, inner_agent in character_agents.values()
2024-05-02 11:56:57 +00:00
if inner_agent == agent
),
None,
)
def get_agent_for_character(character: Character) -> Agent | None:
2024-05-02 11:56:57 +00:00
return next(
(
inner_agent
for inner_character, inner_agent in character_agents.values()
if inner_character == character
2024-05-02 11:56:57 +00:00
),
None,
)
def get_character_agent_for_name(
name: str,
) -> Tuple[Character, Agent] | Tuple[None, None]:
2024-05-02 11:56:57 +00:00
return next(
(
(character, agent)
for character, agent in character_agents.values()
if normalize_name(character.name) == normalize_name(name)
2024-05-02 11:56:57 +00:00
),
(None, None),
)
def get_all_character_agents():
return list(character_agents.values())
2024-05-05 22:46:24 +00:00
2024-05-09 02:11:16 +00:00
# endregion