diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 3b0e9234..78ea8a36 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,19 +10,19 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules + self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules - def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): - """ - Checks for IP address matches. + def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: + """Checks for IP address matches. - Args: - _rule: The rule being checked - _source_ip_address: the source IP address to compare - _dest_ip_address: the destination IP address to compare - - Returns: - True if match; False otherwise. + :param _rule: The rule object to check + :type _rule: ACLRule + :param _source_ip_address: Source IP address to compare + :type _source_ip_address: str + :param _dest_ip_address: Destination IP address to compare + :type _dest_ip_address: str + :return: True if there is a match, otherwise False. + :rtype: bool """ if ( (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) @@ -34,7 +34,7 @@ class AccessControlList: else: return False - def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): + def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool: """ Checks for rules that block a protocol / port. @@ -116,3 +116,27 @@ class AccessControlList: rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) return hash_value + + def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port): + """Get all ACL rules that relate to the given arguments. + + :param _source_ip_address: the source IP address to check + :param _dest_ip_address: the destination IP address to check + :param _protocol: the protocol to check + :param _port: the port to check + :return: Dictionary of all ACL rules that relate to the given arguments + :rtype: Dict[str, ACLRule] + """ + relevant_rules = {} + + for rule_key, rule_value in self.acl.items(): + if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): + if ( + rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY" + ) and ( + str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY" + ): + # There's a matching rule. + relevant_rules[rule_key] = rule_value + + return relevant_rules diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 263ccbdc..f8c571c9 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,5 +1,9 @@ +from typing import Any, Dict, List, Union + import numpy as np +from primaite.acl.access_control_list import AccessControlList +from primaite.acl.acl_rule import ACLRule from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, @@ -7,13 +11,17 @@ from primaite.agents.utils import ( transform_action_acl_enum, transform_change_obs_readable, ) +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import HardCodedAgentView +from primaite.nodes.active_node import ActiveNode +from primaite.nodes.service_node import ServiceNode +from primaite.pol.ier import IER class HardCodedACLAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic ACL agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: # Basic view action using only the current observation return self._calculate_action_basic_view(obs) @@ -22,12 +30,19 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): # history and reward feedback return self._calculate_action_full_view(obs) - def get_blocked_green_iers(self, green_iers, acl, nodes): - """ - Get blocked green IERs. + def get_blocked_green_iers( + self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[Any, Any]: + """Get blocked green IERs. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param green_iers: Green IERs to check for being + :type green_iers: Dict[str, IER] + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: Same as `green_iers` input dict, but filtered to only contain the blocked ones. + :rtype: Dict[str, IER] """ blocked_green_iers = {} @@ -45,12 +60,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers - def get_matching_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get matching ACL rules for an IER. + def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]): + """Get list of ACL rules which are relevant to an IER. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ source_node_id = ier.get_source_node_id() source_node_address = nodes[source_node_id].ip_address @@ -58,11 +78,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): dest_node_address = nodes[dest_node_id].ip_address protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules - def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): + def get_blocking_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: """ Get blocking ACL rules for an IER. @@ -70,8 +91,14 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked). - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) @@ -82,12 +109,19 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_rules - def get_allow_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get all allowing ACL rules for an IER. + def get_allow_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: + """Get all allowing ACL rules for an IER. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) @@ -100,19 +134,32 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_matching_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get matching ACL rules. + source_node_id: str, + dest_node_id: str, + protocol: str, + port: str, + acl: AccessControlList, + nodes: Dict[str, Union[ServiceNode, ActiveNode]], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """Filter ACL rules to only those which are relevant to the specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node + :type source_node_id: str + :param dest_node_id: Destination nodes + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: str + :param port: Network port + :type port: str + :param acl: Access Control list which will be filtered + :type acl: AccessControlList + :param nodes: The environment's node directory. + :type nodes: Dict[str, Union[ServiceNode, ActiveNode]] + :param services_list: List of services registered for the environment. + :type services_list: List[str] + :return: Filtered version of 'acl' + :rtype: Dict[str, ACLRule] """ if source_node_id != "ANY": source_node_address = nodes[str(source_node_id)].ip_address @@ -132,19 +179,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_allow_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the ALLOW ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List ALLOW rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -165,19 +226,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_deny_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the DENY ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List DENY rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -196,7 +271,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def _calculate_action_full_view(self, obs): + def _calculate_action_full_view(self, obs: np.ndarray) -> int: """ Calculate a good acl-based action for the blue agent to take. @@ -224,8 +299,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing an overwhelmed state, so we don't do this. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ # obs = convert_to_old_obs(obs) r_obs = transform_change_obs_readable(obs) @@ -361,7 +438,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action = get_new_action(action, self._env.action_dict) return action - def _calculate_action_basic_view(self, obs): + def _calculate_action_basic_view(self, obs: np.ndarray) -> int: """Calculate a good acl-based action for the blue agent to take. Uses ONLY information from the current observation with NO knowledge @@ -379,8 +456,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): Currently, a deny rule does not overwrite an allow rule. The allow rules must be deleted. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 310fc178..c00cf421 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,3 +1,5 @@ +import numpy as np + from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable @@ -5,12 +7,14 @@ from primaite.agents.utils import get_new_action, transform_action_node_enum, tr class HardCodedNodeAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic Node agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: """ Calculate a good node-based action for the blue agent to take. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 8c59faf7..5f5261e0 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,5 +1,8 @@ +from typing import Dict, List, Union + import numpy as np +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( HardwareState, LinkStatus, @@ -10,15 +13,17 @@ from primaite.common.enums import ( ) -def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. +def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: + """Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_node_property = NodePOLType(action[1]).name @@ -33,15 +38,18 @@ def transform_action_node_readable(action): return new_action -def transform_action_acl_readable(action): +def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]: """ Transform an ACL action to a more readable format. example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} action_permissions = {0: "DENY", 1: "ALLOW"} @@ -58,7 +66,7 @@ def transform_action_acl_readable(action): return new_action -def is_valid_node_action(action): +def is_valid_node_action(action: List[int]) -> bool: """Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -67,8 +75,11 @@ def is_valid_node_action(action): - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_node_readable(action) @@ -93,7 +104,7 @@ def is_valid_node_action(action): return True -def is_valid_acl_action(action): +def is_valid_acl_action(action: List[int]) -> bool: """ Is the ACL action an actual valid action. @@ -103,8 +114,11 @@ def is_valid_acl_action(action): - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_acl_readable(action) @@ -126,12 +140,15 @@ def is_valid_acl_action(action): return True -def is_valid_acl_action_extra(action): +def is_valid_acl_action_extra(action: List[int]) -> bool: """ Harsher version of valid acl actions, does not allow action. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ if is_valid_acl_action(action) is False: return False @@ -150,22 +167,24 @@ def is_valid_acl_action_extra(action): return True -def transform_change_obs_readable(obs): - """ - Transform list of transactions to readable list of each observation property. +def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform list of transactions to readable list of each observation property. example: np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ ids = [i for i in obs[:, 0]] operating_states = [HardwareState(i).name for i in obs[:, 1]] os_states = [SoftwareState(i).name for i in obs[:, 2]] new_obs = [ids, operating_states, os_states] - for service in range(3, obs.shape[1]): + # changed range(3,...) to range(4,...) because we added file system which was new since ADSP + for service in range(4, obs.shape[1]): # Links bit/s don't have a service state service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] new_obs.append(service_states) @@ -173,14 +192,16 @@ def transform_change_obs_readable(obs): return new_obs -def transform_obs_readable(obs): - """ - Transform observation to readable format. +def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform observation to readable format. + example np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ changed_obs = transform_change_obs_readable(obs) new_obs = list(zip(*changed_obs)) @@ -190,21 +211,23 @@ def transform_obs_readable(obs): return new_obs -def convert_to_new_obs(obs, num_nodes=10): - """ - Convert original gym Box observation space to new multiDiscrete observation space. +def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray: + """Convert original gym Box observation space to new multiDiscrete observation space. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: observation in the 'old' (NodeLinkTable) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :return: reformatted observation + :rtype: np.ndarray """ # Remove ID columns, remove links and flatten to MultiDiscrete observation space new_obs = obs[:num_nodes, 1:].flatten() return new_obs -def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """ - Convert to old observation. +def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray: + """Convert to old observation. Links filled with 0's as no information is included in new observation space. @@ -216,8 +239,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): [ 3, 1, 1, 1], ... [20, 0, 0, 0]]) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + + :param obs: observation in the 'new' (MultiDiscrete) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :param num_links: number of links in the network, defaults to 10 + :type num_links: int, optional + :param num_services: number of services on the network, defaults to 1 + :type num_services: int, optional + :return: 2-d BOX observation space, in the same format as NodeLinkTable + :rtype: np.ndarray """ # Convert back to more readable, original format reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) @@ -239,17 +271,28 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """ - Return string describing change between two observations. +def describe_obs_change( + obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1 +) -> str: + """Build a string describing the difference between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) output = 'ID 1: SERVICE 2 set to GOOD' - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs1: First observation + :type obs1: np.ndarray + :param obs2: Second observation + :type obs2: np.ndarray + :param num_nodes: How many nodes are in the network laydown, defaults to 10 + :type num_nodes: int, optional + :param num_links: How many links are in the network laydown, defaults to 10 + :type num_links: int, optional + :param num_services: How many services are configured for this scenario, defaults to 1 + :type num_services: int, optional + :return: A multi-line string with a human-readable description of the difference. + :rtype: str """ obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) @@ -268,7 +311,7 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): return change_string -def _describe_obs_change_helper(obs_change, is_link): +def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str: """ Helper funcion to describe what has changed. @@ -277,8 +320,14 @@ def _describe_obs_change_helper(obs_change, is_link): Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one + row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new + status where it has changed. + :type obs_change: List[int] + :param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node. + :type is_link: bool + :return: A human-readable description of the difference between the two observation rows. + :rtype: str """ # Indexes where a change has occured, not including 0th index index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] @@ -304,15 +353,15 @@ def _describe_obs_change_helper(obs_change, is_link): return desc -def transform_action_node_enum(action): - """ - Convert a node action from readable string format, to enumerated format. +def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]: + """Convert a node action from readable string format, to enumerated format. example: [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Action in 'readable' format + :type action: List[Union[str,int]] + :return: Action with verbs encoded as ints + :rtype: List[int] """ action_node_id = action[0] action_node_property = NodePOLType[action[1]].value @@ -336,63 +385,14 @@ def transform_action_node_enum(action): return new_action -def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. - - example: - [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - action_node_property = NodePOLType(action[1]).name - - if action_node_property == "OPERATING": - property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: - property_action = NodeSoftwareAction(action[2]).name - else: - property_action = "NONE" - - new_action = [action[0], action_node_property, property_action, action[3]] - return new_action - - -def node_action_description(action): - """ - Generate string describing a node-based action. - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if isinstance(action[1], (int, np.int64)): - # transform action to readable format - action = transform_action_node_readable(action) - - node_id = action[0] - node_property = action[1] - property_action = action[2] - service_id = action[3] - - if property_action == "NONE": - return "" - if node_property == "OPERATING" or node_property == "OS": - description = f"NODE {node_id}, {node_property}, SET TO {property_action}" - elif node_property == "SERVICE": - description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" - else: - return "" - - return description - - -def transform_action_acl_enum(action): +def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray: """ Convert acl action from readable str format, to enumerated format. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: ACL-based action expressed as a list of human-readable ints and strings + :type action: List[Union[int,str]] + :return: The same action but encoded to contain only integers. + :rtype: np.ndarray """ action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} action_permissions = {"DENY": 0, "ALLOW": 1} @@ -410,35 +410,17 @@ def transform_action_acl_enum(action): return new_action -def acl_action_description(action): - """ - Generate string describing an acl-based action. - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if isinstance(action[0], (int, np.int64)): - # transform action to readable format - action = transform_action_acl_readable(action) - if action[0] == "NONE": - description = "NO ACL RULE APPLIED" - else: - description = ( - f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," - f" for protocol/service index {action[4]} on port index {action[5]}" - ) - - return description - - -def get_node_of_ip(ip, node_dict): - """ - Get the node ID of an IP address. +def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str: + """Get the node ID of an IP address. node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ip: The IP address of the node whose ID is required + :type ip: str + :param node_dict: The environment's node registry dictionary + :type node_dict: Dict[str,NodeUnion] + :return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip` + :rtype: str """ for node_key, node_value in node_dict.items(): node_ip = node_value.ip_address @@ -446,104 +428,18 @@ def get_node_of_ip(ip, node_dict): return node_key -def is_valid_node_action(action): - """Is the node action an actual valid action. - - Only uses information about the action to determine if the action has an effect - - Does NOT consider: - - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - - Node already being in that state (turning an ON node ON) - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - action_r = transform_action_node_readable(action) - - node_property = action_r[1] - node_action = action_r[2] - - if node_property == "NONE": - return False - if node_action == "NONE": - return False - if node_property == "OPERATING" and node_action == "PATCHING": - # Operating State cannot PATCH - return False - if node_property != "OPERATING" and node_action not in [ - "NONE", - "PATCHING", - ]: - # Software States can only do Nothing or Patch - return False - return True - - -def is_valid_acl_action(action): - """ - Is the ACL action an actual valid action. - - Only uses information about the action to determine if the action has an effect - - Does NOT consider: - - Trying to create identical rules - - Trying to create a rule which is a subset of another rule (caused by "ANY") - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - action_r = transform_action_acl_readable(action) - - action_decision = action_r[0] - action_permission = action_r[1] - action_source_id = action_r[2] - action_destination_id = action_r[3] - - if action_decision == "NONE": - return False - if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": - # ACL rule towards itself - return False - if action_permission == "DENY": - # DENY is unnecessary, we can create and delete allow rules instead - # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. - return False - - return True - - -def is_valid_acl_action_extra(action): - """ - Harsher version of valid acl actions, does not allow action. - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if is_valid_acl_action(action) is False: - return False - - action_r = transform_action_acl_readable(action) - action_protocol = action_r[4] - action_port = action_r[5] - - # Don't allow protocols or ports to be ANY - # in the future we might want to do the opposite, and only have ANY option for ports and service - if action_protocol == "ANY": - return False - if action_port == "ANY": - return False - - return True - - -def get_new_action(old_action, action_dict): +def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int: """ Get new action (e.g. 32) from old action e.g. [1,1,1,0]. Old_action can be either node or acl action type - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param old_action: Action expressed as a list of choices, eg. [1,1,1,0] + :type old_action: np.ndarray + :param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions. + :type action_dict: Dict[int,List] + :return: Action key correspoinding to the input `old_action` + :rtype: int """ for key, val in action_dict.items(): if list(val) == list(old_action): diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index bff19bf8..1a8bd406 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -296,11 +296,17 @@ def apply_red_agent_node_pol( pass -def is_red_ier_incoming(node, iers, node_pol_type): - """ - Checks if the RED IER is incoming. +def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool: + """Checks if the RED IER is incoming. - TODO: Write more descriptive docstring with params and returns. + :param node: Destination node of the IER + :type node: NodeUnion + :param iers: Directory of IERs + :type iers: Dict[str,IER] + :param node_pol_type: Type of Pattern-Of-Life + :type node_pol_type: NodePOLType + :return: Whether the RED IER is incoming. + :rtype: bool """ node_id = node.node_id