From 759965587982931bd039df6a4084ff7aa364cbdd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 11 Mar 2024 20:10:08 +0000 Subject: [PATCH 01/10] Add agent action history --- .../game/agent/data_manipulation_bot.py | 5 +- src/primaite/game/agent/interface.py | 37 ++++++++++++-- src/primaite/game/agent/rewards.py | 37 ++++++++++---- src/primaite/game/game.py | 51 +++++++++---------- src/primaite/interface/request.py | 4 +- .../Data-Manipulation-E2E-Demonstration.ipynb | 10 ++-- src/primaite/session/environment.py | 25 ++++----- src/primaite/session/io.py | 41 +++------------ src/primaite/simulator/core.py | 4 +- 9 files changed, 110 insertions(+), 104 deletions(-) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 16453433..d3ec19cb 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.reset_agent_for_episode() + self.setup_agent() def _set_next_execution_timestep(self, timestep: int) -> None: """Set the next execution timestep with a configured random variance. @@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent): return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} - def reset_agent_for_episode(self) -> None: + def setup_agent(self) -> None: """Set the next execution timestep when the episode resets.""" - super().reset_agent_for_episode() self._select_start_node() self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 88848479..0531b25f 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,6 +1,6 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium.core import ActType, ObsType from pydantic import BaseModel, model_validator @@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.interface.request import RequestFormat, RequestResponse if TYPE_CHECKING: pass +class AgentActionHistoryItem(BaseModel): + """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" + + timestep: int + """Timestep of this action.""" + + action: str + """CAOS Action name.""" + + parameters: Dict[str, Any] + """CAOS parameters for the given action.""" + + request: RequestFormat + """The request that was sent to the simulation based on the CAOS action chosen.""" + + response: RequestResponse + """The response sent back by the simulator for this action.""" + + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -90,6 +110,7 @@ class AbstractAgent(ABC): self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() + self.action_history: List[AgentActionHistoryItem] = [] def update_observation(self, state: Dict) -> ObsType: """ @@ -109,7 +130,7 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.update(state) + return self.reward_function.update(state=state, last_action_response=self.action_history[-1]) @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -138,9 +159,15 @@ class AbstractAgent(ABC): request = self.action_manager.form_request(action_identifier=action, action_options=options) return request - def reset_agent_for_episode(self) -> None: - """Agent reset logic should go here.""" - pass + def process_action_response( + self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + ) -> None: + """Process the response from the most recent action.""" + self.action_history.append( + AgentActionHistoryItem( + timestep=timestep, action=action, parameters=parameters, request=request, response=response + ) + ) class AbstractScriptedAgent(AbstractAgent): diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 8c8e36ad..6ab5aa42 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,11 +26,14 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, Type +from typing import Dict, List, Tuple, Type, TYPE_CHECKING from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.agent.interface import AgentActionHistoryItem + _LOGGER = getLogger(__name__) @@ -38,7 +41,9 @@ class AbstractReward: """Base class for reward function components.""" @abstractmethod - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state.""" return 0.0 @@ -58,7 +63,9 @@ class AbstractReward: class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state.""" return 0.0 @@ -98,7 +105,9 @@ class DatabaseFileIntegrity(AbstractReward): file_name, ] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -153,7 +162,9 @@ class WebServer404Penalty(AbstractReward): """ self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -206,7 +217,9 @@ class WebpageUnavailablePenalty(AbstractReward): self._node = node_hostname self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """ Calculate the reward based on current simulation state. @@ -255,13 +268,17 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): self._node = node_hostname self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """ Calculate the reward based on current simulation state. :param state: The current state of the simulation. :type state: Dict """ + if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]: + pass # TODO db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") @@ -313,7 +330,9 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def update(self, state: Dict) -> float: + def update( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the overall reward for the current state. :param state: The current state of the simulation. @@ -323,7 +342,7 @@ class RewardFunction: for comp_and_weight in self.reward_components: comp = comp_and_weight[0] weight = comp_and_weight[1] - total += weight * comp.calculate(state=state) + total += weight * comp.calculate(state=state, last_action_response=last_action_response) self.current_reward = total return self.current_reward diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c94cb3ad..1cc8cfed 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,6 +1,6 @@ """PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address -from typing import Dict, List, Tuple +from typing import Dict, List from pydantic import BaseModel, ConfigDict @@ -130,49 +130,44 @@ class PrimaiteGame: """ _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") - # Get the current state of the simulation - sim_state = self.get_sim_state() - - # Update agents' observations and rewards based on the current state - self.update_agents(sim_state) - # Apply all actions to simulation as requests - self.apply_agent_actions() + action_data = self.apply_agent_actions() # Advance timestep self.advance_timestep() + # Get the current state of the simulation + sim_state = self.get_sim_state() + + # Update agents' observations and rewards based on the current state, and the response from the last action + self.update_agents(state=sim_state, action_data=action_data) + def get_sim_state(self) -> Dict: """Get the current state of the simulation.""" return self.simulation.describe_state() def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for _, agent in self.agents.items(): - agent.update_observation(state) - agent.update_reward(state) + for agent_name, agent in self.agents.items(): + if self.step_counter > 0: # can't get reward before first action + agent.update_reward(state=state) + agent.update_observation(state=state) agent.reward_function.total_reward += agent.reward_function.current_reward - def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]: - """ - Apply all actions to simulation as requests. - - :return: A recap of each agent's actions, in CAOS format. - :rtype: Dict[str, Tuple[str, Dict]] - - """ - agent_actions = {} + def apply_agent_actions(self) -> None: + """Apply all actions to simulation as requests.""" for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation - action_choice, options = agent.get_action(obs, timestep=self.step_counter) - request = agent.format_request(action_choice, options) + action_choice, parameters = agent.get_action(obs, timestep=self.step_counter) + request = agent.format_request(action_choice, parameters) response = self.simulation.apply_request(request) - agent_actions[agent.agent_name] = { - "action": action_choice, - "parameters": options, - "response": response.model_dump(), - } - return agent_actions + agent.process_action_response( + timestep=self.step_counter, + action=action_choice, + parameters=parameters, + request=request, + response=response, + ) def advance_timestep(self) -> None: """Advance timestep.""" diff --git a/src/primaite/interface/request.py b/src/primaite/interface/request.py index 8e922ef9..bc076599 100644 --- a/src/primaite/interface/request.py +++ b/src/primaite/interface/request.py @@ -1,7 +1,9 @@ -from typing import Dict, ForwardRef, Literal +from typing import Dict, ForwardRef, List, Literal, Union from pydantic import BaseModel, ConfigDict, StrictBool, validate_call +RequestFormat = List[Union[str, int, float]] + RequestResponse = ForwardRef("RequestResponse") """This makes it possible to type-hint RequestResponse.from_bool return type.""" diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 1d7cb157..b2522c2b 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -373,7 +373,7 @@ "# Imports\n", "from primaite.config.load import data_manipulation_config_path\n", "from primaite.session.environment import PrimaiteGymEnv\n", - "from primaite.game.game import PrimaiteGame\n", + "from primaite.game.agent.interface import AgentActionHistoryItem\n", "import yaml\n", "from pprint import pprint\n" ] @@ -425,14 +425,14 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info = info['agent_actions']['data_manipulation_attacker']\n", - " red_action = red_info['action']\n", + " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", " elif red_action == 'NODE_APPLICATION_EXECUTE':\n", - " client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n", + " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", " red_str = f\"ATTACK from {client}\"\n", - " return red_str\n" + " return red_str" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 87638e7d..64534b04 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -49,23 +49,20 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() - self.game.update_agents(state) - next_obs = self._get_obs() + next_obs = self._get_obs() # this doesn't update observation, just gets the current observation reward = self.agent.reward_function.current_reward terminated = False truncated = self.game.calculate_truncated() - info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. + info = { + "agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()} + } # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) - if self.io.settings.save_agent_actions: - self.io.store_agent_actions( - agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter - ) return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): @@ -91,13 +88,13 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if self.io.settings.save_agent_actions: - self.io.write_agent_actions(episode=self.episode_counter) - self.io.clear_agent_actions() + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() - self.game.update_agents(state) + self.game.update_agents(state=state) next_obs = self._get_obs() info = {} return next_obs, info @@ -217,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -236,10 +233,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: self._write_step_metadata_json(actions, state, rewards) - if self.io.settings.save_agent_actions: - self.io.store_agent_actions( - agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter - ) return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index ed2b4d62..87289c43 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -48,8 +48,6 @@ class PrimaiteIO: SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs - self.agent_action_log: List[Dict] = [] - def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" if timestamp is None: @@ -72,48 +70,23 @@ class PrimaiteIO: """Return the path where agent actions will be saved.""" return self.session_path / "agent_actions" / f"episode_{episode}.json" - def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None: - """Cache agent actions for a particular step. - - :param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected - format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters] - CAOS action is a string representing one the CAOS actions. - parameters is a dict of parameter names and values for that particular CAOS action. - For example: - { - 'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}), - 'defender': ('DO_NOTHING', {}) - } - :type agent_actions: Dict - :param timestep: Simulation timestep when these actions occurred. - :type timestep: int - """ - self.agent_action_log.append( - [ - { - "episode": episode, - "timestep": timestep, - "agent_actions": agent_actions, - } - ] - ) - - def write_agent_actions(self, episode: int) -> None: + def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None: """Take the contents of the agent action log and write it to a file. :param episode: Episode number :type episode: int """ + data = {} + longest_history = max([len(hist) for hist in agent_actions]) + for i in range(longest_history): + data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}} + path = self.generate_agent_actions_save_path(episode=episode) path.parent.mkdir(exist_ok=True, parents=True) path.touch() _LOGGER.info(f"Saving agent action log to {path}") with open(path, "w") as file: - json.dump(self.agent_action_log, fp=file, indent=1) - - def clear_agent_actions(self) -> None: - """Reset the agent action log back to an empty dictionary.""" - self.agent_action_log = [] + json.dump(data, fp=file, indent=1) @classmethod def from_config(cls, config: Dict) -> "PrimaiteIO": diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index aeb4e865..6da8a2f8 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -7,12 +7,10 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger -from primaite.interface.request import RequestResponse +from primaite.interface.request import RequestFormat, RequestResponse _LOGGER = getLogger(__name__) -RequestFormat = List[Union[str, int, float]] - class RequestPermissionValidator(BaseModel): """ From c3f1cfb33d3516fcca84b5ffe944b46d080d1cf3 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 11 Mar 2024 22:53:39 +0000 Subject: [PATCH 02/10] Add shared reward --- .../_package_data/data_manipulation.yaml | 39 ++++--- src/primaite/game/agent/rewards.py | 100 +++++++++++++++--- src/primaite/game/game.py | 55 +++++++++- src/primaite/game/science.py | 79 ++++++++++++++ src/primaite/session/io.py | 7 +- 5 files changed, 242 insertions(+), 38 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index dffb40ea..f4789e50 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -73,7 +73,14 @@ agents: reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 - ref: client_1_green_user team: GREEN @@ -116,7 +123,14 @@ agents: reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 @@ -696,22 +710,15 @@ agents: node_hostname: database_server folder_name: database file_name: database.db - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_1 - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_2 + agent_name: client_2_green_user + agent_settings: diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 6ab5aa42..86a61535 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,7 +26,9 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, Type, TYPE_CHECKING +from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING + +from typing_extensions import Never from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -214,18 +216,29 @@ class WebpageUnavailablePenalty(AbstractReward): :param node_hostname: Hostname of the node which has the web browser. :type node_hostname: str """ - self._node = node_hostname - self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + self._node: str = node_hostname + self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + self._last_request_failed: bool = False def calculate( self, state: Dict, last_action_response: "AgentActionHistoryItem" ) -> float: # todo maybe make last_action_response optional? """ - Calculate the reward based on current simulation state. + Calculate the reward based on current simulation state, and the recent agent action. - :param state: The current state of the simulation. - :type state: Dict + When the green agent requests to execute the browser application, and that request fails, this reward + component will keep track of that information. In that case, it doesn't matter whether the last webpage + had a 200 status code, because there has been an unsuccessful request since. """ + if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: + self._last_request_failed = last_action_response.response.status != "success" + + # if agent couldn't even get as far as sending the request (because for example the node was off), then + # apply a penalty + if self._last_request_failed: + return -1.0 + + # If the last request did actually go through, then check if the webpage also loaded web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: _LOGGER.info( @@ -265,20 +278,28 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): :param node_hostname: Hostname of the node where the database client sits. :type node_hostname: str """ - self._node = node_hostname - self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + self._node: str = node_hostname + self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + self._last_request_failed: bool = False - def calculate( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """ - Calculate the reward based on current simulation state. + Calculate the reward based on current simulation state, and the recent agent action. - :param state: The current state of the simulation. - :type state: Dict + When the green agent requests to execute the database client application, and that request fails, this reward + component will keep track of that information. In that case, it doesn't matter whether the last successful + request returned was able to connect to the database server, because there has been an unsuccessful request + since. """ - if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]: - pass # TODO + if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: + self._last_request_failed = last_action_response.response.status != "success" + + # if agent couldn't even get as far as sending the request (because for example the node was off), then + # apply a penalty + if self._last_request_failed: + return -1.0 + + # If the last request was actually sent, then check if the connection was established. db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") @@ -301,6 +322,52 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return cls(node_hostname=node_hostname) +class SharedReward(AbstractReward): + """Adds another agent's reward to the overall reward.""" + + def __init__(self, agent_name: Optional[str] = None) -> None: + """ + Initialise the shared reward. + + The agent_ref is a placeholder value. It starts off as none, but it must be set before this reward can work + correctly. + + :param agent_name: The name whose reward is an input + :type agent_ref: Optional[str] + """ + self.agent_name = agent_name + """Agent whose reward to track.""" + + def default_callback() -> Never: + """ + Default callback to prevent calling this reward until it's properly initialised. + + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. + """ + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") + + self.callback: Callable[[], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" + + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + """Simply access the other agent's reward and return it.""" + print(self.callback(), self.agent_name) + return self.callback() + + @classmethod + def from_config(cls, config: Dict) -> "SharedReward": + """ + Build the SharedReward object from config. + + :param config: Configuration dictionary + :type config: Dict + """ + agent_name = config.get("agent_name") + return cls(agent_name=agent_name) + + class RewardFunction: """Manages the reward function for the agent.""" @@ -310,6 +377,7 @@ class RewardFunction: "WEB_SERVER_404_PENALTY": WebServer404Penalty, "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, + "SHARED_REWARD": SharedReward, } """List of reward class identifiers.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1cc8cfed..e766bcd3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -9,8 +9,9 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations import ObservationManager -from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.scripted_agents import ProbabilisticAgent +from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -110,6 +111,9 @@ class PrimaiteGame: self.save_step_metadata: bool = False """Whether to save the RL agents' action, environment state, and other data at every single step.""" + self._reward_calculation_order: List[str] = [name for name in self.agents] + """Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards.""" + def step(self): """ Perform one step of the simulation/agent loop. @@ -148,10 +152,11 @@ class PrimaiteGame: def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for agent_name, agent in self.agents.items(): + for agent_name in self._reward_calculation_order: + agent = self.agents[agent_name] if self.step_counter > 0: # can't get reward before first action agent.update_reward(state=state) - agent.update_observation(state=state) + agent.update_observation(state=state) # order of this doesn't matter so just use reward order agent.reward_function.total_reward += agent.reward_function.current_reward def apply_agent_actions(self) -> None: @@ -443,7 +448,51 @@ class PrimaiteGame: raise ValueError(msg) game.agents[agent_cfg["ref"]] = new_agent + # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. + game.setup_reward_sharing() + # Set the NMNE capture config set_nmne_config(network_config.get("nmne_config", {})) return game + + def setup_reward_sharing(self): + """Do necessary setup to enable reward sharing between agents. + + This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1 + depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop. + + Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is + sharing their reward. This callback is provided by this setup method. + + Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards + that rely on the value of another reward are evaluated later. + + :raises RuntimeError: If the reward sharing is specified with a cyclic dependency. + """ + # construct dependency graph in the reward sharing between agents. + graph = {} + for name, agent in self.agents.items(): + graph[name] = set() + for comp, weight in agent.reward_function.reward_components: + if isinstance(comp, SharedReward): + comp: SharedReward + graph[name].add(comp.agent_name) + + # while constructing the graph, we might as well set up the reward sharing itself. + comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward + # TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches + # the most recent value of current_reward, NOT just simply caching the reward value at the time this + # callback method is defined. + + # make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing. + if graph_has_cycle(graph): + raise RuntimeError( + ( + "Detected cycle in agent reward sharing. Check the agent reward function ", + "configuration: reward sharing can only go one way.", + ) + ) + + # sort the agents so the rewards that depend on other rewards are always evaluated later + self._reward_calculation_order = topological_sort(graph) diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py index 19a86237..801ef269 100644 --- a/src/primaite/game/science.py +++ b/src/primaite/game/science.py @@ -1,4 +1,5 @@ from random import random +from typing import Any, Iterable, Mapping def simulate_trial(p_of_success: float) -> bool: @@ -14,3 +15,81 @@ def simulate_trial(p_of_success: float) -> bool: :returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False. """ return random() < p_of_success + + +def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool: + """Detect cycles in a directed graph. + + Provide the graph as a dictionary that describes which nodes are linked. For example: + {0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0 + {'a': ('b','c'), c:('b')} here there is no cycle + + :param graph: a mapping from node to a set of nodes to which it is connected. + :type graph: Mapping[Any, Iterable[Any]] + :return: Whether the graph has any cycles + :rtype: bool + """ + visited = set() + currently_visiting = set() + + def depth_first_search(node: Any) -> bool: + """Perform depth-first search (DFS) traversal to detect cycles starting from a given node.""" + if node in currently_visiting: + return True # Cycle detected + if node in visited: + return False # Already visited, no need to explore further + + visited.add(node) + currently_visiting.add(node) + + for neighbour in graph.get(node, []): + if depth_first_search(neighbour): + return True # Cycle detected + + currently_visiting.remove(node) + return False + + # Start DFS traversal from each node + for node in graph: + if depth_first_search(node): + return True # Cycle detected + + return False # No cycles found + + +def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]: + """ + Perform topological sorting on a directed graph. + + This guarantees that if there's a directed edge from node A to node B, then A appears before B. + + :param graph: A dictionary representing the directed graph, where keys are node identifiers + and values are lists of outgoing edges from each node. + :type graph: dict[int, list[Any]] + + :return: A topologically sorted list of node identifiers. + :rtype: list[Any] + """ + visited: set[Any] = set() + stack: list[Any] = [] + + def dfs(node: Any) -> None: + """ + Depth-first search traversal to visit nodes and their neighbors. + + :param node: The current node to visit. + :type node: Any + """ + if node in visited: + return + visited.add(node) + for neighbour in graph.get(node, []): + dfs(neighbour) + stack.append(node) + + # Perform DFS traversal from each node + for node in graph: + dfs(node) + + # Reverse the stack and return it. + return stack[::-1] diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 87289c43..ef77c63d 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -77,16 +77,17 @@ class PrimaiteIO: :type episode: int """ data = {} - longest_history = max([len(hist) for hist in agent_actions]) + longest_history = max([len(hist) for hist in agent_actions.values()]) for i in range(longest_history): - data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}} + data[i] = {"timestep": i, "episode": episode} + data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i}) path = self.generate_agent_actions_save_path(episode=episode) path.parent.mkdir(exist_ok=True, parents=True) path.touch() _LOGGER.info(f"Saving agent action log to {path}") with open(path, "w") as file: - json.dump(data, fp=file, indent=1) + json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump()) @classmethod def from_config(cls, config: Dict) -> "PrimaiteIO": From 03ee976a2d66d40e35a20940c41f1aae88c5a9e2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 12 Mar 2024 11:00:55 +0000 Subject: [PATCH 03/10] remove extra print statement --- src/primaite/game/agent/rewards.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 86a61535..3d61c0b4 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -353,7 +353,6 @@ class SharedReward(AbstractReward): def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Simply access the other agent's reward and return it.""" - print(self.callback(), self.agent_name) return self.callback() @classmethod From 24fdb8dc17f2a7608e7876a58afa054993e30851 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 12 Mar 2024 11:40:26 +0000 Subject: [PATCH 04/10] Fix minor reward sharing bugs --- src/primaite/game/agent/rewards.py | 8 ++--- src/primaite/game/game.py | 5 +-- src/primaite/game/science.py | 3 +- .../Data-Manipulation-E2E-Demonstration.ipynb | 32 +++++++++++-------- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 3d61c0b4..a2ffd875 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -230,7 +230,7 @@ class WebpageUnavailablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last webpage had a 200 status code, because there has been an unsuccessful request since. """ - if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: + if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]: self._last_request_failed = last_action_response.response.status != "success" # if agent couldn't even get as far as sending the request (because for example the node was off), then @@ -338,7 +338,7 @@ class SharedReward(AbstractReward): self.agent_name = agent_name """Agent whose reward to track.""" - def default_callback() -> Never: + def default_callback(agent_name: str) -> Never: """ Default callback to prevent calling this reward until it's properly initialised. @@ -348,12 +348,12 @@ class SharedReward(AbstractReward): """ raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - self.callback: Callable[[], float] = default_callback + self.callback: Callable[[str], float] = default_callback """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Simply access the other agent's reward and return it.""" - return self.callback() + return self.callback(self.agent_name) @classmethod def from_config(cls, config: Dict) -> "SharedReward": diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e766bcd3..ac23610c 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -480,10 +480,7 @@ class PrimaiteGame: graph[name].add(comp.agent_name) # while constructing the graph, we might as well set up the reward sharing itself. - comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward - # TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches - # the most recent value of current_reward, NOT just simply caching the reward value at the time this - # callback method is defined. + comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward # make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing. if graph_has_cycle(graph): diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py index 801ef269..908b326f 100644 --- a/src/primaite/game/science.py +++ b/src/primaite/game/science.py @@ -91,5 +91,4 @@ def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]: for node in graph: dfs(node) - # Reverse the stack and return it. - return stack[::-1] + return stack diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index b2522c2b..946202b6 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -450,7 +450,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now the reward is -1, let's have a look at blue agent's observation." + "Now the reward is -0.8, let's have a look at blue agent's observation." ] }, { @@ -510,9 +510,9 @@ "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", - "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n", + "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n", "print(f\"Blue reward:{reward}\" )" ] }, @@ -533,9 +533,9 @@ "metadata": {}, "outputs": [], "source": [ - "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", + "obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", "print(f\"step: {env.game.step_counter}\")\n", - "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n", + "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n", "print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n", "print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n", "print(f\"Blue reward:{reward:.2f}\" )" @@ -557,17 +557,19 @@ "outputs": [], "source": [ "env.step(13) # Patch the database\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(50) # Block client 1\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(51) # Block client 2\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", - "for step in range(30):\n", + "while abs(reward - 0.8) > 1e-5:\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", + " if env.game.step_counter > 10000:\n", + " break # make sure there's no infinite loop if something went wrong" ] }, { @@ -617,17 +619,19 @@ " if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", " # client 1 has NMNEs, let's block it\n", " obs, reward, terminated, truncated, info = env.step(50) # block client 1\n", + " print(\"blocking client 1\")\n", " break\n", " elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", " # client 2 has NMNEs, so let's block it\n", " obs, reward, terminated, truncated, info = env.step(51) # block client 2\n", + " print(\"blocking client 2\")\n", " break\n", " if tries>100:\n", " print(\"Error: NMNE never increased\")\n", " break\n", "\n", "env.step(13) # Patch the database\n", - "..." + "print()\n" ] }, { @@ -646,14 +650,14 @@ "\n", "for step in range(40):\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode." + "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)" ] }, { From 045f46740702918a711003f2b9859988def28dbd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 12 Mar 2024 11:51:17 +0000 Subject: [PATCH 05/10] Update marl config --- .../_package_data/data_manipulation_marl.yaml | 58 +++++++++---------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation_marl.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml index f7288cb0..be53d2c5 100644 --- a/src/primaite/config/_package_data/data_manipulation_marl.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -75,7 +75,14 @@ agents: reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 - ref: client_1_green_user team: GREEN @@ -118,7 +125,14 @@ agents: reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 @@ -700,22 +714,14 @@ agents: node_hostname: database_server folder_name: database file_name: database.db - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_1 - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_2 + agent_name: client_2_green_user agent_settings: @@ -1259,22 +1265,14 @@ agents: node_hostname: database_server folder_name: database file_name: database.db - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_1 - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_2 + agent_name: client_2_green_user agent_settings: From 6dedb910990df2b289f31162b43e678d12cf0d12 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 13 Mar 2024 09:17:29 +0000 Subject: [PATCH 06/10] Remove redundant TODOs --- src/primaite/game/agent/interface.py | 2 -- src/primaite/game/agent/rewards.py | 24 ++++++------------------ 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 0531b25f..91fa03d4 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -141,8 +141,6 @@ class AbstractAgent(ABC): :param obs: Observation of the environment. :type obs: ObsType - :param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted? - :type reward: float, optional :param timestep: The current timestep in the simulation, used for non-RL agents. Optional :type timestep: int :return: Action to be taken in the environment. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index a2ffd875..d8cb1328 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -43,9 +43,7 @@ class AbstractReward: """Base class for reward function components.""" @abstractmethod - def calculate( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -65,9 +63,7 @@ class AbstractReward: class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" - def calculate( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -107,9 +103,7 @@ class DatabaseFileIntegrity(AbstractReward): file_name, ] - def calculate( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -164,9 +158,7 @@ class WebServer404Penalty(AbstractReward): """ self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -220,9 +212,7 @@ class WebpageUnavailablePenalty(AbstractReward): self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] self._last_request_failed: bool = False - def calculate( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -397,9 +387,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def update( - self, state: Dict, last_action_response: "AgentActionHistoryItem" - ) -> float: # todo maybe make last_action_response optional? + def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. From 10ee9b300fe8e6596c25fbad6ebd689ac1bc4aaf Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 13 Mar 2024 12:08:20 +0000 Subject: [PATCH 07/10] Update docs on rewards --- docs/source/game_layer.rst | 72 ++++++++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 14 deletions(-) diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index 39ab7bde..ba400ac2 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -6,19 +6,12 @@ The Primaite codebase consists of two main modules: * ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes. * ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules. - The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API. - -.. - TODO: write up these APIs and link them here. - - -Game layer ----------- +The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API. The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components: PrimAITE Session -^^^^^^^^^^^^^^^^ +================ .. admonition:: Deprecated :class: deprecated @@ -28,7 +21,7 @@ PrimAITE Session ``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types. Agents -^^^^^^ +====== All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types: @@ -39,16 +32,67 @@ All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAge TODO: add seed to stochastic scripted agents Observations -^^^^^^^^^^^^^^^^^^ +============ An agent's observations are managed by the ``ObservationManager`` class. It generates observations based on the current simulation state dictionary. It also provides the observation space during initial setup. The data is formatted so it's compatible with ``Gymnasium.spaces``. Observation spaces are composed of one or more components which are defined by the ``AbstractObservation`` base class. Actions -^^^^^^^ +======= An agent's actions are managed by the ``ActionManager``. It converts actions selected by agents (which are typically integers chosen from a ``gymnasium.spaces.Discrete`` space) into simulation-friendly requests. It also provides the action space during initial setup. Action spaces are composed of one or more components which are defined by the ``AbstractAction`` base class. Rewards -^^^^^^^ +======= -An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. The reward components are defined by the AbstractReward base class. +An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. + +Reward Components +----------------- + +Currently implemented are reward components tailored to the data manipulation scenario. View the full API and description of how they work here: :py:module:`primaite.game.agent.reward`. + +Reward Sharing +-------------- + +An agent's reward can be based on rewards of other agents. This is particularly useful for modelling a situation where the blue agent's job is to protect the ability of green agents to perform their pattern-of-life. This can be configured in the YAML file this way: + +```yaml +green_agent_1: # this agent sometimes tries to access the webpage, and sometimes the database + # actions, observations, and agent settings go here + reward_function: + reward_components: + + # When the webpage loads, the reward goes up by 0.25 when it fails to load, it goes down to -0.25 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + + # When the database is reachable, the reward goes up by 0.05, when it is unreachable it goes down to -0.05 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + +blue_agent: + # actions, observations, and agent settings go here + reward_function: + reward_components: + + # When the database file is in a good state, blue's reward is 0.4, when it's in a corrupted state the reward is -0.4 + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + # The green's reward is added onto the blue's reward. + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + +``` + +When defining agent reward sharing, users must be careful to avoid circular references, as that would lead to an infinite calculation loop. PrimAITE will prevent circular dependencies and provide a helpful error message if they are detected in the yaml. From f438acf745c54137b08b41175b05d91e517d7db4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 13 Mar 2024 14:01:17 +0000 Subject: [PATCH 08/10] Add shared reward test --- tests/assets/configs/shared_rewards.yaml | 956 ++++++++++++++++++ .../game_layer/test_rewards.py | 27 + 2 files changed, 983 insertions(+) create mode 100644 tests/assets/configs/shared_rewards.yaml diff --git a/tests/assets/configs/shared_rewards.yaml b/tests/assets/configs/shared_rewards.yaml new file mode 100644 index 00000000..91ff20e7 --- /dev/null +++ b/tests/assets/configs/shared_rewards.yaml @@ -0,0 +1,956 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: true + + +game: + max_episode_length: 256 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_hostname: domain_controller + services: + - service_name: DNSServer + - node_hostname: web_server + services: + - service_name: WebServer + - node_hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_hostname: router_1 + ip_address_order: + - node_hostname: domain_controller + nic_num: 1 + - node_hostname: web_server + nic_num: 1 + - node_hostname: database_server + nic_num: 1 + - node_hostname: backup_server + nic_num: 1 + - node_hostname: security_suite + nic_num: 1 + - node_hostname: client_1 + nic_num: 1 + - node_hostname: client_2 + nic_num: 1 + - node_hostname: security_suite + nic_num: 2 + ics: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_PATCH + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: NETWORK_ACL_ADDRULE + options: + target_router_hostname: router_1 + - type: NETWORK_ACL_REMOVERULE + options: + target_router_hostname: router_1 + - type: NETWORK_NIC_ENABLE + - type: NETWORK_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_PATCH" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "NETWORK_ACL_ADDRULE" + options: + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "NETWORK_ACL_ADDRULE" + options: + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "NETWORK_ACL_ADDRULE" + options: + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 50: # old action num: 26 + action: "NETWORK_ACL_ADDRULE" + options: + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 51: # old action num: 27 + action: "NETWORK_ACL_ADDRULE" + options: + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 52: # old action num: 28 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 0 + 53: # old action num: 29 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 1 + 54: # old action num: 30 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 2 + 55: # old action num: 31 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 3 + 56: # old action num: 32 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 4 + 57: # old action num: 33 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 5 + 58: # old action num: 34 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 6 + 59: # old action num: 35 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 7 + 60: # old action num: 36 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 8 + 61: # old action num: 37 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 9 + 62: # old action num: 38 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + + + reward_function: + reward_components: + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + + + agent_settings: + flatten_obs: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - ref: router_1 + hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - ref: switch_1 + hostname: switch_1 + type: switch + num_ports: 8 + + - ref: switch_2 + hostname: switch_2 + type: switch + num_ports: 8 + + - ref: domain_controller + hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - ref: domain_controller_dns_server + type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - ref: web_server + hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: web_server_web_service + type: WebServer + applications: + - ref: web_server_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - ref: database_server + hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: database_service + type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - ref: database_ftp_client + type: FTPClient + + - ref: backup_server + hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: backup_service + type: FTPServer + + - ref: security_suite + hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - ref: client_1 + hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - ref: client_1_dns_client + type: DNSClient + + - ref: client_2 + hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - ref: client_2_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - ref: client_2_dns_client + type: DNSClient + + + + links: + - ref: router_1___switch_1 + endpoint_a_ref: router_1 + endpoint_a_port: 1 + endpoint_b_ref: switch_1 + endpoint_b_port: 8 + - ref: router_1___switch_2 + endpoint_a_ref: router_1 + endpoint_a_port: 2 + endpoint_b_ref: switch_2 + endpoint_b_port: 8 + - ref: switch_1___domain_controller + endpoint_a_ref: switch_1 + endpoint_a_port: 1 + endpoint_b_ref: domain_controller + endpoint_b_port: 1 + - ref: switch_1___web_server + endpoint_a_ref: switch_1 + endpoint_a_port: 2 + endpoint_b_ref: web_server + endpoint_b_port: 1 + - ref: switch_1___database_server + endpoint_a_ref: switch_1 + endpoint_a_port: 3 + endpoint_b_ref: database_server + endpoint_b_port: 1 + - ref: switch_1___backup_server + endpoint_a_ref: switch_1 + endpoint_a_port: 4 + endpoint_b_ref: backup_server + endpoint_b_port: 1 + - ref: switch_1___security_suite + endpoint_a_ref: switch_1 + endpoint_a_port: 7 + endpoint_b_ref: security_suite + endpoint_b_port: 1 + - ref: switch_2___client_1 + endpoint_a_ref: switch_2 + endpoint_a_port: 1 + endpoint_b_ref: client_1 + endpoint_b_port: 1 + - ref: switch_2___client_2 + endpoint_a_ref: switch_2 + endpoint_a_port: 2 + endpoint_b_ref: client_2 + endpoint_b_port: 1 + - ref: switch_2___security_suite + endpoint_a_ref: switch_2 + endpoint_a_port: 7 + endpoint_b_ref: security_suite + endpoint_b_port: 2 diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 8edbf0ac..56ba2b8f 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,10 +1,15 @@ +import yaml + from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.services.database.database_service import DatabaseService +from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -80,3 +85,25 @@ def test_uc2_rewards(game_and_agent): state = game.get_sim_state() reward_value = comp.calculate(state) assert reward_value == -1.0 + + +def test_shared_reward(): + CFG_PATH = TEST_ASSETS_ROOT / "configs/shared_rewards.yaml" + with open(CFG_PATH, "r") as f: + cfg = yaml.safe_load(f) + + env = PrimaiteGymEnv(game_config=cfg) + + env.reset() + + order = env.game._reward_calculation_order + assert order.index("defender") > order.index("client_1_green_user") + assert order.index("defender") > order.index("client_2_green_user") + + for step in range(256): + act = env.action_space.sample() + env.step(act) + g1_reward = env.game.agents["client_1_green_user"].reward_function.current_reward + g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward + blue_reward = env.game.agents["defender"].reward_function.current_reward + assert blue_reward == g1_reward + g2_reward From d33c80d0d61153a7fb599fefa6d209d3b0e602fe Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 14 Mar 2024 14:33:04 +0000 Subject: [PATCH 09/10] Minor fixes --- src/primaite/game/game.py | 9 +++++++-- src/primaite/session/environment.py | 4 ++-- tests/conftest.py | 2 ++ .../game_layer/test_rewards.py | 17 ++++++++++++++--- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 84e5e7df..05b76679 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -139,8 +139,12 @@ class PrimaiteGame: """ _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") + if self.step_counter == 0: + state = self.get_sim_state() + for agent in self.agents.values(): + agent.update_observation(state=state) # Apply all actions to simulation as requests - action_data = self.apply_agent_actions() + self.apply_agent_actions() # Advance timestep self.advance_timestep() @@ -149,7 +153,7 @@ class PrimaiteGame: sim_state = self.get_sim_state() # Update agents' observations and rewards based on the current state, and the response from the last action - self.update_agents(state=sim_state, action_data=action_data) + self.update_agents(state=sim_state) def get_sim_state(self) -> Dict: """Get the current state of the simulation.""" @@ -458,6 +462,7 @@ class PrimaiteGame: # Set the NMNE capture config set_nmne_config(network_config.get("nmne_config", {})) + game.update_agents(game.get_sim_state()) return game diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 64534b04..1795f14b 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -189,8 +189,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" if self.io.settings.save_agent_actions: - self.io.write_agent_actions(episode=self.episode_counter) - self.io.clear_agent_actions() + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 diff --git a/tests/conftest.py b/tests/conftest.py index 20600e73..3a9e2655 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -531,4 +531,6 @@ def game_and_agent(): game.agents["test_agent"] = test_agent + game.setup_reward_sharing() + return (game, test_agent) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 56ba2b8f..cfd013bc 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,5 +1,6 @@ import yaml +from primaite.game.agent.interface import AgentActionHistoryItem from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv @@ -66,13 +67,18 @@ def test_uc2_rewards(game_and_agent): comp = GreenAdminDatabaseUnreachablePenalty("client_1") - db_client.apply_request( + response = db_client.apply_request( [ "execute", ] ) state = game.get_sim_state() - reward_value = comp.calculate(state) + reward_value = comp.calculate( + state, + last_action_response=AgentActionHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + ), + ) assert reward_value == 1.0 router.acl.remove_rule(position=2) @@ -83,7 +89,12 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate(state) + reward_value = comp.calculate( + state, + last_action_response=AgentActionHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + ), + ) assert reward_value == -1.0 From a9bf0981e6624569eb75f2ef0c54451c18ccd27f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 15 Mar 2024 09:22:55 +0000 Subject: [PATCH 10/10] Doc fixes --- docs/source/game_layer.rst | 4 +--- src/primaite/game/agent/rewards.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index ba400ac2..af3eadc6 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -26,10 +26,8 @@ Agents All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types: * RL agents action during each step is decided by an appropriate RL algorithm. The agent within PrimAITE just acts to format and forward actions decided by an RL policy. -* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable. +* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed is settable. -.. - TODO: add seed to stochastic scripted agents Observations ============ diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index d8cb1328..52bed9e2 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -319,11 +319,11 @@ class SharedReward(AbstractReward): """ Initialise the shared reward. - The agent_ref is a placeholder value. It starts off as none, but it must be set before this reward can work + The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work correctly. :param agent_name: The name whose reward is an input - :type agent_ref: Optional[str] + :type agent_name: Optional[str] """ self.agent_name = agent_name """Agent whose reward to track."""