From 3ced1a19137ae40d287b8b785cf6d2aa0b5af70d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 5 Jul 2023 09:54:50 +0100 Subject: [PATCH 1/8] Update some param descriptions for hardcoded agent --- src/primaite/agents/hardcoded_acl.py | 60 ++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 263ccbdc..430c8b54 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,4 +1,6 @@ +from typing import Any, Dict import numpy as np +from primaite.acl.access_control_list import AccessControlList from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import ( @@ -7,7 +9,9 @@ from primaite.agents.utils import ( transform_action_acl_enum, transform_change_obs_readable, ) +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import HardCodedAgentView +from primaite.pol.ier import IER class HardCodedACLAgent(HardCodedAgentSessionABC): @@ -22,12 +26,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): # history and reward feedback return self._calculate_action_full_view(obs) - def get_blocked_green_iers(self, green_iers, acl, nodes): - """ - Get blocked green IERs. + def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]): + """Get blocked green IERs. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param green_iers: Green IERs to check for being + :type green_iers: Dict[str, IER] + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: Same as `green_iers` input dict, but filtered to only contain the blocked ones. + :rtype: Dict[str, IER] """ blocked_green_iers = {} @@ -45,12 +54,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers - def get_matching_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get matching ACL rules for an IER. + def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]): + """Get list of ACL rules which are relevant to an IER - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ source_node_id = ier.get_source_node_id() source_node_address = nodes[source_node_id].ip_address @@ -58,11 +72,11 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): dest_node_address = nodes[dest_node_id].ip_address protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - + # I can't find where this function 'get_relevant_rules' is defined... matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules - def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): + def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]: """ Get blocking ACL rules for an IER. @@ -70,8 +84,14 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked). - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) @@ -82,12 +102,18 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_rules - def get_allow_acl_rules_for_ier(self, ier, acl, nodes): + def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]: """ Get all allowing ACL rules for an IER. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ier: Information Exchange Request to query against the ACL list + :type ier: IER + :param acl: Firewall rules + :type acl: AccessControlList + :param nodes: Nodes in the network + :type nodes: Dict[str,NodeUnion] + :return: _description_ + :rtype: _type_ """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) From b3d4eb4ec0920201291daa8c1feecc17ffb5baca Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 5 Jul 2023 13:58:46 +0100 Subject: [PATCH 2/8] Changed hardcoded agent helper for new obs space --- src/primaite/agents/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 8c59faf7..b5a3c673 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -165,7 +165,8 @@ def transform_change_obs_readable(obs): os_states = [SoftwareState(i).name for i in obs[:, 2]] new_obs = [ids, operating_states, os_states] - for service in range(3, obs.shape[1]): + # changed range(3,...) to range(4,...) because we added file system which was new since ADSP + for service in range(4, obs.shape[1]): # Links bit/s don't have a service state service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] new_obs.append(service_states) From 171b5cb58e0e56f242d18cc98f423a98e6291007 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 5 Jul 2023 14:10:52 +0100 Subject: [PATCH 3/8] Imported ADSP function for ACL --- src/primaite/acl/access_control_list.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 3b0e9234..c155deed 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -116,3 +116,27 @@ class AccessControlList: rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) return hash_value + + def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port): + """Get all ACL rules that relate to the given arguments + + :param _source_ip_address: the source IP address to check + :param _dest_ip_address: the destination IP address to check + :param _protocol: the protocol to check + :param _port: the port to check + :return: Dictionary of all ACL rules that relate to the given arguments + :rtype: Dict[str, ACLRule] + """ + relevant_rules = {} + + for rule_key, rule_value in self.acl.items(): + if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): + if ( + rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY" or _protocol == "ANY" + ) and ( + str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" or str(_port) == "ANY" + ): + # There's a matching rule. + relevant_rules[rule_key] = rule_value + + return relevant_rules From 7482aead76d702b70527e550995fc3a7c25a9b98 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 5 Jul 2023 14:50:03 +0100 Subject: [PATCH 4/8] typo --- src/primaite/acl/access_control_list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index c155deed..7e90724a 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -118,7 +118,7 @@ class AccessControlList: return hash_value def get_relevant_rules(self, _source_ip_address, _dest_ip_address, _protocol, _port): - """Get all ACL rules that relate to the given arguments + """Get all ACL rules that relate to the given arguments. :param _source_ip_address: the source IP address to check :param _dest_ip_address: the destination IP address to check From 5c167293e3833dd558294cec3dafd19c8e915c96 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 5 Jul 2023 16:19:43 +0100 Subject: [PATCH 5/8] Add docstrings and type hints. --- src/primaite/acl/access_control_list.py | 22 ++-- src/primaite/agents/hardcoded_acl.py | 160 ++++++++++++++++-------- src/primaite/agents/hardcoded_node.py | 10 +- src/primaite/agents/utils.py | 125 +++++++++++------- 4 files changed, 208 insertions(+), 109 deletions(-) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 7e90724a..f5f70d6c 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,19 +10,19 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules + self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules - def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): - """ - Checks for IP address matches. + def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: + """Checks for IP address matches. - Args: - _rule: The rule being checked - _source_ip_address: the source IP address to compare - _dest_ip_address: the destination IP address to compare - - Returns: - True if match; False otherwise. + :param _rule: The rule object to check + :type _rule: ACLRule + :param _source_ip_address: Source IP address to compare + :type _source_ip_address: str + :param _dest_ip_address: Destination IP address to compare + :type _dest_ip_address: str + :return: True if there is a match, otherwise False. + :rtype: bool """ if ( (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 430c8b54..a724e10e 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -1,7 +1,9 @@ -from typing import Any, Dict -import numpy as np -from primaite.acl.access_control_list import AccessControlList +from typing import Any, Dict, List, Union +import numpy as np + +from primaite.acl.access_control_list import AccessControlList +from primaite.acl.acl_rule import ACLRule from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, @@ -11,13 +13,15 @@ from primaite.agents.utils import ( ) from primaite.common.custom_typing import NodeUnion from primaite.common.enums import HardCodedAgentView +from primaite.nodes.active_node import ActiveNode +from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER class HardCodedACLAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic ACL agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: # Basic view action using only the current observation return self._calculate_action_basic_view(obs) @@ -26,7 +30,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): # history and reward feedback return self._calculate_action_full_view(obs) - def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]): + def get_blocked_green_iers( + self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[Any, Any]: """Get blocked green IERs. :param green_iers: Green IERs to check for being @@ -54,8 +60,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers - def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]): - """Get list of ACL rules which are relevant to an IER + def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]): + """Get list of ACL rules which are relevant to an IER. :param ier: Information Exchange Request to query against the ACL list :type ier: IER @@ -76,7 +82,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules - def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]: + def get_blocking_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: """ Get blocking ACL rules for an IER. @@ -102,9 +110,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_rules - def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]: - """ - Get all allowing ACL rules for an IER. + def get_allow_acl_rules_for_ier( + self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] + ) -> Dict[str, Any]: + """Get all allowing ACL rules for an IER. :param ier: Information Exchange Request to query against the ACL list :type ier: IER @@ -126,19 +135,32 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_matching_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get matching ACL rules. + source_node_id: str, + dest_node_id: str, + protocol: str, + port: str, + acl: AccessControlList, + nodes: Dict[str, Union[ServiceNode, ActiveNode]], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """Filter ACL rules to only those which are relevant to the specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node + :type source_node_id: str + :param dest_node_id: Destination nodes + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: str + :param port: Network port + :type port: str + :param acl: Access Control list which will be filtered + :type acl: AccessControlList + :param nodes: The environment's node directory. + :type nodes: Dict[str, Union[ServiceNode, ActiveNode]] + :param services_list: List of services registered for the environment. + :type services_list: List[str] + :return: Filtered version of 'acl' + :rtype: Dict[str, ACLRule] """ if source_node_id != "ANY": source_node_address = nodes[str(source_node_id)].ip_address @@ -158,19 +180,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_allow_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the ALLOW ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List ALLOW rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -191,19 +227,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_deny_acl_rules( self, - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ): - """ - Get the DENY ACL rules. + source_node_id: int, + dest_node_id: str, + protocol: int, + port: str, + acl: AccessControlList, + nodes: Dict[str, NodeUnion], + services_list: List[str], + ) -> Dict[str, ACLRule]: + """List DENY rules relating to specified nodes. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param source_node_id: Source node id + :type source_node_id: int + :param dest_node_id: Destination node + :type dest_node_id: str + :param protocol: Network protocol + :type protocol: int + :param port: Port + :type port: str + :param acl: Firewall ruleset which is applied to the network + :type acl: AccessControlList + :param nodes: The simulation's node store + :type nodes: Dict[str, NodeUnion] + :param services_list: Services list + :type services_list: List[str] + :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and + desination nodes + :rtype: Dict[str, ACLRule] """ matching_rules = self.get_matching_acl_rules( source_node_id, @@ -222,7 +272,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def _calculate_action_full_view(self, obs): + def _calculate_action_full_view(self, obs: np.ndarray) -> int: """ Calculate a good acl-based action for the blue agent to take. @@ -250,8 +300,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing an overwhelmed state, so we don't do this. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ # obs = convert_to_old_obs(obs) r_obs = transform_change_obs_readable(obs) @@ -387,7 +439,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action = get_new_action(action, self._env.action_dict) return action - def _calculate_action_basic_view(self, obs): + def _calculate_action_basic_view(self, obs: np.ndarray) -> int: """Calculate a good acl-based action for the blue agent to take. Uses ONLY information from the current observation with NO knowledge @@ -405,8 +457,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): Currently, a deny rule does not overwrite an allow rule. The allow rules must be deleted. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 310fc178..c00cf421 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,3 +1,5 @@ +import numpy as np + from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable @@ -5,12 +7,14 @@ from primaite.agents.utils import get_new_action, transform_action_node_enum, tr class HardCodedNodeAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic Node agent.""" - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> int: """ Calculate a good node-based action for the blue agent to take. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: current observation from the gym environment + :type obs: np.ndarray + :return: Optimal action to take in the environment (chosen from the discrete action space) + :rtype: int """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index b5a3c673..bffdbf43 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,3 +1,5 @@ +from typing import List, Union + import numpy as np from primaite.common.enums import ( @@ -10,15 +12,17 @@ from primaite.common.enums import ( ) -def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. +def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: + """Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_node_property = NodePOLType(action[1]).name @@ -33,15 +37,18 @@ def transform_action_node_readable(action): return new_action -def transform_action_acl_readable(action): +def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]: """ Transform an ACL action to a more readable format. example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: The same action list, but with the encodings translated back into meaningful labels + :rtype: List[Union[int,str]] """ action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} action_permissions = {0: "DENY", 1: "ALLOW"} @@ -58,7 +65,7 @@ def transform_action_acl_readable(action): return new_action -def is_valid_node_action(action): +def is_valid_node_action(action: List[int]) -> bool: """Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -67,8 +74,11 @@ def is_valid_node_action(action): - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_node_readable(action) @@ -93,7 +103,7 @@ def is_valid_node_action(action): return True -def is_valid_acl_action(action): +def is_valid_acl_action(action: List[int]) -> bool: """ Is the ACL action an actual valid action. @@ -103,8 +113,11 @@ def is_valid_acl_action(action): - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_acl_readable(action) @@ -126,12 +139,15 @@ def is_valid_acl_action(action): return True -def is_valid_acl_action_extra(action): +def is_valid_acl_action_extra(action: List[int]) -> bool: """ Harsher version of valid acl actions, does not allow action. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Agent action, formatted as a list of ints, for more information check out + `primaite.environment.primaite_env.Primaite` + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ if is_valid_acl_action(action) is False: return False @@ -150,15 +166,16 @@ def is_valid_acl_action_extra(action): return True -def transform_change_obs_readable(obs): - """ - Transform list of transactions to readable list of each observation property. +def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform list of transactions to readable list of each observation property. example: np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ ids = [i for i in obs[:, 0]] operating_states = [HardwareState(i).name for i in obs[:, 1]] @@ -174,14 +191,16 @@ def transform_change_obs_readable(obs): return new_obs -def transform_obs_readable(obs): - """ - Transform observation to readable format. +def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: + """Transform observation to readable format. + example np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: Raw observation from the environment. + :type obs: np.ndarray + :return: The same observation, but the encoded integer values are replaced with readable names. + :rtype: List[List[Union[str, int]]] """ changed_obs = transform_change_obs_readable(obs) new_obs = list(zip(*changed_obs)) @@ -191,21 +210,23 @@ def transform_obs_readable(obs): return new_obs -def convert_to_new_obs(obs, num_nodes=10): - """ - Convert original gym Box observation space to new multiDiscrete observation space. +def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray: + """Convert original gym Box observation space to new multiDiscrete observation space. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs: observation in the 'old' (NodeLinkTable) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :return: reformatted observation + :rtype: np.ndarray """ # Remove ID columns, remove links and flatten to MultiDiscrete observation space new_obs = obs[:num_nodes, 1:].flatten() return new_obs -def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """ - Convert to old observation. +def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray: + """Convert to old observation. Links filled with 0's as no information is included in new observation space. @@ -217,8 +238,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): [ 3, 1, 1, 1], ... [20, 0, 0, 0]]) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + + :param obs: observation in the 'new' (MultiDiscrete) format + :type obs: np.ndarray + :param num_nodes: number of nodes in the network, defaults to 10 + :type num_nodes: int, optional + :param num_links: number of links in the network, defaults to 10 + :type num_links: int, optional + :param num_services: number of services on the network, defaults to 1 + :type num_services: int, optional + :return: 2-d BOX observation space, in the same format as NodeLinkTable + :rtype: np.ndarray """ # Convert back to more readable, original format reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) @@ -240,17 +270,28 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """ - Return string describing change between two observations. +def describe_obs_change( + obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1 +) -> str: + """Build a string describing the difference between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) output = 'ID 1: SERVICE 2 set to GOOD' - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs1: First observation + :type obs1: np.ndarray + :param obs2: Second observation + :type obs2: np.ndarray + :param num_nodes: How many nodes are in the network laydown, defaults to 10 + :type num_nodes: int, optional + :param num_links: How many links are in the network laydown, defaults to 10 + :type num_links: int, optional + :param num_services: How many services are configured for this scenario, defaults to 1 + :type num_services: int, optional + :return: A multi-line string with a human-readable description of the difference. + :rtype: str """ obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) From b426d5802e6769f3dab3abca409271fc4fe94abd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 5 Jul 2023 16:46:23 +0100 Subject: [PATCH 6/8] Updated docstrings --- src/primaite/agents/utils.py | 132 ++++++++++++++++-------------- src/primaite/pol/red_agent_pol.py | 14 +++- 2 files changed, 82 insertions(+), 64 deletions(-) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index bffdbf43..58b422d0 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Dict, List, Union import numpy as np @@ -346,15 +346,15 @@ def _describe_obs_change_helper(obs_change, is_link): return desc -def transform_action_node_enum(action): - """ - Convert a node action from readable string format, to enumerated format. +def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]: + """Convert a node action from readable string format, to enumerated format. example: [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Action in 'readable' format + :type action: List[Union[str,int]] + :return: Action with verbs encoded as ints + :rtype: List[int] """ action_node_id = action[0] action_node_property = NodePOLType[action[1]].value @@ -378,15 +378,16 @@ def transform_action_node_enum(action): return new_action -def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. +def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: + """Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Raw action with integer encodings + :type action: List[int] + :return: Human-redable version of the action + :rtype: List[Union[int,str]] """ action_node_property = NodePOLType(action[1]).name @@ -401,32 +402,33 @@ def transform_action_node_readable(action): return new_action -def node_action_description(action): - """ - Generate string describing a node-based action. +# unused +# def node_action_description(action): +# """ +# Generate string describing a node-based action. - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if isinstance(action[1], (int, np.int64)): - # transform action to readable format - action = transform_action_node_readable(action) +# TO#DO: Add params and return in docstring. +# TO#DO: Typehint params and return. +# """ +# if isinstance(action[1], (int, np.int64)): +# # transform action to readable format +# action = transform_action_node_readable(action) - node_id = action[0] - node_property = action[1] - property_action = action[2] - service_id = action[3] +# node_id = action[0] +# node_property = action[1] +# property_action = action[2] +# service_id = action[3] - if property_action == "NONE": - return "" - if node_property == "OPERATING" or node_property == "OS": - description = f"NODE {node_id}, {node_property}, SET TO {property_action}" - elif node_property == "SERVICE": - description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" - else: - return "" +# if property_action == "NONE": +# return "" +# if node_property == "OPERATING" or node_property == "OS": +# description = f"NODE {node_id}, {node_property}, SET TO {property_action}" +# elif node_property == "SERVICE": +# description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" +# else: +# return "" - return description +# return description def transform_action_acl_enum(action): @@ -452,25 +454,26 @@ def transform_action_acl_enum(action): return new_action -def acl_action_description(action): - """ - Generate string describing an acl-based action. +# unused +# def acl_action_description(action): +# """ +# Generate string describing an acl-based action. - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - if isinstance(action[0], (int, np.int64)): - # transform action to readable format - action = transform_action_acl_readable(action) - if action[0] == "NONE": - description = "NO ACL RULE APPLIED" - else: - description = ( - f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," - f" for protocol/service index {action[4]} on port index {action[5]}" - ) +# TODO: Add params and return in docstring. +# TODO: Typehint params and return. +# """ +# if isinstance(action[0], (int, np.int64)): +# # transform action to readable format +# action = transform_action_acl_readable(action) +# if action[0] == "NONE": +# description = "NO ACL RULE APPLIED" +# else: +# description = ( +# f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," +# f" for protocol/service index {action[4]} on port index {action[5]}" +# ) - return description +# return description def get_node_of_ip(ip, node_dict): @@ -521,7 +524,7 @@ def is_valid_node_action(action): return True -def is_valid_acl_action(action): +def is_valid_acl_action(action: List[int]) -> bool: """ Is the ACL action an actual valid action. @@ -531,8 +534,11 @@ def is_valid_acl_action(action): - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") - TODO: Add params and return in docstring. - TODO: Typehint params and return. + + :param action: Action to check + :type action: List[int] + :return: Whether the action is valid + :rtype: bool """ action_r = transform_action_acl_readable(action) @@ -554,12 +560,14 @@ def is_valid_acl_action(action): return True -def is_valid_acl_action_extra(action): +def is_valid_acl_action_extra(action: List[int]) -> bool: """ Harsher version of valid acl actions, does not allow action. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: Input action + :type action: List[int] + :return: Whether the action is a valid ACL action + :rtype: bool """ if is_valid_acl_action(action) is False: return False @@ -578,14 +586,18 @@ def is_valid_acl_action_extra(action): return True -def get_new_action(old_action, action_dict): +def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int: """ Get new action (e.g. 32) from old action e.g. [1,1,1,0]. Old_action can be either node or acl action type - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param old_action: Action expressed as a list of choices, eg. [1,1,1,0] + :type old_action: np.ndarray + :param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions. + :type action_dict: Dict[int,List] + :return: Action key correspoinding to the input `old_action` + :rtype: int """ for key, val in action_dict.items(): if list(val) == list(old_action): diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index bff19bf8..1a8bd406 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -296,11 +296,17 @@ def apply_red_agent_node_pol( pass -def is_red_ier_incoming(node, iers, node_pol_type): - """ - Checks if the RED IER is incoming. +def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool: + """Checks if the RED IER is incoming. - TODO: Write more descriptive docstring with params and returns. + :param node: Destination node of the IER + :type node: NodeUnion + :param iers: Directory of IERs + :type iers: Dict[str,IER] + :param node_pol_type: Type of Pattern-Of-Life + :type node_pol_type: NodePOLType + :return: Whether the RED IER is incoming. + :rtype: bool """ node_id = node.node_id From c38dda34b9b3f4917c921876d9b2fc7f71939220 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 6 Jul 2023 10:23:14 +0100 Subject: [PATCH 7/8] Removed duplicated function definitions --- src/primaite/agents/utils.py | 204 ++++------------------------------- 1 file changed, 23 insertions(+), 181 deletions(-) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 58b422d0..5f5261e0 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -2,6 +2,7 @@ from typing import Dict, List, Union import numpy as np +from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( HardwareState, LinkStatus, @@ -310,7 +311,7 @@ def describe_obs_change( return change_string -def _describe_obs_change_helper(obs_change, is_link): +def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str: """ Helper funcion to describe what has changed. @@ -319,8 +320,14 @@ def _describe_obs_change_helper(obs_change, is_link): Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one + row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new + status where it has changed. + :type obs_change: List[int] + :param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node. + :type is_link: bool + :return: A human-readable description of the difference between the two observation rows. + :rtype: str """ # Indexes where a change has occured, not including 0th index index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] @@ -378,65 +385,14 @@ def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]: return new_action -def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: - """Convert a node action from enumerated format to readable format. - - example: - [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - - :param action: Raw action with integer encodings - :type action: List[int] - :return: Human-redable version of the action - :rtype: List[Union[int,str]] - """ - action_node_property = NodePOLType(action[1]).name - - if action_node_property == "OPERATING": - property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: - property_action = NodeSoftwareAction(action[2]).name - else: - property_action = "NONE" - - new_action = [action[0], action_node_property, property_action, action[3]] - return new_action - - -# unused -# def node_action_description(action): -# """ -# Generate string describing a node-based action. - -# TO#DO: Add params and return in docstring. -# TO#DO: Typehint params and return. -# """ -# if isinstance(action[1], (int, np.int64)): -# # transform action to readable format -# action = transform_action_node_readable(action) - -# node_id = action[0] -# node_property = action[1] -# property_action = action[2] -# service_id = action[3] - -# if property_action == "NONE": -# return "" -# if node_property == "OPERATING" or node_property == "OS": -# description = f"NODE {node_id}, {node_property}, SET TO {property_action}" -# elif node_property == "SERVICE": -# description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" -# else: -# return "" - -# return description - - -def transform_action_acl_enum(action): +def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray: """ Convert acl action from readable str format, to enumerated format. - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param action: ACL-based action expressed as a list of human-readable ints and strings + :type action: List[Union[int,str]] + :return: The same action but encoded to contain only integers. + :rtype: np.ndarray """ action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} action_permissions = {"DENY": 0, "ALLOW": 1} @@ -454,36 +410,17 @@ def transform_action_acl_enum(action): return new_action -# unused -# def acl_action_description(action): -# """ -# Generate string describing an acl-based action. - -# TODO: Add params and return in docstring. -# TODO: Typehint params and return. -# """ -# if isinstance(action[0], (int, np.int64)): -# # transform action to readable format -# action = transform_action_acl_readable(action) -# if action[0] == "NONE": -# description = "NO ACL RULE APPLIED" -# else: -# description = ( -# f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," -# f" for protocol/service index {action[4]} on port index {action[5]}" -# ) - -# return description - - -def get_node_of_ip(ip, node_dict): - """ - Get the node ID of an IP address. +def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str: + """Get the node ID of an IP address. node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - TODO: Add params and return in docstring. - TODO: Typehint params and return. + :param ip: The IP address of the node whose ID is required + :type ip: str + :param node_dict: The environment's node registry dictionary + :type node_dict: Dict[str,NodeUnion] + :return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip` + :rtype: str """ for node_key, node_value in node_dict.items(): node_ip = node_value.ip_address @@ -491,101 +428,6 @@ def get_node_of_ip(ip, node_dict): return node_key -def is_valid_node_action(action): - """Is the node action an actual valid action. - - Only uses information about the action to determine if the action has an effect - - Does NOT consider: - - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - - Node already being in that state (turning an ON node ON) - - TODO: Add params and return in docstring. - TODO: Typehint params and return. - """ - action_r = transform_action_node_readable(action) - - node_property = action_r[1] - node_action = action_r[2] - - if node_property == "NONE": - return False - if node_action == "NONE": - return False - if node_property == "OPERATING" and node_action == "PATCHING": - # Operating State cannot PATCH - return False - if node_property != "OPERATING" and node_action not in [ - "NONE", - "PATCHING", - ]: - # Software States can only do Nothing or Patch - return False - return True - - -def is_valid_acl_action(action: List[int]) -> bool: - """ - Is the ACL action an actual valid action. - - Only uses information about the action to determine if the action has an effect - - Does NOT consider: - - Trying to create identical rules - - Trying to create a rule which is a subset of another rule (caused by "ANY") - - - :param action: Action to check - :type action: List[int] - :return: Whether the action is valid - :rtype: bool - """ - action_r = transform_action_acl_readable(action) - - action_decision = action_r[0] - action_permission = action_r[1] - action_source_id = action_r[2] - action_destination_id = action_r[3] - - if action_decision == "NONE": - return False - if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": - # ACL rule towards itself - return False - if action_permission == "DENY": - # DENY is unnecessary, we can create and delete allow rules instead - # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. - return False - - return True - - -def is_valid_acl_action_extra(action: List[int]) -> bool: - """ - Harsher version of valid acl actions, does not allow action. - - :param action: Input action - :type action: List[int] - :return: Whether the action is a valid ACL action - :rtype: bool - """ - if is_valid_acl_action(action) is False: - return False - - action_r = transform_action_acl_readable(action) - action_protocol = action_r[4] - action_port = action_r[5] - - # Don't allow protocols or ports to be ANY - # in the future we might want to do the opposite, and only have ANY option for ports and service - if action_protocol == "ANY": - return False - if action_port == "ANY": - return False - - return True - - def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int: """ Get new action (e.g. 32) from old action e.g. [1,1,1,0]. From 87bdaa1ec3f24adbd9b38d22c3a6c69359b6ffbd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 6 Jul 2023 10:34:27 +0100 Subject: [PATCH 8/8] Updated documentation --- src/primaite/acl/access_control_list.py | 2 +- src/primaite/agents/hardcoded_acl.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index f5f70d6c..78ea8a36 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -34,7 +34,7 @@ class AccessControlList: else: return False - def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): + def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool: """ Checks for rules that block a protocol / port. diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index a724e10e..f8c571c9 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -78,7 +78,6 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): dest_node_address = nodes[dest_node_id].ip_address protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - # I can't find where this function 'get_relevant_rules' is defined... matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules