Add docstrings and type hints.

This commit is contained in:
Marek Wolan
2023-07-05 16:19:43 +01:00
parent 0ae7158859
commit 5c167293e3
4 changed files with 208 additions and 109 deletions

View File

@@ -10,19 +10,19 @@ class AccessControlList:
def __init__(self):
"""Init."""
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""
Checks for IP address matches.
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches.
Args:
_rule: The rule being checked
_source_ip_address: the source IP address to compare
_dest_ip_address: the destination IP address to compare
Returns:
True if match; False otherwise.
:param _rule: The rule object to check
:type _rule: ACLRule
:param _source_ip_address: Source IP address to compare
:type _source_ip_address: str
:param _dest_ip_address: Destination IP address to compare
:type _dest_ip_address: str
:return: True if there is a match, otherwise False.
:rtype: bool
"""
if (
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)

View File

@@ -1,7 +1,9 @@
from typing import Any, Dict
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from typing import Any, Dict, List, Union
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
@@ -11,13 +13,15 @@ from primaite.agents.utils import (
)
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardCodedAgentView
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
@@ -26,7 +30,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
# history and reward feedback
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers:Dict[str, IER], acl:AccessControlList, nodes:Dict[str,NodeUnion]):
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[Any, Any]:
"""Get blocked green IERs.
:param green_iers: Green IERs to check for being
@@ -54,8 +60,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion]):
"""Get list of ACL rules which are relevant to an IER
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
"""Get list of ACL rules which are relevant to an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
@@ -76,7 +82,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
"""
Get blocking ACL rules for an IER.
@@ -102,9 +110,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_rules
def get_allow_acl_rules_for_ier(self, ier:IER, acl:AccessControlList, nodes:Dict[str,NodeUnion])->Dict[str,Any]:
"""
Get all allowing ACL rules for an IER.
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
"""Get all allowing ACL rules for an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
@@ -126,19 +135,32 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_matching_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get matching ACL rules.
source_node_id: str,
dest_node_id: str,
protocol: str,
port: str,
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
) -> Dict[str, ACLRule]:
"""Filter ACL rules to only those which are relevant to the specified nodes.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param source_node_id: Source node
:type source_node_id: str
:param dest_node_id: Destination nodes
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: str
:param port: Network port
:type port: str
:param acl: Access Control list which will be filtered
:type acl: AccessControlList
:param nodes: The environment's node directory.
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
:param services_list: List of services registered for the environment.
:type services_list: List[str]
:return: Filtered version of 'acl'
:rtype: Dict[str, ACLRule]
"""
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
@@ -158,19 +180,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_allow_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the ALLOW ACL rules.
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
"""List ALLOW rules relating to specified nodes.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
@@ -191,19 +227,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_deny_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the DENY ACL rules.
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
"""List DENY rules relating to specified nodes.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
@@ -222,7 +272,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return allowed_rules
def _calculate_action_full_view(self, obs):
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
@@ -250,8 +300,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
# obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
@@ -387,7 +439,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
action = get_new_action(action, self._env.action_dict)
return action
def _calculate_action_basic_view(self, obs):
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
"""Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge
@@ -405,8 +457,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
Currently, a deny rule does not overwrite an allow rule. The allow
rules must be deleted.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)

View File

@@ -1,3 +1,5 @@
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
@@ -5,12 +7,14 @@ from primaite.agents.utils import get_new_action, transform_action_node_enum, tr
class HardCodedNodeAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic Node agent."""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
"""
Calculate a good node-based action for the blue agent to take.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)

View File

@@ -1,3 +1,5 @@
from typing import List, Union
import numpy as np
from primaite.common.enums import (
@@ -10,15 +12,17 @@ from primaite.common.enums import (
)
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format.
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
"""Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_node_property = NodePOLType(action[1]).name
@@ -33,15 +37,18 @@ def transform_action_node_readable(action):
return new_action
def transform_action_acl_readable(action):
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
@@ -58,7 +65,7 @@ def transform_action_acl_readable(action):
return new_action
def is_valid_node_action(action):
def is_valid_node_action(action: List[int]) -> bool:
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
@@ -67,8 +74,11 @@ def is_valid_node_action(action):
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_node_readable(action)
@@ -93,7 +103,7 @@ def is_valid_node_action(action):
return True
def is_valid_acl_action(action):
def is_valid_acl_action(action: List[int]) -> bool:
"""
Is the ACL action an actual valid action.
@@ -103,8 +113,11 @@ def is_valid_acl_action(action):
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_acl_readable(action)
@@ -126,12 +139,15 @@ def is_valid_acl_action(action):
return True
def is_valid_acl_action_extra(action):
def is_valid_acl_action_extra(action: List[int]) -> bool:
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
if is_valid_acl_action(action) is False:
return False
@@ -150,15 +166,16 @@ def is_valid_acl_action_extra(action):
return True
def transform_change_obs_readable(obs):
"""
Transform list of transactions to readable list of each observation property.
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
@@ -174,14 +191,16 @@ def transform_change_obs_readable(obs):
return new_obs
def transform_obs_readable(obs):
"""
Transform observation to readable format.
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform observation to readable format.
example
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
@@ -191,21 +210,23 @@ def transform_obs_readable(obs):
return new_obs
def convert_to_new_obs(obs, num_nodes=10):
"""
Convert original gym Box observation space to new multiDiscrete observation space.
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
"""Convert original gym Box observation space to new multiDiscrete observation space.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: observation in the 'old' (NodeLinkTable) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:return: reformatted observation
:rtype: np.ndarray
"""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
"""
Convert to old observation.
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
"""Convert to old observation.
Links filled with 0's as no information is included in new observation space.
@@ -217,8 +238,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: observation in the 'new' (MultiDiscrete) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:param num_links: number of links in the network, defaults to 10
:type num_links: int, optional
:param num_services: number of services on the network, defaults to 1
:type num_services: int, optional
:return: 2-d BOX observation space, in the same format as NodeLinkTable
:rtype: np.ndarray
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
@@ -240,17 +270,28 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
return new_obs
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""
Return string describing change between two observations.
def describe_obs_change(
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
) -> str:
"""Build a string describing the difference between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs1: First observation
:type obs1: np.ndarray
:param obs2: Second observation
:type obs2: np.ndarray
:param num_nodes: How many nodes are in the network laydown, defaults to 10
:type num_nodes: int, optional
:param num_links: How many links are in the network laydown, defaults to 10
:type num_links: int, optional
:param num_services: How many services are configured for this scenario, defaults to 1
:type num_services: int, optional
:return: A multi-line string with a human-readable description of the difference.
:rtype: str
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)