1
0
Fork 0

add use action and remote players

This commit is contained in:
Sean Sube 2024-05-05 09:14:54 -05:00
parent 16525ac635
commit f15390bd72
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 334 additions and 137 deletions

View File

@ -17,8 +17,14 @@ OPPOSITE_DIRECTIONS = {
}
GenerateCallback = Callable[[str], None]
def generate_room(
agent: Agent, world_theme: str, existing_rooms: List[str], callback
agent: Agent,
world_theme: str,
callback: GenerateCallback | None = None,
existing_rooms: List[str] = [],
) -> Room:
def unique_name(name: str, **kwargs):
if name in existing_rooms:
@ -37,7 +43,10 @@ def generate_room(
},
result_parser=unique_name,
)
if callable(callback):
callback(f"Generating room: {name}")
desc = agent(
"Generate a detailed description of the {name} area. What does it look like? "
"What does it smell like? What can be seen or heard?",
@ -56,10 +65,10 @@ def generate_room(
def generate_item(
agent: Agent,
world_theme: str,
existing_items: List[str],
callback,
callback: Callable[[str], None] | None = None,
dest_room: str | None = None,
dest_actor: str | None = None,
existing_items: List[str] = [],
) -> Item:
if dest_actor:
dest_note = "The item will be held by the {dest_actor} character"
@ -87,7 +96,10 @@ def generate_item(
},
result_parser=unique_name,
)
if callable(callback):
callback(f"Generating item: {name}")
desc = agent(
"Generate a detailed description of the {name} item. What does it look like? What is it made of? What does it do?",
name=name,
@ -99,7 +111,11 @@ def generate_item(
def generate_actor(
agent: Agent, world_theme: str, dest_room: str, existing_actors: List[str], callback
agent: Agent,
world_theme: str,
dest_room: str,
callback: GenerateCallback | None = None,
existing_actors: List[str] = [],
) -> Actor:
def unique_name(name: str, **kwargs):
if name in existing_actors:
@ -120,7 +136,10 @@ def generate_actor(
},
result_parser=unique_name,
)
if callable(callback):
callback(f"Generating actor: {name}")
description = agent(
"Generate a detailed description of the {name} character. What do they look like? What are they wearing? "
"What are they doing? Describe their appearance from the perspective of an outside observer."
@ -147,9 +166,11 @@ def generate_world(
theme: str,
room_count: int | None = None,
max_rooms: int = 5,
callback: Callable[[str], None] = lambda x: None,
callback: Callable[[str], None] | None = None,
) -> World:
room_count = room_count or randint(3, max_rooms)
if callable(callback):
callback(f"Generating a {theme} with {room_count} rooms")
existing_actors: List[str] = []
@ -159,11 +180,15 @@ def generate_world(
# generate the rooms
rooms = []
for i in range(room_count):
room = generate_room(agent, theme, existing_rooms, callback=callback)
room = generate_room(
agent, theme, existing_rooms=existing_rooms, callback=callback
)
rooms.append(room)
existing_rooms.append(room.name)
item_count = randint(0, 3)
if callable(callback):
callback(f"Generating {item_count} items for room: {room.name}")
for j in range(item_count):
@ -178,6 +203,8 @@ def generate_world(
existing_items.append(item.name)
actor_count = randint(0, 3)
if callable(callback):
callback(f"Generating {actor_count} actors for room: {room.name}")
for j in range(actor_count):
@ -193,6 +220,8 @@ def generate_world(
# generate the actor's inventory
item_count = randint(0, 3)
if callable(callback):
callback(f"Generating {item_count} items for actor {actor.name}")
for k in range(item_count):

142
adventure/logic.py Normal file
View File

@ -0,0 +1,142 @@
from logging import getLogger
from random import random
from typing import Callable, Dict, List, Optional
from functools import partial
from rule_engine import Rule
from pydantic import Field
from yaml import Loader, load
from adventure.models import Actor, Item, Room, World, dataclass
from adventure.plugins import get_plugin_function
logger = getLogger(__name__)
@dataclass
class LogicLabel:
backstory: str
description: str
@dataclass
class LogicRule:
chance: float = 1.0
group: Optional[str] = None
match: Optional[Dict[str, str]] = None
remove: Optional[List[str]] = None
rule: Optional[str] = None
set: Optional[Dict[str, str]] = None
trigger: Optional[List[str]] = None
@dataclass
class LogicTable:
rules: List[LogicRule]
labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict)
LogicTrigger = Callable[[Room | Actor | Item, Dict[str, str]], Dict[str, str]]
TriggerTable = Dict[LogicRule, List[LogicTrigger]]
def update_attributes(
entity: Room | Actor | Item,
attributes: Dict[str, str],
rules: LogicTable,
triggers: TriggerTable,
) -> Dict[str, str]:
entity_type = entity.__class__.__name__.lower()
skip_groups = set()
for rule in rules.rules:
if rule.group:
if rule.group in skip_groups:
logger.debug("skipping logic group: %s", rule.group)
continue
skip_groups.add(rule.group)
typed_attributes = {
**attributes,
"type": entity_type,
}
if rule.rule:
# TODO: pre-compile rules
rule_impl = Rule(rule.rule)
if not rule_impl.matches({
"attributes": typed_attributes,
}):
logger.debug("logic rule did not match attributes: %s", rule.rule)
continue
if rule.match and not(rule.match.items() <= typed_attributes.items()):
logger.debug("logic did not match attributes: %s", rule.match)
continue
logger.info("matched logic: %s", rule.match)
if rule.chance < 1:
if random() > rule.chance:
logger.info("logic skipped by chance: %s", rule.chance)
continue
for key in rule.remove or []:
attributes.pop(key, None)
if rule.set:
attributes.update(rule.set)
logger.info("logic set state: %s", rule.set)
if rule in triggers:
for trigger in triggers[rule]:
attributes = trigger(entity, attributes)
return attributes
def update_logic(world: World, step: int, rules: LogicTable, triggers: TriggerTable) -> None:
for room in world.rooms:
room.attributes = update_attributes(room, room.attributes, rules=rules, triggers=triggers)
for actor in room.actors:
actor.attributes = update_attributes(actor, actor.attributes, rules=rules, triggers=triggers)
for item in actor.items:
item.attributes = update_attributes(item, item.attributes, rules=rules, triggers=triggers)
for item in room.items:
item.attributes = update_attributes(item, item.attributes, rules=rules, triggers=triggers)
logger.info("updated world attributes")
def format_logic(attributes: Dict[str, str], rules: LogicTable, self=True) -> str:
labels = []
for attribute, value in attributes.items():
if attribute in rules.labels and value in rules.labels[attribute]:
label = rules.labels[attribute][value]
if self:
labels.append(label.backstory)
else:
labels.append(label.description)
if len(labels) > 0:
logger.info("adding labels: %s", labels)
return " ".join(labels)
def init_from_file(filename: str):
logger.info("loading logic from file: %s", filename)
with open(filename) as file:
logic_rules = LogicTable(**load(file, Loader=Loader))
logic_triggers = {
rule: [get_plugin_function(trigger) for trigger in rule.trigger]
for rule in logic_rules.rules
if rule.trigger
}
logger.info("initialized logic system")
return (
partial(update_logic, rules=logic_rules, triggers=logic_triggers),
partial(format_logic, rules=logic_rules)
)

View File

@ -228,7 +228,7 @@ def parse_args():
"--systems",
type=str,
nargs="*",
help="Extra logic systems to run in the simulation",
help="Extra systems to run in the simulation",
)
parser.add_argument(
"--theme", type=str, default="fantasy", help="The theme of the generated world"

View File

@ -3,13 +3,15 @@ from typing import Callable, List
from packit.agent import Agent, agent_easy_connect
from adventure.context import broadcast, get_current_context
from adventure.context import broadcast, get_agent_for_actor, get_current_context
from adventure.generate import OPPOSITE_DIRECTIONS, generate_item, generate_room
logger = getLogger(__name__)
llm = agent_easy_connect()
# TODO: provide dungeon master with the world theme
dungeon_master = Agent(
"dungeon master",
"You are the dungeon master in charge of a fantasy world.",
@ -37,7 +39,7 @@ def action_explore(direction: str) -> str:
existing_rooms = [room.name for room in current_world.rooms]
new_room = generate_room(
dungeon_master, current_world.theme, existing_rooms, callback=lambda x: x
dungeon_master, current_world.theme, existing_rooms=existing_rooms
)
current_world.rooms.append(new_room)
@ -68,7 +70,6 @@ def action_search() -> str:
action_world.theme,
existing_items=existing_items,
dest_room=action_room.name,
callback=lambda x: x,
)
action_room.items.append(new_item)
@ -78,6 +79,46 @@ def action_search() -> str:
return f"You search the room and find a new item: {new_item.name}"
def action_use(item: str, target: str) -> str:
"""
Use an item on yourself or another character in the room.
Args:
item: The name of the item to use.
target: The name of the character to use the item on, or "self" to use the item on yourself.
"""
_, action_room, action_actor = get_current_context()
available_items = [item.name for item in action_actor.items] + [item.name for item in action_room.items]
if item not in available_items:
return f"The {item} item is not available to use."
if target == "self":
target_actor = action_actor
target = action_actor.name
else:
target_actor = next(
(actor for actor in action_room.actors if actor.name == target), None
)
if not target_actor:
return f"The {target} character is not in the room."
broadcast(f"{action_actor.name} uses {item} on {target}")
outcome = dungeon_master(
f"{action_actor.name} uses {item} on {target}. {action_actor.description}. {target_actor.description}. What happens? How does {target} react? "
"Specify the outcome of the action. Do not include the question or any JSON. Only include the outcome of the action."
)
broadcast(f"The action resulted in: {outcome}")
# make sure both agents remember the outcome
target_agent = get_agent_for_actor(target_actor)
if target_agent:
target_agent.memory.append(outcome)
return outcome
def init() -> List[Callable]:
"""
Initialize the custom actions.
@ -85,4 +126,5 @@ def init() -> List[Callable]:
return [
action_explore,
action_search,
action_use,
]

View File

@ -1,12 +1,13 @@
from json import dumps
from readline import add_history
from typing import Any, Dict, List, Sequence
from queue import Queue
from typing import Any, Callable, Dict, List, Sequence
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from packit.utils import could_be_json
class LocalPlayer:
class BasePlayer:
"""
A human agent that can interact with the world.
"""
@ -40,18 +41,7 @@ class LocalPlayer:
return self(prompt, **context)
def __call__(self, prompt: str, **kwargs) -> str:
"""
Ask the player for input.
"""
formatted_prompt = prompt.format(**kwargs)
self.memory.append(HumanMessage(content=formatted_prompt))
print(formatted_prompt)
reply = input(">>> ")
reply = reply.strip()
def parse_input(self, reply: str):
# if the reply starts with a tilde, it is a literal response and should be returned without the tilde
if reply.startswith("~"):
reply = reply[1:]
@ -94,3 +84,50 @@ class LocalPlayer:
)
self.memory.append(AIMessage(content=reply_json))
return reply_json
def __call__(self, prompt: str, **kwargs) -> str:
raise NotImplementedError("Subclasses must implement this method")
class LocalPlayer(BasePlayer):
def __call__(self, prompt: str, **kwargs) -> str:
"""
Ask the player for input.
"""
formatted_prompt = prompt.format(**kwargs)
self.memory.append(HumanMessage(content=formatted_prompt))
print(formatted_prompt)
reply = input(">>> ")
reply = reply.strip()
return self.parse_input(reply)
class RemotePlayer(BasePlayer):
input_queue: Queue[str]
send_prompt: Callable[[str, str], bool]
def __init__(self, name: str, backstory: str, send_prompt: Callable[[str, str], bool]) -> None:
super().__init__(name, backstory)
self.input_queue = Queue()
self.send_prompt = send_prompt
def __call__(self, prompt: str, **kwargs) -> str:
"""
Ask the player for input.
"""
formatted_prompt = prompt.format(**kwargs)
self.memory.append(HumanMessage(content=formatted_prompt))
try:
if self.send_prompt(self.name, formatted_prompt):
reply = self.input_queue.get(timeout=60)
return self.parse_input(reply)
except Exception:
pass
# logger.warning("Failed to send prompt to remote player")
return ""

View File

@ -1,18 +1,22 @@
import asyncio
from collections import deque
from json import dumps
from json import dumps, loads
from logging import getLogger
from threading import Thread
from typing import Dict, Tuple
import websockets
from adventure.context import get_actor_agent_for_name
from adventure.models import Actor, Room, World
from adventure.player import RemotePlayer
from adventure.state import snapshot_world, world_json
logger = getLogger(__name__)
connected = set()
recent_events = deque(maxlen=10)
characters: Dict[str, RemotePlayer] = {}
recent_events = deque(maxlen=100)
recent_world = None
@ -20,6 +24,20 @@ async def handler(websocket):
logger.info("Client connected")
connected.add(websocket)
async def next_turn(character: str, prompt: str) -> None:
await websocket.send(connected, dumps({
"type": "turn",
"character": character,
"prompt": prompt,
}))
def sync_turn(character: str, prompt: str) -> bool:
if websocket not in characters:
return False
asyncio.run(next_turn(character, prompt))
return True
try:
if recent_world:
await websocket.send(recent_world)
@ -31,12 +49,44 @@ async def handler(websocket):
while True:
try:
# if this socket is attached to a character and that character's turn is active, wait for input
message = await websocket.recv()
print(message)
logger.info(f"Received message: {message}")
try:
data = loads(message)
if "become" in data:
character = characters.get(websocket)
if character:
del characters[websocket]
character_name = data["become"]
actor, _ = get_actor_agent_for_name(character_name)
if not actor:
logger.error(f"Failed to find actor {character_name}")
continue
if character_name in [player.name for player in characters.values()]:
logger.error(f"Character {character_name} is already in use")
continue
characters[websocket] = RemotePlayer(actor.name, actor.backstory, sync_turn)
logger.info(f"Client {websocket} is now character {character_name}")
elif websocket in characters:
player = characters[websocket]
player.input_queue.put(message)
except Exception:
logger.exception("Failed to parse message")
except websockets.ConnectionClosedOK:
break
connected.remove(websocket)
# TODO: swap out the character for the original agent
if websocket in characters:
del characters[websocket]
logger.info("Client disconnected")

View File

@ -1,103 +0,0 @@
from logging import getLogger
from random import random
from typing import Dict, List, Optional
from pydantic import Field
from yaml import Loader, load
from adventure.models import Actor, Item, Room, World, dataclass
from adventure.plugins import get_plugin_function
logger = getLogger(__name__)
@dataclass
class LogicLabel:
backstory: str
description: str
@dataclass
class LogicRule:
match: Dict[str, str]
chance: float = 1.0
remove: Optional[List[str]] = None
set: Optional[Dict[str, str]] = None
trigger: Optional[List[str]] = None
@dataclass
class LogicTable:
rules: List[LogicRule]
labels: Dict[str, Dict[str, LogicLabel]] = Field(default_factory=dict)
with open("./worlds/logic.yaml") as file:
logic_rules = LogicTable(**load(file, Loader=Loader))
logic_triggers = {
rule: [get_plugin_function(trigger) for trigger in rule.trigger]
for rule in logic_rules.rules
if rule.trigger
}
def update_attributes(
entity: Room | Actor | Item,
attributes: Dict[str, str],
dataset: LogicTable,
) -> Dict[str, str]:
for rule in dataset.rules:
if rule.match.items() <= attributes.items():
logger.info("matched logic: %s", rule.match)
if rule.chance < 1:
if random() > rule.chance:
logger.info("logic skipped by chance: %s", rule.chance)
continue
if rule.set:
attributes.update(rule.set)
logger.info("logic set state: %s", rule.set)
for key in rule.remove or []:
attributes.pop(key, None)
if rule in logic_triggers:
for trigger in logic_triggers[rule]:
attributes = trigger(entity, attributes)
return attributes
def update_logic(world: World, step: int) -> None:
for room in world.rooms:
room.attributes = update_attributes(room, room.attributes, logic_rules)
for actor in room.actors:
actor.attributes = update_attributes(actor, actor.attributes, logic_rules)
for item in actor.items:
item.attributes = update_attributes(item, item.attributes, logic_rules)
for item in room.items:
item.attributes = update_attributes(item, item.attributes, logic_rules)
logger.info("updated world attributes")
def format_logic(attributes: Dict[str, str], self=True) -> str:
labels = []
for attribute, value in attributes.items():
if attribute in logic_rules.labels and value in logic_rules.labels[attribute]:
label = logic_rules.labels[attribute][value]
if self:
labels.append(label.backstory)
else:
labels.append(label.description)
if len(labels) > 0:
logger.info("adding labels: %s", labels)
return " ".join(labels)
def init():
logger.info("initialized logic system")
return (update_logic, format_logic)