1
0
Fork 0
taleweave-ai/adventure/logic.py

169 lines
4.7 KiB
Python
Raw Normal View History

2024-05-08 01:40:53 +00:00
from functools import partial, wraps
2024-05-05 14:14:54 +00:00
from logging import getLogger
from random import random
from typing import Callable, Dict, List, Optional
from pydantic import Field
from rule_engine import Rule
2024-05-05 14:14:54 +00:00
from yaml import Loader, load
from adventure.game_system import FormatPerspective, GameSystem
2024-05-19 21:49:02 +00:00
from adventure.models.entity import Attributes, World, WorldEntity, dataclass
2024-05-05 14:14:54 +00:00
from adventure.plugins import get_plugin_function
logger = getLogger(__name__)
@dataclass
class LogicLabel:
backstory: str
2024-05-05 22:46:24 +00:00
description: str | None = None
2024-05-19 21:49:02 +00:00
match: Optional[Attributes] = None
rule: Optional[str] = None
2024-05-05 14:14:54 +00:00
@dataclass
class LogicRule:
chance: float = 1.0
group: Optional[str] = None
2024-05-05 22:46:24 +00:00
match: Optional[Attributes] = None
2024-05-05 14:14:54 +00:00
remove: Optional[List[str]] = None
rule: Optional[str] = None
2024-05-05 22:46:24 +00:00
set: Optional[Attributes] = None
2024-05-05 14:14:54 +00:00
trigger: Optional[List[str]] = None
@dataclass
class LogicTable:
2024-05-19 21:49:02 +00:00
rules: List[LogicRule] = Field(default_factory=list)
labels: List[LogicLabel] = Field(default_factory=list)
2024-05-05 14:14:54 +00:00
LogicTrigger = Callable[[WorldEntity], None]
2024-05-08 01:40:53 +00:00
TriggerTable = Dict[str, LogicTrigger]
2024-05-05 14:14:54 +00:00
2024-05-19 21:49:02 +00:00
def match_logic(
entity: WorldEntity,
matcher: LogicLabel | LogicRule,
) -> bool:
typed_attributes = {
**entity.attributes,
"type": entity.type,
}
if matcher.rule:
# TODO: pre-compile rules
rule_impl = Rule(matcher.rule)
if not rule_impl.matches(
{
"attributes": typed_attributes,
}
):
logger.debug("logic rule did not match attributes: %s", matcher.rule)
return False
if matcher.match and not (matcher.match.items() <= typed_attributes.items()):
logger.debug("logic did not match attributes: %s", matcher.match)
return False
return True
2024-05-05 14:14:54 +00:00
def update_attributes(
entity: WorldEntity,
2024-05-05 14:14:54 +00:00
rules: LogicTable,
triggers: TriggerTable,
) -> None:
2024-05-05 14:14:54 +00:00
skip_groups = set()
for rule in rules.rules:
if rule.group:
if rule.group in skip_groups:
2024-05-05 22:46:24 +00:00
logger.debug("already ran a rule from group %s, skipping", rule.group)
2024-05-05 14:14:54 +00:00
continue
2024-05-19 21:49:02 +00:00
if not match_logic(entity, rule):
2024-05-05 14:14:54 +00:00
continue
logger.info("matched logic: %s", rule.match)
if rule.chance < 1:
if random() > rule.chance:
logger.info("logic skipped by chance: %s", rule.chance)
continue
2024-05-05 22:46:24 +00:00
if rule.group:
skip_groups.add(rule.group)
2024-05-05 14:14:54 +00:00
for key in rule.remove or []:
entity.attributes.pop(key, None)
2024-05-05 14:14:54 +00:00
if rule.set:
entity.attributes.update(rule.set)
2024-05-05 14:14:54 +00:00
logger.info("logic set state: %s", rule.set)
2024-05-08 01:40:53 +00:00
if rule.trigger:
for trigger in rule.trigger:
if trigger in triggers:
triggers[trigger](entity)
2024-05-05 14:14:54 +00:00
def update_logic(
world: World, step: int, rules: LogicTable, triggers: TriggerTable
) -> None:
2024-05-05 14:14:54 +00:00
for room in world.rooms:
update_attributes(room, rules=rules, triggers=triggers)
2024-05-05 14:14:54 +00:00
for actor in room.actors:
update_attributes(actor, rules=rules, triggers=triggers)
2024-05-05 14:14:54 +00:00
for item in actor.items:
update_attributes(item, rules=rules, triggers=triggers)
2024-05-05 14:14:54 +00:00
for item in room.items:
update_attributes(item, rules=rules, triggers=triggers)
2024-05-05 14:14:54 +00:00
logger.info("updated world attributes")
def format_logic(
entity: WorldEntity,
rules: LogicTable,
perspective: FormatPerspective = FormatPerspective.SECOND_PERSON,
) -> str:
2024-05-05 14:14:54 +00:00
labels = []
2024-05-19 21:49:02 +00:00
for label in rules.labels:
if match_logic(entity, label):
if perspective == FormatPerspective.SECOND_PERSON:
2024-05-05 14:14:54 +00:00
labels.append(label.backstory)
elif perspective == FormatPerspective.THIRD_PERSON and label.description:
2024-05-05 14:14:54 +00:00
labels.append(label.description)
2024-05-05 22:46:24 +00:00
else:
logger.debug("label has no relevant description: %s", label)
2024-05-05 14:14:54 +00:00
if len(labels) > 0:
logger.debug("adding attribute labels: %s", labels)
2024-05-05 14:14:54 +00:00
return " ".join(labels)
def load_logic(filename: str):
2024-05-05 14:14:54 +00:00
logger.info("loading logic from file: %s", filename)
with open(filename) as file:
logic_rules = LogicTable(**load(file, Loader=Loader))
2024-05-08 01:40:53 +00:00
logic_triggers = {}
for rule in logic_rules.rules:
if rule.trigger:
for trigger in rule.trigger:
logic_triggers[trigger] = get_plugin_function(trigger)
2024-05-05 14:14:54 +00:00
logger.info("initialized logic system")
system_simulate = wraps(update_logic)(
partial(update_logic, rules=logic_rules, triggers=logic_triggers)
)
system_format = wraps(format_logic)(partial(format_logic, rules=logic_rules))
return GameSystem(
format=system_format,
simulate=system_simulate,
2024-05-05 14:14:54 +00:00
)