diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 7e90724a..f5f70d6c 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,19 +10,19 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules + self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules - def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): - """ - Checks for IP address matches. + def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: + """Checks for IP address matches. - Args: - _rule: The rule being checked - _source_ip_address: the source IP address to compare - _dest_ip_address: the destination IP address to compare - - Returns: - True if match; False otherwise. + :param _rule: The rule object to check + :type _rule: ACLRule + :param _source_ip_address: Source IP address to compare + :type _source_ip_address: str + :param _dest_ip_address: Destination IP address to compare + :type _dest_ip_address: str + :return: True if there is a match, otherwise False. + :rtype: bool """ if ( (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 430c8b54..a724e10e 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,7 +1,9 @@ -from typing import Any, Dict -import numpy as np -from primaite.acl.access_control_list import AccessControlList +from typing import Any, Dict, List, Union +import numpy as np + +from primaite.acl.access_control_list import AccessControlList +from primaite.acl.acl_rule import ACLRule from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, @@ -11,13 +13,15 @@ from primaite.agents.utils import ( ) from primaite.common.custom_typing import NodeUnion from primaite.common.enums import HardCodedAgentView +from primaite.nodes.active_node import ActiveNode +from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER class HardCodedACLAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic ACL agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: # Basic view action using only the current observation return self._calculate_action_basic_view(obs) @@ -26,7 +30,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): # history and reward feedback return self._calculate_action_full_view(obs) - def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]): + def get_blocked_green_iers( + self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[Any, Any]: """Get blocked green IERs. :param green_iers: Green IERs to check for being @@ -54,8 +60,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers - def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]): - """Get list of ACL rules which are relevant to an IER + def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]): + """Get list of ACL rules which are relevant to an IER. :param ier: Information Exchange Request to query against the ACL list :type ier: IER @@ -76,7 +82,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules - def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]: + def get_blocking_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: """ Get blocking ACL rules for an IER. @@ -102,9 +110,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_rules - def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]: - """ - Get all allowing ACL rules for an IER. + def get_allow_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: + """Get all allowing ACL rules for an IER. :param ier: Information Exchange Request to query against the ACL list :type ier: IER @@ -126,19 +135,32 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_matching_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get matching ACL rules. + source_node_id: str, + dest_node_id: str, + protocol: str, + port: str, + acl: AccessControlList, + nodes: Dict[str, Union[ServiceNode, ActiveNode]], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """Filter ACL rules to only those which are relevant to the specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node + :type source_node_id: str + :param dest_node_id: Destination nodes + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: str + :param port: Network port + :type port: str + :param acl: Access Control list which will be filtered + :type acl: AccessControlList + :param nodes: The environment's node directory. + :type nodes: Dict[str, Union[ServiceNode, ActiveNode]] + :param services_list: List of services registered for the environment. + :type services_list: List[str] + :return: Filtered version of 'acl' + :rtype: Dict[str, ACLRule] """ if source_node_id != "ANY": source_node_address = nodes[str(source_node_id)].ip_address @@ -158,19 +180,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_allow_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the ALLOW ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List ALLOW rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -191,19 +227,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_deny_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the DENY ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List DENY rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -222,7 +272,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def _calculate_action_full_view(self, obs): + def _calculate_action_full_view(self, obs: np.ndarray) -> int: """ Calculate a good acl-based action for the blue agent to take. @@ -250,8 +300,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing an overwhelmed state, so we don't do this. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ # obs = convert_to_old_obs(obs) r_obs = transform_change_obs_readable(obs) @@ -387,7 +439,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action = get_new_action(action, self._env.action_dict) return action - def _calculate_action_basic_view(self, obs): + def _calculate_action_basic_view(self, obs: np.ndarray) -> int: """Calculate a good acl-based action for the blue agent to take. Uses ONLY information from the current observation with NO knowledge @@ -405,8 +457,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): Currently, a deny rule does not overwrite an allow rule. The allow rules must be deleted. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 310fc178..c00cf421 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,3 +1,5 @@ +import numpy as np + from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable @@ -5,12 +7,14 @@ from primaite.agents.utils import get_new_action, transform_action_node_enum, tr class HardCodedNodeAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic Node agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: """ Calculate a good node-based action for the blue agent to take. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index b5a3c673..bffdbf43 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,3 +1,5 @@ +from typing import List, Union + import numpy as np from primaite.common.enums import ( @@ -10,15 +12,17 @@ from primaite.common.enums import ( ) -def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. +def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: + """Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_node_property = NodePOLType(action[1]).name @@ -33,15 +37,18 @@ def transform_action_node_readable(action): return new_action -def transform_action_acl_readable(action): +def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]: """ Transform an ACL action to a more readable format. example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} action_permissions = {0: "DENY", 1: "ALLOW"} @@ -58,7 +65,7 @@ def transform_action_acl_readable(action): return new_action -def is_valid_node_action(action): +def is_valid_node_action(action: List[int]) -> bool: """Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -67,8 +74,11 @@ def is_valid_node_action(action): - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_node_readable(action) @@ -93,7 +103,7 @@ def is_valid_node_action(action): return True -def is_valid_acl_action(action): +def is_valid_acl_action(action: List[int]) -> bool: """ Is the ACL action an actual valid action. @@ -103,8 +113,11 @@ def is_valid_acl_action(action): - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_acl_readable(action) @@ -126,12 +139,15 @@ def is_valid_acl_action(action): return True -def is_valid_acl_action_extra(action): +def is_valid_acl_action_extra(action: List[int]) -> bool: """ Harsher version of valid acl actions, does not allow action. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ if is_valid_acl_action(action) is False: return False @@ -150,15 +166,16 @@ def is_valid_acl_action_extra(action): return True -def transform_change_obs_readable(obs): - """ - Transform list of transactions to readable list of each observation property. +def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform list of transactions to readable list of each observation property. example: np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ ids = [i for i in obs[:, 0]] operating_states = [HardwareState(i).name for i in obs[:, 1]] @@ -174,14 +191,16 @@ def transform_change_obs_readable(obs): return new_obs -def transform_obs_readable(obs): - """ - Transform observation to readable format. +def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform observation to readable format. + example np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ changed_obs = transform_change_obs_readable(obs) new_obs = list(zip(*changed_obs)) @@ -191,21 +210,23 @@ def transform_obs_readable(obs): return new_obs -def convert_to_new_obs(obs, num_nodes=10): - """ - Convert original gym Box observation space to new multiDiscrete observation space. +def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray: + """Convert original gym Box observation space to new multiDiscrete observation space. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: observation in the 'old' (NodeLinkTable) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :return: reformatted observation + :rtype: np.ndarray """ # Remove ID columns, remove links and flatten to MultiDiscrete observation space new_obs = obs[:num_nodes, 1:].flatten() return new_obs -def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """ - Convert to old observation. +def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray: + """Convert to old observation. Links filled with 0's as no information is included in new observation space. @@ -217,8 +238,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): [ 3, 1, 1, 1], ... [20, 0, 0, 0]]) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + + :param obs: observation in the 'new' (MultiDiscrete) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :param num_links: number of links in the network, defaults to 10 + :type num_links: int, optional + :param num_services: number of services on the network, defaults to 1 + :type num_services: int, optional + :return: 2-d BOX observation space, in the same format as NodeLinkTable + :rtype: np.ndarray """ # Convert back to more readable, original format reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) @@ -240,17 +270,28 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """ - Return string describing change between two observations. +def describe_obs_change( + obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1 +) -> str: + """Build a string describing the difference between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) output = 'ID 1: SERVICE 2 set to GOOD' - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs1: First observation + :type obs1: np.ndarray + :param obs2: Second observation + :type obs2: np.ndarray + :param num_nodes: How many nodes are in the network laydown, defaults to 10 + :type num_nodes: int, optional + :param num_links: How many links are in the network laydown, defaults to 10 + :type num_links: int, optional + :param num_services: How many services are configured for this scenario, defaults to 1 + :type num_services: int, optional + :return: A multi-line string with a human-readable description of the difference. + :rtype: str """ obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)