diff --git a/.gitignore b/.gitignore index b65d1fd8..4bb700b2 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,8 @@ dmypy.json # Cython debug symbols cython_debug/ +# IDE .idea/ + +# outputs +src/primaite/outputs/ diff --git a/README.md b/README.md index 78f36fba..f7c6efd7 100644 --- a/README.md +++ b/README.md @@ -1 +1,64 @@ # PrimAITE + +## Getting Started with PrimAITE + +### Pre-Requisites + +In order to get **PrimAITE** installed, you will need to have the following installed: + +- `python3.8+` +- `python3-pip` +- `virtualenv` + +**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS. + +### Installation from source +#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv) + +```unix +python3 -m venv +``` + +#### 2. Activate the venv + +##### Unix +```bash +source /bin/activate +``` + +##### Windows +```powershell +.\\Scripts\activate +``` + +#### 3. Install `primaite` into the venv along with all of it's dependencies + +```bash +python3 -m pip install -e . +``` + +### Development Installation +To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example: + +```bash +python3 -m pip install -e .[dev] +``` + +## Building documentation +The PrimAITE documentation can be built with the following commands: + +##### Unix +```bash +cd docs +make html +``` + +##### Windows +```powershell +cd docs +.\make.bat html +``` + +This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build +options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation +for your specific output requirements. diff --git a/docs/source/config.rst b/docs/source/config.rst index 71ade6c5..c80baa3c 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -82,203 +82,203 @@ The environment config file consists of the following attributes: Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment. -* **Generic [all_ok]** [int] +* **Generic [all_ok]** [float] The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) -* **Node Hardware State [off_should_be_on]** [int] +* **Node Hardware State [off_should_be_on]** [float] The score to give when the node should be on, but is off -* **Node Hardware State [off_should_be_resetting]** [int] +* **Node Hardware State [off_should_be_resetting]** [float] The score to give when the node should be resetting, but is off -* **Node Hardware State [on_should_be_off]** [int] +* **Node Hardware State [on_should_be_off]** [float] The score to give when the node should be off, but is on -* **Node Hardware State [on_should_be_resetting]** [int] +* **Node Hardware State [on_should_be_resetting]** [float] The score to give when the node should be resetting, but is on -* **Node Hardware State [resetting_should_be_on]** [int] +* **Node Hardware State [resetting_should_be_on]** [float] The score to give when the node should be on, but is resetting -* **Node Hardware State [resetting_should_be_off]** [int] +* **Node Hardware State [resetting_should_be_off]** [float] The score to give when the node should be off, but is resetting -* **Node Hardware State [resetting]** [int] +* **Node Hardware State [resetting]** [float] The score to give when the node is resetting -* **Node Operating System or Service State [good_should_be_patching]** [int] +* **Node Operating System or Service State [good_should_be_patching]** [float] The score to give when the state should be patching, but is good -* **Node Operating System or Service State [good_should_be_compromised]** [int] +* **Node Operating System or Service State [good_should_be_compromised]** [float] The score to give when the state should be compromised, but is good -* **Node Operating System or Service State [good_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [good_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is good -* **Node Operating System or Service State [patching_should_be_good]** [int] +* **Node Operating System or Service State [patching_should_be_good]** [float] The score to give when the state should be good, but is patching -* **Node Operating System or Service State [patching_should_be_compromised]** [int] +* **Node Operating System or Service State [patching_should_be_compromised]** [float] The score to give when the state should be compromised, but is patching -* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is patching -* **Node Operating System or Service State [patching]** [int] +* **Node Operating System or Service State [patching]** [float] The score to give when the state is patching -* **Node Operating System or Service State [compromised_should_be_good]** [int] +* **Node Operating System or Service State [compromised_should_be_good]** [float] The score to give when the state should be good, but is compromised -* **Node Operating System or Service State [compromised_should_be_patching]** [int] +* **Node Operating System or Service State [compromised_should_be_patching]** [float] The score to give when the state should be patching, but is compromised -* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is compromised -* **Node Operating System or Service State [compromised]** [int] +* **Node Operating System or Service State [compromised]** [float] The score to give when the state is compromised -* **Node Operating System or Service State [overwhelmed_should_be_good]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_good]** [float] The score to give when the state should be good, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float] The score to give when the state should be patching, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float] The score to give when the state should be compromised, but is overwhelmed -* **Node Operating System or Service State [overwhelmed]** [int] +* **Node Operating System or Service State [overwhelmed]** [float] The score to give when the state is overwhelmed -* **Node File System State [good_should_be_repairing]** [int] +* **Node File System State [good_should_be_repairing]** [float] The score to give when the state should be repairing, but is good -* **Node File System State [good_should_be_restoring]** [int] +* **Node File System State [good_should_be_restoring]** [float] The score to give when the state should be restoring, but is good -* **Node File System State [good_should_be_corrupt]** [int] +* **Node File System State [good_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is good -* **Node File System State [good_should_be_destroyed]** [int] +* **Node File System State [good_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is good -* **Node File System State [repairing_should_be_good]** [int] +* **Node File System State [repairing_should_be_good]** [float] The score to give when the state should be good, but is repairing -* **Node File System State [repairing_should_be_restoring]** [int] +* **Node File System State [repairing_should_be_restoring]** [float] The score to give when the state should be restoring, but is repairing -* **Node File System State [repairing_should_be_corrupt]** [int] +* **Node File System State [repairing_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is repairing -* **Node File System State [repairing_should_be_destroyed]** [int] +* **Node File System State [repairing_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is repairing -* **Node File System State [repairing]** [int] +* **Node File System State [repairing]** [float] The score to give when the state is repairing -* **Node File System State [restoring_should_be_good]** [int] +* **Node File System State [restoring_should_be_good]** [float] The score to give when the state should be good, but is restoring -* **Node File System State [restoring_should_be_repairing]** [int] +* **Node File System State [restoring_should_be_repairing]** [float] The score to give when the state should be repairing, but is restoring -* **Node File System State [restoring_should_be_corrupt]** [int] +* **Node File System State [restoring_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is restoring -* **Node File System State [restoring_should_be_destroyed]** [int] +* **Node File System State [restoring_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is restoring -* **Node File System State [restoring]** [int] +* **Node File System State [restoring]** [float] The score to give when the state is restoring -* **Node File System State [corrupt_should_be_good]** [int] +* **Node File System State [corrupt_should_be_good]** [float] The score to give when the state should be good, but is corrupt -* **Node File System State [corrupt_should_be_repairing]** [int] +* **Node File System State [corrupt_should_be_repairing]** [float] The score to give when the state should be repairing, but is corrupt -* **Node File System State [corrupt_should_be_restoring]** [int] +* **Node File System State [corrupt_should_be_restoring]** [float] The score to give when the state should be restoring, but is corrupt -* **Node File System State [corrupt_should_be_destroyed]** [int] +* **Node File System State [corrupt_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is corrupt -* **Node File System State [corrupt]** [int] +* **Node File System State [corrupt]** [float] The score to give when the state is corrupt -* **Node File System State [destroyed_should_be_good]** [int] +* **Node File System State [destroyed_should_be_good]** [float] The score to give when the state should be good, but is destroyed -* **Node File System State [destroyed_should_be_repairing]** [int] +* **Node File System State [destroyed_should_be_repairing]** [float] The score to give when the state should be repairing, but is destroyed -* **Node File System State [destroyed_should_be_restoring]** [int] +* **Node File System State [destroyed_should_be_restoring]** [float] The score to give when the state should be restoring, but is destroyed -* **Node File System State [destroyed_should_be_corrupt]** [int] +* **Node File System State [destroyed_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is destroyed -* **Node File System State [destroyed]** [int] +* **Node File System State [destroyed]** [float] The score to give when the state is destroyed -* **Node File System State [scanning]** [int] +* **Node File System State [scanning]** [float] The score to give when the state is scanning -* **IER Status [red_ier_running]** [int] +* **IER Status [red_ier_running]** [float] The score to give when a red agent IER is permitted to run -* **IER Status [green_ier_blocked]** [int] +* **IER Status [green_ier_blocked]** [float] The score to give when a green agent IER is prevented from running @@ -308,6 +308,14 @@ Rewards are calculated based on the difference between the current state and ref The number of steps to take when scanning the file system +* **deterministic** [bool] + + Set to true if the agent evaluation should be deterministic. Default is ``False`` + +* **seed** [int] + + Seed used in the randomisation in agent training. Default is ``None`` + The Lay Down Config ******************* diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 7e90724a..78ea8a36 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,19 +10,19 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules + self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules - def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): - """ - Checks for IP address matches. + def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: + """Checks for IP address matches. - Args: - _rule: The rule being checked - _source_ip_address: the source IP address to compare - _dest_ip_address: the destination IP address to compare - - Returns: - True if match; False otherwise. + :param _rule: The rule object to check + :type _rule: ACLRule + :param _source_ip_address: Source IP address to compare + :type _source_ip_address: str + :param _dest_ip_address: Destination IP address to compare + :type _dest_ip_address: str + :return: True if there is a match, otherwise False. + :rtype: bool """ if ( (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) @@ -34,7 +34,7 @@ class AccessControlList: else: return False - def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): + def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool: """ Checks for rules that block a protocol / port. diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 685fe776..32118597 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -257,10 +257,19 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass + @property + def _saved_agent_path(self) -> Path: + file_name = ( + f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip" + ) + return self.learning_path / file_name + @abstractmethod def save(self): """Save the agent.""" - self._agent.save(self.session_path) + pass @abstractmethod def export(self): diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 263ccbdc..f8c571c9 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,5 +1,9 @@ +from typing import Any, Dict, List, Union + import numpy as np +from primaite.acl.access_control_list import AccessControlList +from primaite.acl.acl_rule import ACLRule from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, @@ -7,13 +11,17 @@ from primaite.agents.utils import ( transform_action_acl_enum, transform_change_obs_readable, ) +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import HardCodedAgentView +from primaite.nodes.active_node import ActiveNode +from primaite.nodes.service_node import ServiceNode +from primaite.pol.ier import IER class HardCodedACLAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic ACL agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: # Basic view action using only the current observation return self._calculate_action_basic_view(obs) @@ -22,12 +30,19 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): # history and reward feedback return self._calculate_action_full_view(obs) - def get_blocked_green_iers(self, green_iers, acl, nodes): - """ - Get blocked green IERs. + def get_blocked_green_iers( + self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[Any, Any]: + """Get blocked green IERs. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param green_iers: Green IERs to check for being + :type green_iers: Dict[str, IER] + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: Same as `green_iers` input dict, but filtered to only contain the blocked ones. + :rtype: Dict[str, IER] """ blocked_green_iers = {} @@ -45,12 +60,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers - def get_matching_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get matching ACL rules for an IER. + def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]): + """Get list of ACL rules which are relevant to an IER. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ source_node_id = ier.get_source_node_id() source_node_address = nodes[source_node_id].ip_address @@ -58,11 +78,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): dest_node_address = nodes[dest_node_id].ip_address protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules - def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): + def get_blocking_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: """ Get blocking ACL rules for an IER. @@ -70,8 +91,14 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked). - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) @@ -82,12 +109,19 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_rules - def get_allow_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get all allowing ACL rules for an IER. + def get_allow_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: + """Get all allowing ACL rules for an IER. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) @@ -100,19 +134,32 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_matching_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get matching ACL rules. + source_node_id: str, + dest_node_id: str, + protocol: str, + port: str, + acl: AccessControlList, + nodes: Dict[str, Union[ServiceNode, ActiveNode]], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """Filter ACL rules to only those which are relevant to the specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node + :type source_node_id: str + :param dest_node_id: Destination nodes + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: str + :param port: Network port + :type port: str + :param acl: Access Control list which will be filtered + :type acl: AccessControlList + :param nodes: The environment's node directory. + :type nodes: Dict[str, Union[ServiceNode, ActiveNode]] + :param services_list: List of services registered for the environment. + :type services_list: List[str] + :return: Filtered version of 'acl' + :rtype: Dict[str, ACLRule] """ if source_node_id != "ANY": source_node_address = nodes[str(source_node_id)].ip_address @@ -132,19 +179,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_allow_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the ALLOW ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List ALLOW rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -165,19 +226,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_deny_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the DENY ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List DENY rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -196,7 +271,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def _calculate_action_full_view(self, obs): + def _calculate_action_full_view(self, obs: np.ndarray) -> int: """ Calculate a good acl-based action for the blue agent to take. @@ -224,8 +299,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing an overwhelmed state, so we don't do this. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ # obs = convert_to_old_obs(obs) r_obs = transform_change_obs_readable(obs) @@ -361,7 +438,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action = get_new_action(action, self._env.action_dict) return action - def _calculate_action_basic_view(self, obs): + def _calculate_action_basic_view(self, obs: np.ndarray) -> int: """Calculate a good acl-based action for the blue agent to take. Uses ONLY information from the current observation with NO knowledge @@ -379,8 +456,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): Currently, a deny rule does not overwrite an allow rule. The allow rules must be deleted. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 310fc178..c00cf421 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,3 +1,5 @@ +import numpy as np + from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable @@ -5,12 +7,14 @@ from primaite.agents.utils import get_new_action, transform_action_node_enum, tr class HardCodedNodeAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic Node agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: """ Calculate a good node-based action for the blue agent to take. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d851ba9c..0bc41762 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +import shutil from datetime import datetime from pathlib import Path from typing import Union +from uuid import uuid4 from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig @@ -106,6 +108,7 @@ class RLlibAgent(AgentSessionABC): timestamp_str=self.timestamp_str, ), ) + self._agent_config.seed = self._training_config.seed self._agent_config.training(train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") @@ -120,9 +123,11 @@ class RLlibAgent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - self._agent.save(str(self.checkpoints_path)) + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + self._agent.save(str(self.checkpoints_path)) def learn( self, @@ -140,9 +145,14 @@ class RLlibAgent(AgentSessionABC): for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() + self.save() self._agent.stop() + super().learn() + # save agent + self.save() + def evaluate( self, **kwargs, @@ -162,9 +172,25 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): + def save(self, overwrite_existing: bool = True): """Save the agent.""" - raise NotImplementedError + # Make temp dir to save in isolation + temp_dir = self.learning_path / str(uuid4()) + temp_dir.mkdir() + + # Save the agent to the temp dir + self._agent.save(str(temp_dir)) + + # Capture the saved Rllib checkpoint inside the temp directory + for file in temp_dir.iterdir(): + checkpoint_dir = file + break + + # Zip the folder + shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa + + # Drop the temp directory + shutil.rmtree(temp_dir) def export(self): """Export the agent to transportable file format.""" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index f5ac44cb..aa8e312d 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -59,16 +59,19 @@ class SB3Agent(AgentSessionABC): verbose=self.sb3_output_verbose_level, n_steps=self._training_config.num_steps, tensorboard_log=str(self._tensorboard_log_path), + seed=self._training_config.seed, ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - self._agent.save(checkpoint_path) - _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -90,25 +93,27 @@ class SB3Agent(AgentSessionABC): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() self._env.reset() + self.save() self._env.close() super().learn() + # save agent + self.save() + def evaluate( self, - deterministic: bool = True, **kwargs, ): """ Evaluate the agent. - :param deterministic: Whether the evaluation is deterministic. :param kwargs: Any agent-specific key-word args to be passed. """ time_steps = self._training_config.num_steps episodes = self._training_config.num_episodes self._env.set_as_eval() self.is_eval = True - if deterministic: + if self._training_config.deterministic: deterministic_str = "deterministic" else: deterministic_str = "non-deterministic" @@ -119,7 +124,7 @@ class SB3Agent(AgentSessionABC): obs = self._env.reset() for step in range(time_steps): - action, _states = self._agent.predict(obs, deterministic=deterministic) + action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic) if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) @@ -134,7 +139,7 @@ class SB3Agent(AgentSessionABC): def save(self): """Save the agent.""" - raise NotImplementedError + self._agent.save(self._saved_agent_path) def export(self): """Export the agent to transportable file format.""" diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index acc70cc4..ee681f86 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,5 +1,8 @@ +from typing import Dict, List, Union + import numpy as np +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( HardwareState, LinkStatus, @@ -10,15 +13,17 @@ from primaite.common.enums import ( ) -def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. +def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: + """Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_node_property = NodePOLType(action[1]).name @@ -33,15 +38,18 @@ def transform_action_node_readable(action): return new_action -def transform_action_acl_readable(action): +def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]: """ Transform an ACL action to a more readable format. example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} action_permissions = {0: "DENY", 1: "ALLOW"} @@ -58,7 +66,7 @@ def transform_action_acl_readable(action): return new_action -def is_valid_node_action(action): +def is_valid_node_action(action: List[int]) -> bool: """Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -67,8 +75,11 @@ def is_valid_node_action(action): - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_node_readable(action) @@ -93,7 +104,7 @@ def is_valid_node_action(action): return True -def is_valid_acl_action(action): +def is_valid_acl_action(action: List[int]) -> bool: """ Is the ACL action an actual valid action. @@ -103,8 +114,11 @@ def is_valid_acl_action(action): - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_acl_readable(action) @@ -126,12 +140,15 @@ def is_valid_acl_action(action): return True -def is_valid_acl_action_extra(action): +def is_valid_acl_action_extra(action: List[int]) -> bool: """ Harsher version of valid acl actions, does not allow action. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ if is_valid_acl_action(action) is False: return False @@ -150,15 +167,16 @@ def is_valid_acl_action_extra(action): return True -def transform_change_obs_readable(obs): - """ - Transform list of transactions to readable list of each observation property. +def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform list of transactions to readable list of each observation property. example: np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ ids = [i for i in obs[:, 0]] operating_states = [HardwareState(i).name for i in obs[:, 1]] @@ -173,14 +191,16 @@ def transform_change_obs_readable(obs): return new_obs -def transform_obs_readable(obs): - """ - Transform observation to readable format. +def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform observation to readable format. + example np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ changed_obs = transform_change_obs_readable(obs) new_obs = list(zip(*changed_obs)) @@ -190,21 +210,23 @@ def transform_obs_readable(obs): return new_obs -def convert_to_new_obs(obs, num_nodes=10): - """ - Convert original gym Box observation space to new multiDiscrete observation space. +def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray: + """Convert original gym Box observation space to new multiDiscrete observation space. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: observation in the 'old' (NodeLinkTable) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :return: reformatted observation + :rtype: np.ndarray """ # Remove ID columns, remove links and flatten to MultiDiscrete observation space new_obs = obs[:num_nodes, 1:].flatten() return new_obs -def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """ - Convert to old observation. +def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray: + """Convert to old observation. Links filled with 0's as no information is included in new observation space. @@ -216,8 +238,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): [ 3, 1, 1, 1], ... [20, 0, 0, 0]]) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + + :param obs: observation in the 'new' (MultiDiscrete) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :param num_links: number of links in the network, defaults to 10 + :type num_links: int, optional + :param num_services: number of services on the network, defaults to 1 + :type num_services: int, optional + :return: 2-d BOX observation space, in the same format as NodeLinkTable + :rtype: np.ndarray """ # Convert back to more readable, original format reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) @@ -239,17 +270,28 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """ - Return string describing change between two observations. +def describe_obs_change( + obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1 +) -> str: + """Build a string describing the difference between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) output = 'ID 1: SERVICE 2 set to GOOD' - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs1: First observation + :type obs1: np.ndarray + :param obs2: Second observation + :type obs2: np.ndarray + :param num_nodes: How many nodes are in the network laydown, defaults to 10 + :type num_nodes: int, optional + :param num_links: How many links are in the network laydown, defaults to 10 + :type num_links: int, optional + :param num_services: How many services are configured for this scenario, defaults to 1 + :type num_services: int, optional + :return: A multi-line string with a human-readable description of the difference. + :rtype: str """ obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) @@ -268,7 +310,7 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): return change_string -def _describe_obs_change_helper(obs_change, is_link): +def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str: """ Helper funcion to describe what has changed. @@ -277,8 +319,14 @@ def _describe_obs_change_helper(obs_change, is_link): Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one + row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new + status where it has changed. + :type obs_change: List[int] + :param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node. + :type is_link: bool + :return: A human-readable description of the difference between the two observation rows. + :rtype: str """ # Indexes where a change has occured, not including 0th index index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] @@ -304,15 +352,15 @@ def _describe_obs_change_helper(obs_change, is_link): return desc -def transform_action_node_enum(action): - """ - Convert a node action from readable string format, to enumerated format. +def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]: + """Convert a node action from readable string format, to enumerated format. example: [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Action in 'readable' format + :type action: List[Union[str,int]] + :return: Action with verbs encoded as ints + :rtype: List[int] """ action_node_id = action[0] action_node_property = NodePOLType[action[1]].value @@ -336,63 +384,14 @@ def transform_action_node_enum(action): return new_action -def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. - - example: - [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - action_node_property = NodePOLType(action[1]).name - - if action_node_property == "OPERATING": - property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: - property_action = NodeSoftwareAction(action[2]).name - else: - property_action = "NONE" - - new_action = [action[0], action_node_property, property_action, action[3]] - return new_action - - -def node_action_description(action): - """ - Generate string describing a node-based action. - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if isinstance(action[1], (int, np.int64)): - # transform action to readable format - action = transform_action_node_readable(action) - - node_id = action[0] - node_property = action[1] - property_action = action[2] - service_id = action[3] - - if property_action == "NONE": - return "" - if node_property == "OPERATING" or node_property == "OS": - description = f"NODE {node_id}, {node_property}, SET TO {property_action}" - elif node_property == "SERVICE": - description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" - else: - return "" - - return description - - -def transform_action_acl_enum(action): +def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray: """ Convert acl action from readable str format, to enumerated format. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: ACL-based action expressed as a list of human-readable ints and strings + :type action: List[Union[int,str]] + :return: The same action but encoded to contain only integers. + :rtype: np.ndarray """ action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} action_permissions = {"DENY": 0, "ALLOW": 1} @@ -410,35 +409,17 @@ def transform_action_acl_enum(action): return new_action -def acl_action_description(action): - """ - Generate string describing an acl-based action. - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if isinstance(action[0], (int, np.int64)): - # transform action to readable format - action = transform_action_acl_readable(action) - if action[0] == "NONE": - description = "NO ACL RULE APPLIED" - else: - description = ( - f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," - f" for protocol/service index {action[4]} on port index {action[5]}" - ) - - return description - - -def get_node_of_ip(ip, node_dict): - """ - Get the node ID of an IP address. +def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str: + """Get the node ID of an IP address. node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ip: The IP address of the node whose ID is required + :type ip: str + :param node_dict: The environment's node registry dictionary + :type node_dict: Dict[str,NodeUnion] + :return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip` + :rtype: str """ for node_key, node_value in node_dict.items(): node_ip = node_value.ip_address @@ -446,104 +427,18 @@ def get_node_of_ip(ip, node_dict): return node_key -def is_valid_node_action(action): - """Is the node action an actual valid action. - - Only uses information about the action to determine if the action has an effect - - Does NOT consider: - - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - - Node already being in that state (turning an ON node ON) - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - action_r = transform_action_node_readable(action) - - node_property = action_r[1] - node_action = action_r[2] - - if node_property == "NONE": - return False - if node_action == "NONE": - return False - if node_property == "OPERATING" and node_action == "PATCHING": - # Operating State cannot PATCH - return False - if node_property != "OPERATING" and node_action not in [ - "NONE", - "PATCHING", - ]: - # Software States can only do Nothing or Patch - return False - return True - - -def is_valid_acl_action(action): - """ - Is the ACL action an actual valid action. - - Only uses information about the action to determine if the action has an effect - - Does NOT consider: - - Trying to create identical rules - - Trying to create a rule which is a subset of another rule (caused by "ANY") - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - action_r = transform_action_acl_readable(action) - - action_decision = action_r[0] - action_permission = action_r[1] - action_source_id = action_r[2] - action_destination_id = action_r[3] - - if action_decision == "NONE": - return False - if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": - # ACL rule towards itself - return False - if action_permission == "DENY": - # DENY is unnecessary, we can create and delete allow rules instead - # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. - return False - - return True - - -def is_valid_acl_action_extra(action): - """ - Harsher version of valid acl actions, does not allow action. - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if is_valid_acl_action(action) is False: - return False - - action_r = transform_action_acl_readable(action) - action_protocol = action_r[4] - action_port = action_r[5] - - # Don't allow protocols or ports to be ANY - # in the future we might want to do the opposite, and only have ANY option for ports and service - if action_protocol == "ANY": - return False - if action_port == "ANY": - return False - - return True - - -def get_new_action(old_action, action_dict): +def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int: """ Get new action (e.g. 32) from old action e.g. [1,1,1,0]. Old_action can be either node or acl action type - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param old_action: Action expressed as a list of choices, eg. [1,1,1,0] + :type old_action: np.ndarray + :param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions. + :type action_dict: Dict[int,List] + :return: Action key correspoinding to the input `old_action` + :rtype: int """ for key, val in action_dict.items(): if list(val) == list(old_action): diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index a638fe14..15adc4dd 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -31,6 +31,16 @@ agent_identifier: PPO # False random_red_agent: False +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + # Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. # Options are: # "BASIC" (The current observation space only) @@ -83,58 +93,58 @@ sb3_output_verbose_level: NONE # Generic all_ok: 0 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index bd73f65b..e7b701c7 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -94,64 +94,64 @@ class TrainingConfig: # Reward values # Generic - all_ok: int = 0 + all_ok: float = 0 # Node Hardware State - off_should_be_on: int = -10 - off_should_be_resetting: int = -5 - on_should_be_off: int = -2 - on_should_be_resetting: int = -5 - resetting_should_be_on: int = -5 - resetting_should_be_off: int = -2 - resetting: int = -3 + off_should_be_on: float = -0.001 + off_should_be_resetting: float = -0.0005 + on_should_be_off: float = -0.0002 + on_should_be_resetting: float = -0.0005 + resetting_should_be_on: float = -0.0005 + resetting_should_be_off: float = -0.0002 + resetting: float = -0.0003 # Node Software or Service State - good_should_be_patching: int = 2 - good_should_be_compromised: int = 5 - good_should_be_overwhelmed: int = 5 - patching_should_be_good: int = -5 - patching_should_be_compromised: int = 2 - patching_should_be_overwhelmed: int = 2 - patching: int = -3 - compromised_should_be_good: int = -20 - compromised_should_be_patching: int = -20 - compromised_should_be_overwhelmed: int = -20 - compromised: int = -20 - overwhelmed_should_be_good: int = -20 - overwhelmed_should_be_patching: int = -20 - overwhelmed_should_be_compromised: int = -20 - overwhelmed: int = -20 + good_should_be_patching: float = 0.0002 + good_should_be_compromised: float = 0.0005 + good_should_be_overwhelmed: float = 0.0005 + patching_should_be_good: float = -0.0005 + patching_should_be_compromised: float = 0.0002 + patching_should_be_overwhelmed: float = 0.0002 + patching: float = -0.0003 + compromised_should_be_good: float = -0.002 + compromised_should_be_patching: float = -0.002 + compromised_should_be_overwhelmed: float = -0.002 + compromised: float = -0.002 + overwhelmed_should_be_good: float = -0.002 + overwhelmed_should_be_patching: float = -0.002 + overwhelmed_should_be_compromised: float = -0.002 + overwhelmed: float = -0.002 # Node File System State - good_should_be_repairing: int = 2 - good_should_be_restoring: int = 2 - good_should_be_corrupt: int = 5 - good_should_be_destroyed: int = 10 - repairing_should_be_good: int = -5 - repairing_should_be_restoring: int = 2 - repairing_should_be_corrupt: int = 2 - repairing_should_be_destroyed: int = 0 - repairing: int = -3 - restoring_should_be_good: int = -10 - restoring_should_be_repairing: int = -2 - restoring_should_be_corrupt: int = 1 - restoring_should_be_destroyed: int = 2 - restoring: int = -6 - corrupt_should_be_good: int = -10 - corrupt_should_be_repairing: int = -10 - corrupt_should_be_restoring: int = -10 - corrupt_should_be_destroyed: int = 2 - corrupt: int = -10 - destroyed_should_be_good: int = -20 - destroyed_should_be_repairing: int = -20 - destroyed_should_be_restoring: int = -20 - destroyed_should_be_corrupt: int = -20 - destroyed: int = -20 - scanning: int = -2 + good_should_be_repairing: float = 0.0002 + good_should_be_restoring: float = 0.0002 + good_should_be_corrupt: float = 0.0005 + good_should_be_destroyed: float = 0.001 + repairing_should_be_good: float = -0.0005 + repairing_should_be_restoring: float = 0.0002 + repairing_should_be_corrupt: float = 0.0002 + repairing_should_be_destroyed: float = 0.0000 + repairing: float = -0.0003 + restoring_should_be_good: float = -0.001 + restoring_should_be_repairing: float = -0.0002 + restoring_should_be_corrupt: float = 0.0001 + restoring_should_be_destroyed: float = 0.0002 + restoring: float = -0.0006 + corrupt_should_be_good: float = -0.001 + corrupt_should_be_repairing: float = -0.001 + corrupt_should_be_restoring: float = -0.001 + corrupt_should_be_destroyed: float = 0.0002 + corrupt: float = -0.001 + destroyed_should_be_good: float = -0.002 + destroyed_should_be_repairing: float = -0.002 + destroyed_should_be_restoring: float = -0.002 + destroyed_should_be_corrupt: float = -0.002 + destroyed: float = -0.002 + scanning: float = -0.0002 # IER status - red_ier_running: int = -5 - green_ier_blocked: int = -10 + red_ier_running: float = -0.0005 + green_ier_blocked: float = -0.001 # Patching / Reset durations os_patching_duration: int = 5 @@ -178,6 +178,12 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system" + deterministic: bool = False + "If true, the training will be deterministic" + + seed: Optional[int] = None + "The random number generator seed to be used while training the agent" + @classmethod def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: """ diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 03c23f93..3a40066a 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -142,10 +142,10 @@ class Primaite(Env): self.step_info = {} # Total reward - self.total_reward = 0 + self.total_reward: float = 0 # Average reward - self.average_reward = 0 + self.average_reward: float = 0 # Episode count self.episode_count = 0 @@ -283,9 +283,9 @@ class Primaite(Env): self._create_random_red_agent() # Reset counters and totals - self.total_reward = 0 + self.total_reward = 0.0 self.step_count = 0 - self.average_reward = 0 + self.average_reward = 0.0 # Update observations space and return self.update_environent_obs() diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 19094a18..e4353cb9 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -20,7 +20,7 @@ def calculate_reward_function( red_iers, step_count, config_values, -): +) -> float: """ Compares the states of the initial and final nodes/links to get a reward. @@ -33,7 +33,7 @@ def calculate_reward_function( step_count: current step config_values: Config values """ - reward_value = 0 + reward_value: float = 0.0 # For each node, compare hardware state, SoftwareState, service states for node_key, final_node in final_nodes.items(): @@ -94,7 +94,7 @@ def calculate_reward_function( return reward_value -def score_node_operating_state(final_node, initial_node, reference_node, config_values): +def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the hardware state of a node. @@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_operating_state = final_node.hardware_state reference_node_operating_state = reference_node.hardware_state @@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ return score -def score_node_os_state(final_node, initial_node, reference_node, config_values): +def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the Software State of a node. @@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_os_state = final_node.software_state reference_node_os_state = reference_node.software_state @@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) return score -def score_node_service_state(final_node, initial_node, reference_node, config_values): +def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the service state(s) of a node. @@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_services: Dict[str, Service] = final_node.services reference_node_services: Dict[str, Service] = reference_node.services @@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va return score -def score_node_file_system(final_node, initial_node, reference_node, config_values): +def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the file system state of a node. @@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu initial_node: The node before red and blue agents take effect reference_node: The node if there had been no red or blue effect """ - score = 0 + score: float = 0.0 final_node_file_system_state = final_node.file_system_state_actual reference_node_file_system_state = reference_node.file_system_state_actual diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index bff19bf8..1a8bd406 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -296,11 +296,17 @@ def apply_red_agent_node_pol( pass -def is_red_ier_incoming(node, iers, node_pol_type): - """ - Checks if the RED IER is incoming. +def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool: + """Checks if the RED IER is incoming. - TODO: Write more descriptive docstring with params and returns. + :param node: Destination node of the IER + :type node: NodeUnion + :param iers: Directory of IERs + :type iers: Dict[str,IER] + :param node_pol_type: Type of Pattern-Of-Life + :type node_pol_type: NodePOLType + :return: Whether the RED IER is incoming. + :rtype: bool """ node_id = node.node_id diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 7db2444a..95be8115 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -31,7 +31,7 @@ class Transaction(object): "The observation space before any actions are taken" self.obs_space_post = None "The observation space after any actions are taken" - self.reward = None + self.reward: float = None "The reward value" self.action_space = None "The action space invoked by the agent" diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml new file mode 100644 index 00000000..23cff44e --- /dev/null +++ b/tests/config/ppo_not_seeded_training_config.yaml @@ -0,0 +1,155 @@ +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: None + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS +# Number of episodes to run per session +num_episodes: 10 + +# Number of time_steps per episode +num_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 0 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + +# Reward values +# Generic +all_ok: 0.0000 +# Node Hardware State +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 +# Node Software or Service State +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 +# Node File System State +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 +# IER status +red_ier_running: -0.0005 +green_ier_blocked: -0.001 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/tests/config/ppo_seeded_training_config.yaml similarity index 52% rename from src/primaite/config/_package_data/training/training_config_random_red_agent.yaml rename to tests/config/ppo_seeded_training_config.yaml index 96243daf..181331d9 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/tests/config/ppo_seeded_training_config.yaml @@ -1,40 +1,94 @@ -# Main Config File +# Training Config File -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO # Sets whether Red Agent POL and IER is randomised. # Options are: # True # False -random_red_agent: True +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: 67890 + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: True + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session num_episodes: 10 + # Number of time_steps per episode num_steps: 256 -# Time delay between steps (for generic agents) -time_delay: 10 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 0 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL # Environment config values # The high value for the observation space observation_space_high_value: 1000000000 +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + # Reward values # Generic all_ok: 0 diff --git a/tests/conftest.py b/tests/conftest.py index af76b314..388bc034 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,7 +58,6 @@ class TempPrimaiteSession(PrimaiteSession): def __exit__(self, type, value, tb): shutil.rmtree(self.session_path) - shutil.rmtree(self.session_path.parent) _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index feff52f6..90c0cb5d 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -1,6 +1,7 @@ import tempfile from datetime import datetime from pathlib import Path +from uuid import uuid4 from primaite import getLogger @@ -14,9 +15,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path: :param session_timestamp: This is the datetime that the session started. :return: The session directory path. """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + session_path = Path(tempfile.gettempdir()) / "primaite" / str(uuid4()) session_path.mkdir(exist_ok=True, parents=True) _LOGGER.debug(f"Created temp session directory: {session_path}") return session_path diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index ae0b0870..75ea5882 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session): # Check that the network png file exists assert (session_path / f"network_{session.timestamp_str}.png").exists() + # Check that the saved agent exists + assert session._agent_session._saved_agent_path.exists() + # Check that both the transactions and av reward csv files exist for file in session.learning_path.iterdir(): if file.suffix == ".csv": diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py new file mode 100644 index 00000000..34cb43fb --- /dev/null +++ b/tests/test_seeding_and_deterministic_session.py @@ -0,0 +1,49 @@ +import pytest as pytest + +from primaite.config.lay_down_config import dos_very_basic_config_path +from tests import TEST_CONFIG_ROOT + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]], + indirect=True, +) +def test_seeded_learning(temp_primaite_session): + """Test running seeded learning produces the same output when ran twice.""" + expected_mean_reward_per_episode = { + 1: -90.703125, + 2: -91.15234375, + 3: -87.5, + 4: -92.2265625, + 5: -94.6875, + 6: -91.19140625, + 7: -88.984375, + 8: -88.3203125, + 9: -112.79296875, + 10: -100.01953125, + } + with temp_primaite_session as session: + assert session._training_config.seed == 67890, ( + "Expected output is based upon a agent that was trained with " "seed 67890" + ) + session.learn() + actual_mean_reward_per_episode = session.learn_av_reward_per_episode() + + assert actual_mean_reward_per_episode == expected_mean_reward_per_episode + + +@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.") +@pytest.mark.parametrize( + "temp_primaite_session", + [[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]], + indirect=True, +) +def test_deterministic_evaluation(temp_primaite_session): + """Test running deterministic evaluation gives same av eward per episode.""" + with temp_primaite_session as session: + # do stuff + session.learn() + session.evaluate() + eval_mean_reward = session.eval_av_reward_per_episode_csv() + assert len(set(eval_mean_reward.values())) == 1