diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index dffb40ea..f4789e50 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -73,7 +73,14 @@ agents: reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 - ref: client_1_green_user team: GREEN @@ -116,7 +123,14 @@ agents: reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 @@ -696,22 +710,15 @@ agents: node_hostname: database_server folder_name: database file_name: database.db - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_1 - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_2 + agent_name: client_2_green_user + agent_settings: diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 6ab5aa42..86a61535 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,7 +26,9 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, Type, TYPE_CHECKING +from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING + +from typing_extensions import Never from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -214,18 +216,29 @@ class WebpageUnavailablePenalty(AbstractReward): :param node_hostname: Hostname of the node which has the web browser. :type node_hostname: str """ - self._node = node_hostname - self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + self._node: str = node_hostname + self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + self._last_request_failed: bool = False def calculate( self, state: Dict, last_action_response: "AgentActionHistoryItem" ) -> float: # todo maybe make last_action_response optional? """ - Calculate the reward based on current simulation state. + Calculate the reward based on current simulation state, and the recent agent action. - :param state: The current state of the simulation. - :type state: Dict + When the green agent requests to execute the browser application, and that request fails, this reward + component will keep track of that information. In that case, it doesn't matter whether the last webpage + had a 200 status code, because there has been an unsuccessful request since. """ + if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: + self._last_request_failed = last_action_response.response.status != "success" + + # if agent couldn't even get as far as sending the request (because for example the node was off), then + # apply a penalty + if self._last_request_failed: + return -1.0 + + # If the last request did actually go through, then check if the webpage also loaded web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: _LOGGER.info( @@ -265,20 +278,28 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): :param node_hostname: Hostname of the node where the database client sits. :type node_hostname: str """ - self._node = node_hostname - self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + self._node: str = node_hostname + self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + self._last_request_failed: bool = False - def calculate( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """ - Calculate the reward based on current simulation state. + Calculate the reward based on current simulation state, and the recent agent action. - :param state: The current state of the simulation. - :type state: Dict + When the green agent requests to execute the database client application, and that request fails, this reward + component will keep track of that information. In that case, it doesn't matter whether the last successful + request returned was able to connect to the database server, because there has been an unsuccessful request + since. """ - if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]: - pass # TODO + if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: + self._last_request_failed = last_action_response.response.status != "success" + + # if agent couldn't even get as far as sending the request (because for example the node was off), then + # apply a penalty + if self._last_request_failed: + return -1.0 + + # If the last request was actually sent, then check if the connection was established. db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") @@ -301,6 +322,52 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return cls(node_hostname=node_hostname) +class SharedReward(AbstractReward): + """Adds another agent's reward to the overall reward.""" + + def __init__(self, agent_name: Optional[str] = None) -> None: + """ + Initialise the shared reward. + + The agent_ref is a placeholder value. It starts off as none, but it must be set before this reward can work + correctly. + + :param agent_name: The name whose reward is an input + :type agent_ref: Optional[str] + """ + self.agent_name = agent_name + """Agent whose reward to track.""" + + def default_callback() -> Never: + """ + Default callback to prevent calling this reward until it's properly initialised. + + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. + """ + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") + + self.callback: Callable[[], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" + + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + """Simply access the other agent's reward and return it.""" + print(self.callback(), self.agent_name) + return self.callback() + + @classmethod + def from_config(cls, config: Dict) -> "SharedReward": + """ + Build the SharedReward object from config. + + :param config: Configuration dictionary + :type config: Dict + """ + agent_name = config.get("agent_name") + return cls(agent_name=agent_name) + + class RewardFunction: """Manages the reward function for the agent.""" @@ -310,6 +377,7 @@ class RewardFunction: "WEB_SERVER_404_PENALTY": WebServer404Penalty, "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, + "SHARED_REWARD": SharedReward, } """List of reward class identifiers.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1cc8cfed..e766bcd3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -9,8 +9,9 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations import ObservationManager -from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.scripted_agents import ProbabilisticAgent +from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -110,6 +111,9 @@ class PrimaiteGame: self.save_step_metadata: bool = False """Whether to save the RL agents' action, environment state, and other data at every single step.""" + self._reward_calculation_order: List[str] = [name for name in self.agents] + """Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards.""" + def step(self): """ Perform one step of the simulation/agent loop. @@ -148,10 +152,11 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for agent_name, agent in self.agents.items(): + for agent_name in self._reward_calculation_order: + agent = self.agents[agent_name] if self.step_counter > 0: # can't get reward before first action agent.update_reward(state=state) - agent.update_observation(state=state) + agent.update_observation(state=state) # order of this doesn't matter so just use reward order agent.reward_function.total_reward += agent.reward_function.current_reward def apply_agent_actions(self) -> None: @@ -443,7 +448,51 @@ class PrimaiteGame: raise ValueError(msg) game.agents[agent_cfg["ref"]] = new_agent + # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. + game.setup_reward_sharing() + # Set the NMNE capture config set_nmne_config(network_config.get("nmne_config", {})) return game + + def setup_reward_sharing(self): + """Do necessary setup to enable reward sharing between agents. + + This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1 + depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop. + + Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is + sharing their reward. This callback is provided by this setup method. + + Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards + that rely on the value of another reward are evaluated later. + + :raises RuntimeError: If the reward sharing is specified with a cyclic dependency. + """ + # construct dependency graph in the reward sharing between agents. + graph = {} + for name, agent in self.agents.items(): + graph[name] = set() + for comp, weight in agent.reward_function.reward_components: + if isinstance(comp, SharedReward): + comp: SharedReward + graph[name].add(comp.agent_name) + + # while constructing the graph, we might as well set up the reward sharing itself. + comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward + # TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches + # the most recent value of current_reward, NOT just simply caching the reward value at the time this + # callback method is defined. + + # make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing. + if graph_has_cycle(graph): + raise RuntimeError( + ( + "Detected cycle in agent reward sharing. Check the agent reward function ", + "configuration: reward sharing can only go one way.", + ) + ) + + # sort the agents so the rewards that depend on other rewards are always evaluated later + self._reward_calculation_order = topological_sort(graph) diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py index 19a86237..801ef269 100644 --- a/src/primaite/game/science.py +++ b/src/primaite/game/science.py @@ -1,4 +1,5 @@ from random import random +from typing import Any, Iterable, Mapping def simulate_trial(p_of_success: float) -> bool: @@ -14,3 +15,81 @@ def simulate_trial(p_of_success: float) -> bool: :returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False. """ return random() < p_of_success + + +def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool: + """Detect cycles in a directed graph. + + Provide the graph as a dictionary that describes which nodes are linked. For example: + {0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0 + {'a': ('b','c'), c:('b')} here there is no cycle + + :param graph: a mapping from node to a set of nodes to which it is connected. + :type graph: Mapping[Any, Iterable[Any]] + :return: Whether the graph has any cycles + :rtype: bool + """ + visited = set() + currently_visiting = set() + + def depth_first_search(node: Any) -> bool: + """Perform depth-first search (DFS) traversal to detect cycles starting from a given node.""" + if node in currently_visiting: + return True # Cycle detected + if node in visited: + return False # Already visited, no need to explore further + + visited.add(node) + currently_visiting.add(node) + + for neighbour in graph.get(node, []): + if depth_first_search(neighbour): + return True # Cycle detected + + currently_visiting.remove(node) + return False + + # Start DFS traversal from each node + for node in graph: + if depth_first_search(node): + return True # Cycle detected + + return False # No cycles found + + +def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]: + """ + Perform topological sorting on a directed graph. + + This guarantees that if there's a directed edge from node A to node B, then A appears before B. + + :param graph: A dictionary representing the directed graph, where keys are node identifiers + and values are lists of outgoing edges from each node. + :type graph: dict[int, list[Any]] + + :return: A topologically sorted list of node identifiers. + :rtype: list[Any] + """ + visited: set[Any] = set() + stack: list[Any] = [] + + def dfs(node: Any) -> None: + """ + Depth-first search traversal to visit nodes and their neighbors. + + :param node: The current node to visit. + :type node: Any + """ + if node in visited: + return + visited.add(node) + for neighbour in graph.get(node, []): + dfs(neighbour) + stack.append(node) + + # Perform DFS traversal from each node + for node in graph: + dfs(node) + + # Reverse the stack and return it. + return stack[::-1] diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 87289c43..ef77c63d 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -77,16 +77,17 @@ class PrimaiteIO: :type episode: int """ data = {} - longest_history = max([len(hist) for hist in agent_actions]) + longest_history = max([len(hist) for hist in agent_actions.values()]) for i in range(longest_history): - data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}} + data[i] = {"timestep": i, "episode": episode} + data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i}) path = self.generate_agent_actions_save_path(episode=episode) path.parent.mkdir(exist_ok=True, parents=True) path.touch() _LOGGER.info(f"Saving agent action log to {path}") with open(path, "w") as file: - json.dump(data, fp=file, indent=1) + json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump()) @classmethod def from_config(cls, config: Dict) -> "PrimaiteIO":