Standardise docstring summary line placement.
This commit is contained in:
@@ -66,7 +66,8 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
|
|||||||
|
|
||||||
# region Setup Logging
|
# region Setup Logging
|
||||||
class _LevelFormatter(Formatter):
|
class _LevelFormatter(Formatter):
|
||||||
"""A custom level-specific formatter.
|
"""
|
||||||
|
A custom level-specific formatter.
|
||||||
|
|
||||||
Credit to: https://stackoverflow.com/a/68154386
|
Credit to: https://stackoverflow.com/a/68154386
|
||||||
"""
|
"""
|
||||||
@@ -134,7 +135,8 @@ _LOGGER.addHandler(_FILE_HANDLER)
|
|||||||
|
|
||||||
|
|
||||||
def getLogger(name: str) -> Logger: # noqa
|
def getLogger(name: str) -> Logger: # noqa
|
||||||
"""Get a PrimAITE logger.
|
"""
|
||||||
|
Get a PrimAITE logger.
|
||||||
|
|
||||||
:param name: The logger name. Use ``__name__``.
|
:param name: The logger name. Use ``__name__``.
|
||||||
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
|
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ class AccessControlList:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||||
"""Checks for rules that block a protocol / port.
|
"""
|
||||||
|
Checks for rules that block a protocol / port.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_source_ip_address: the source IP address to check
|
_source_ip_address: the source IP address to check
|
||||||
@@ -61,7 +62,8 @@ class AccessControlList:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||||
"""Adds a new rule.
|
"""
|
||||||
|
Adds a new rule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||||
@@ -75,7 +77,8 @@ class AccessControlList:
|
|||||||
self.acl[hash_value] = new_rule
|
self.acl[hash_value] = new_rule
|
||||||
|
|
||||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||||
"""Removes a rule.
|
"""
|
||||||
|
Removes a rule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||||
@@ -97,7 +100,8 @@ class AccessControlList:
|
|||||||
self.acl.clear()
|
self.acl.clear()
|
||||||
|
|
||||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||||
"""Produces a hash value for a rule.
|
"""
|
||||||
|
Produces a hash value for a rule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ class ACLRule:
|
|||||||
self.port = _port
|
self.port = _port
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
"""Override the hash function.
|
"""
|
||||||
|
Override the hash function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns hash of core parameters.
|
Returns hash of core parameters.
|
||||||
@@ -38,7 +39,8 @@ class ACLRule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_permission(self):
|
def get_permission(self):
|
||||||
"""Gets the permission attribute.
|
"""
|
||||||
|
Gets the permission attribute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns permission attribute
|
Returns permission attribute
|
||||||
@@ -46,7 +48,8 @@ class ACLRule:
|
|||||||
return self.permission
|
return self.permission
|
||||||
|
|
||||||
def get_source_ip(self):
|
def get_source_ip(self):
|
||||||
"""Gets the source IP address attribute.
|
"""
|
||||||
|
Gets the source IP address attribute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns source IP address attribute
|
Returns source IP address attribute
|
||||||
@@ -54,7 +57,8 @@ class ACLRule:
|
|||||||
return self.source_ip
|
return self.source_ip
|
||||||
|
|
||||||
def get_dest_ip(self):
|
def get_dest_ip(self):
|
||||||
"""Gets the desintation IP address attribute.
|
"""
|
||||||
|
Gets the desintation IP address attribute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns destination IP address attribute
|
Returns destination IP address attribute
|
||||||
@@ -62,7 +66,8 @@ class ACLRule:
|
|||||||
return self.dest_ip
|
return self.dest_ip
|
||||||
|
|
||||||
def get_protocol(self):
|
def get_protocol(self):
|
||||||
"""Gets the protocol attribute.
|
"""
|
||||||
|
Gets the protocol attribute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns protocol attribute
|
Returns protocol attribute
|
||||||
@@ -70,7 +75,8 @@ class ACLRule:
|
|||||||
return self.protocol
|
return self.protocol
|
||||||
|
|
||||||
def get_port(self):
|
def get_port(self):
|
||||||
"""Gets the port attribute.
|
"""
|
||||||
|
Gets the port attribute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns port attribute
|
Returns port attribute
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ _LOGGER = getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_session_path(session_timestamp: datetime) -> Path:
|
def get_session_path(session_timestamp: datetime) -> Path:
|
||||||
"""Get the directory path the session will output to.
|
"""
|
||||||
|
Get the directory path the session will output to.
|
||||||
|
|
||||||
This is set in the format of:
|
This is set in the format of:
|
||||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||||
@@ -194,7 +195,8 @@ class AgentSessionABC(ABC):
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Train the agent.
|
"""
|
||||||
|
Train the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
"""
|
"""
|
||||||
@@ -211,7 +213,8 @@ class AgentSessionABC(ABC):
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Evaluate the agent.
|
"""
|
||||||
|
Evaluate the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
"""
|
"""
|
||||||
@@ -340,7 +343,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Train the agent.
|
"""
|
||||||
|
Train the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
"""
|
"""
|
||||||
@@ -354,7 +358,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Evaluate the agent.
|
"""
|
||||||
|
Evaluate the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -46,7 +46,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
return blocked_green_iers
|
return blocked_green_iers
|
||||||
|
|
||||||
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
||||||
"""Get matching ACL rules for an IER.
|
"""
|
||||||
|
Get matching ACL rules for an IER.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -62,7 +63,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
return matching_rules
|
return matching_rules
|
||||||
|
|
||||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||||
"""Get blocking ACL rules for an IER.
|
"""
|
||||||
|
Get blocking ACL rules for an IER.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
Can return empty dict but IER can still be blocked by default
|
Can return empty dict but IER can still be blocked by default
|
||||||
@@ -81,7 +83,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
return blocked_rules
|
return blocked_rules
|
||||||
|
|
||||||
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
||||||
"""Get all allowing ACL rules for an IER.
|
"""
|
||||||
|
Get all allowing ACL rules for an IER.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -105,7 +108,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
nodes,
|
nodes,
|
||||||
services_list,
|
services_list,
|
||||||
):
|
):
|
||||||
"""Get matching ACL rules.
|
"""
|
||||||
|
Get matching ACL rules.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -136,7 +140,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
nodes,
|
nodes,
|
||||||
services_list,
|
services_list,
|
||||||
):
|
):
|
||||||
"""Get the ALLOW ACL rules.
|
"""
|
||||||
|
Get the ALLOW ACL rules.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -168,7 +173,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
nodes,
|
nodes,
|
||||||
services_list,
|
services_list,
|
||||||
):
|
):
|
||||||
"""Get the DENY ACL rules.
|
"""
|
||||||
|
Get the DENY ACL rules.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -191,7 +197,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
return allowed_rules
|
return allowed_rules
|
||||||
|
|
||||||
def _calculate_action_full_view(self, obs):
|
def _calculate_action_full_view(self, obs):
|
||||||
"""Calculate a good acl-based action for the blue agent to take.
|
"""
|
||||||
|
Calculate a good acl-based action for the blue agent to take.
|
||||||
|
|
||||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||||
|
|
||||||
@@ -355,7 +362,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
|||||||
return action
|
return action
|
||||||
|
|
||||||
def _calculate_action_basic_view(self, obs):
|
def _calculate_action_basic_view(self, obs):
|
||||||
"""Calculate a good acl-based action for the blue agent to take.
|
"""
|
||||||
|
Calculate a good acl-based action for the blue agent to take.
|
||||||
|
|
||||||
Uses ONLY information from the current observation with NO knowledge
|
Uses ONLY information from the current observation with NO knowledge
|
||||||
of previous actions taken and NO reward feedback.
|
of previous actions taken and NO reward feedback.
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
|||||||
"""An Agent Session class that implements a deterministic Node agent."""
|
"""An Agent Session class that implements a deterministic Node agent."""
|
||||||
|
|
||||||
def _calculate_action(self, obs):
|
def _calculate_action(self, obs):
|
||||||
"""Calculate a good node-based action for the blue agent to take.
|
"""
|
||||||
|
Calculate a good node-based action for the blue agent to take.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
|
|||||||
@@ -140,7 +140,8 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Evaluate the agent.
|
"""
|
||||||
|
Evaluate the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
"""
|
"""
|
||||||
@@ -158,7 +159,8 @@ class RLlibAgent(AgentSessionABC):
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Evaluate the agent.
|
"""
|
||||||
|
Evaluate the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -89,7 +89,8 @@ class SB3Agent(AgentSessionABC):
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Train the agent.
|
"""
|
||||||
|
Train the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
"""
|
"""
|
||||||
@@ -109,7 +110,8 @@ class SB3Agent(AgentSessionABC):
|
|||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Evaluate the agent.
|
"""
|
||||||
|
Evaluate the agent.
|
||||||
|
|
||||||
:param deterministic: Whether the evaluation is deterministic.
|
:param deterministic: Whether the evaluation is deterministic.
|
||||||
:param kwargs: Any agent-specific key-word args to be passed.
|
:param kwargs: Any agent-specific key-word args to be passed.
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ from primaite.common.enums import (
|
|||||||
|
|
||||||
|
|
||||||
def transform_action_node_readable(action):
|
def transform_action_node_readable(action):
|
||||||
"""Convert a node action from enumerated format to readable format.
|
"""
|
||||||
|
Convert a node action from enumerated format to readable format.
|
||||||
|
|
||||||
example:
|
example:
|
||||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||||
@@ -33,7 +34,8 @@ def transform_action_node_readable(action):
|
|||||||
|
|
||||||
|
|
||||||
def transform_action_acl_readable(action):
|
def transform_action_acl_readable(action):
|
||||||
"""Transform an ACL action to a more readable format.
|
"""
|
||||||
|
Transform an ACL action to a more readable format.
|
||||||
|
|
||||||
example:
|
example:
|
||||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||||
@@ -57,7 +59,8 @@ def transform_action_acl_readable(action):
|
|||||||
|
|
||||||
|
|
||||||
def is_valid_node_action(action):
|
def is_valid_node_action(action):
|
||||||
"""Is the node action an actual valid action.
|
"""
|
||||||
|
Is the node action an actual valid action.
|
||||||
|
|
||||||
Only uses information about the action to determine if the action has an effect
|
Only uses information about the action to determine if the action has an effect
|
||||||
|
|
||||||
@@ -92,7 +95,8 @@ def is_valid_node_action(action):
|
|||||||
|
|
||||||
|
|
||||||
def is_valid_acl_action(action):
|
def is_valid_acl_action(action):
|
||||||
"""Is the ACL action an actual valid action.
|
"""
|
||||||
|
Is the ACL action an actual valid action.
|
||||||
|
|
||||||
Only uses information about the action to determine if the action has an effect.
|
Only uses information about the action to determine if the action has an effect.
|
||||||
|
|
||||||
@@ -124,7 +128,8 @@ def is_valid_acl_action(action):
|
|||||||
|
|
||||||
|
|
||||||
def is_valid_acl_action_extra(action):
|
def is_valid_acl_action_extra(action):
|
||||||
"""Harsher version of valid acl actions, does not allow action.
|
"""
|
||||||
|
Harsher version of valid acl actions, does not allow action.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -147,7 +152,8 @@ def is_valid_acl_action_extra(action):
|
|||||||
|
|
||||||
|
|
||||||
def transform_change_obs_readable(obs):
|
def transform_change_obs_readable(obs):
|
||||||
"""Transform list of transactions to readable list of each observation property.
|
"""
|
||||||
|
Transform list of transactions to readable list of each observation property.
|
||||||
|
|
||||||
example:
|
example:
|
||||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||||
@@ -169,7 +175,8 @@ def transform_change_obs_readable(obs):
|
|||||||
|
|
||||||
|
|
||||||
def transform_obs_readable(obs):
|
def transform_obs_readable(obs):
|
||||||
"""Transform observation to readable format.
|
"""
|
||||||
|
Transform observation to readable format.
|
||||||
|
|
||||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||||
|
|
||||||
@@ -185,7 +192,8 @@ def transform_obs_readable(obs):
|
|||||||
|
|
||||||
|
|
||||||
def convert_to_new_obs(obs, num_nodes=10):
|
def convert_to_new_obs(obs, num_nodes=10):
|
||||||
"""Convert original gym Box observation space to new multiDiscrete observation space.
|
"""
|
||||||
|
Convert original gym Box observation space to new multiDiscrete observation space.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -196,7 +204,8 @@ def convert_to_new_obs(obs, num_nodes=10):
|
|||||||
|
|
||||||
|
|
||||||
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||||
"""Convert to old observation.
|
"""
|
||||||
|
Convert to old observation.
|
||||||
|
|
||||||
Links filled with 0's as no information is included in new observation space.
|
Links filled with 0's as no information is included in new observation space.
|
||||||
|
|
||||||
@@ -232,7 +241,8 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
|||||||
|
|
||||||
|
|
||||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||||
"""Return string describing change between two observations.
|
"""
|
||||||
|
Return string describing change between two observations.
|
||||||
|
|
||||||
example:
|
example:
|
||||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||||
@@ -260,7 +270,8 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
|||||||
|
|
||||||
|
|
||||||
def _describe_obs_change_helper(obs_change, is_link):
|
def _describe_obs_change_helper(obs_change, is_link):
|
||||||
"""Helper funcion to describe what has changed.
|
"""
|
||||||
|
Helper funcion to describe what has changed.
|
||||||
|
|
||||||
example:
|
example:
|
||||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||||
@@ -295,7 +306,8 @@ def _describe_obs_change_helper(obs_change, is_link):
|
|||||||
|
|
||||||
|
|
||||||
def transform_action_node_enum(action):
|
def transform_action_node_enum(action):
|
||||||
"""Convert a node action from readable string format, to enumerated format.
|
"""
|
||||||
|
Convert a node action from readable string format, to enumerated format.
|
||||||
|
|
||||||
example:
|
example:
|
||||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||||
@@ -326,7 +338,8 @@ def transform_action_node_enum(action):
|
|||||||
|
|
||||||
|
|
||||||
def transform_action_node_readable(action):
|
def transform_action_node_readable(action):
|
||||||
"""Convert a node action from enumerated format to readable format.
|
"""
|
||||||
|
Convert a node action from enumerated format to readable format.
|
||||||
|
|
||||||
example:
|
example:
|
||||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||||
@@ -348,7 +361,8 @@ def transform_action_node_readable(action):
|
|||||||
|
|
||||||
|
|
||||||
def node_action_description(action):
|
def node_action_description(action):
|
||||||
"""Generate string describing a node-based action.
|
"""
|
||||||
|
Generate string describing a node-based action.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -375,7 +389,8 @@ def node_action_description(action):
|
|||||||
|
|
||||||
|
|
||||||
def transform_action_acl_enum(action):
|
def transform_action_acl_enum(action):
|
||||||
"""Convert acl action from readable str format, to enumerated format.
|
"""
|
||||||
|
Convert acl action from readable str format, to enumerated format.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -397,7 +412,8 @@ def transform_action_acl_enum(action):
|
|||||||
|
|
||||||
|
|
||||||
def acl_action_description(action):
|
def acl_action_description(action):
|
||||||
"""Generate string describing an acl-based action.
|
"""
|
||||||
|
Generate string describing an acl-based action.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -417,7 +433,8 @@ def acl_action_description(action):
|
|||||||
|
|
||||||
|
|
||||||
def get_node_of_ip(ip, node_dict):
|
def get_node_of_ip(ip, node_dict):
|
||||||
"""Get the node ID of an IP address.
|
"""
|
||||||
|
Get the node ID of an IP address.
|
||||||
|
|
||||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||||
|
|
||||||
@@ -431,7 +448,8 @@ def get_node_of_ip(ip, node_dict):
|
|||||||
|
|
||||||
|
|
||||||
def is_valid_node_action(action):
|
def is_valid_node_action(action):
|
||||||
"""Is the node action an actual valid action.
|
"""
|
||||||
|
Is the node action an actual valid action.
|
||||||
|
|
||||||
Only uses information about the action to determine if the action has an effect
|
Only uses information about the action to determine if the action has an effect
|
||||||
|
|
||||||
@@ -464,7 +482,8 @@ def is_valid_node_action(action):
|
|||||||
|
|
||||||
|
|
||||||
def is_valid_acl_action(action):
|
def is_valid_acl_action(action):
|
||||||
"""Is the ACL action an actual valid action.
|
"""
|
||||||
|
Is the ACL action an actual valid action.
|
||||||
|
|
||||||
Only uses information about the action to determine if the action has an effect
|
Only uses information about the action to determine if the action has an effect
|
||||||
|
|
||||||
@@ -496,7 +515,8 @@ def is_valid_acl_action(action):
|
|||||||
|
|
||||||
|
|
||||||
def is_valid_acl_action_extra(action):
|
def is_valid_acl_action_extra(action):
|
||||||
"""Harsher version of valid acl actions, does not allow action.
|
"""
|
||||||
|
Harsher version of valid acl actions, does not allow action.
|
||||||
|
|
||||||
TODO: Add params and return in docstring.
|
TODO: Add params and return in docstring.
|
||||||
TODO: Typehint params and return.
|
TODO: Typehint params and return.
|
||||||
@@ -519,7 +539,8 @@ def is_valid_acl_action_extra(action):
|
|||||||
|
|
||||||
|
|
||||||
def get_new_action(old_action, action_dict):
|
def get_new_action(old_action, action_dict):
|
||||||
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
"""
|
||||||
|
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||||
|
|
||||||
Old_action can be either node or acl action type
|
Old_action can be either node or acl action type
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ def build_dirs():
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def reset_notebooks(overwrite: bool = True):
|
def reset_notebooks(overwrite: bool = True):
|
||||||
"""Force a reset of the demo notebooks in the users notebooks directory.
|
"""
|
||||||
|
Force a reset of the demo notebooks in the users notebooks directory.
|
||||||
|
|
||||||
:param overwrite: If True, will overwrite existing demo notebooks.
|
:param overwrite: If True, will overwrite existing demo notebooks.
|
||||||
"""
|
"""
|
||||||
@@ -39,7 +40,8 @@ def reset_notebooks(overwrite: bool = True):
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||||
"""Print the PrimAITE log file.
|
"""
|
||||||
|
Print the PrimAITE log file.
|
||||||
|
|
||||||
:param last_n: The number of lines to print. Default value is 10.
|
:param last_n: The number of lines to print. Default value is 10.
|
||||||
"""
|
"""
|
||||||
@@ -59,7 +61,8 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||||
"""View or set the PrimAITE Log Level.
|
"""
|
||||||
|
View or set the PrimAITE Log Level.
|
||||||
|
|
||||||
To View, simply call: primaite log-level
|
To View, simply call: primaite log-level
|
||||||
|
|
||||||
@@ -110,7 +113,8 @@ def clean_up():
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def setup(overwrite_existing: bool = True):
|
def setup(overwrite_existing: bool = True):
|
||||||
"""Perform the PrimAITE first-time setup.
|
"""
|
||||||
|
Perform the PrimAITE first-time setup.
|
||||||
|
|
||||||
WARNING: All user-data will be lost.
|
WARNING: All user-data will be lost.
|
||||||
"""
|
"""
|
||||||
@@ -148,7 +152,8 @@ def setup(overwrite_existing: bool = True):
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||||
"""Run a PrimAITE session.
|
"""
|
||||||
|
Run a PrimAITE session.
|
||||||
|
|
||||||
tc: The training config filepath. Optional. If no value is passed then
|
tc: The training config filepath. Optional. If no value is passed then
|
||||||
example default training config is used from:
|
example default training config is used from:
|
||||||
@@ -173,7 +178,8 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||||
"""View or set the plotly template for Session plots.
|
"""
|
||||||
|
View or set the plotly template for Session plots.
|
||||||
|
|
||||||
To View, simply call: primaite plotly-template
|
To View, simply call: primaite plotly-template
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down
|
|||||||
|
|
||||||
|
|
||||||
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Convert a legacy lay down config dict to the new format.
|
"""
|
||||||
|
Convert a legacy lay down config dict to the new format.
|
||||||
|
|
||||||
:param legacy_config_dict: A legacy lay down config dict.
|
:param legacy_config_dict: A legacy lay down config dict.
|
||||||
"""
|
"""
|
||||||
@@ -21,7 +22,8 @@ def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> D
|
|||||||
|
|
||||||
|
|
||||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||||
"""Read in a lay down config yaml file.
|
"""
|
||||||
|
Read in a lay down config yaml file.
|
||||||
|
|
||||||
:param file_path: The config file path.
|
:param file_path: The config file path.
|
||||||
:param legacy_file: True if the config file is legacy format, otherwise False.
|
:param legacy_file: True if the config file is legacy format, otherwise False.
|
||||||
@@ -50,7 +52,8 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
|||||||
|
|
||||||
|
|
||||||
def ddos_basic_one_config_path() -> Path:
|
def ddos_basic_one_config_path() -> Path:
|
||||||
"""The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
"""
|
||||||
|
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||||
|
|
||||||
:return: The file path.
|
:return: The file path.
|
||||||
"""
|
"""
|
||||||
@@ -64,7 +67,8 @@ def ddos_basic_one_config_path() -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def ddos_basic_two_config_path() -> Path:
|
def ddos_basic_two_config_path() -> Path:
|
||||||
"""The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
"""
|
||||||
|
The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||||
|
|
||||||
:return: The file path.
|
:return: The file path.
|
||||||
"""
|
"""
|
||||||
@@ -78,7 +82,8 @@ def ddos_basic_two_config_path() -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def dos_very_basic_config_path() -> Path:
|
def dos_very_basic_config_path() -> Path:
|
||||||
"""The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
"""
|
||||||
|
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||||
|
|
||||||
:return: The file path.
|
:return: The file path.
|
||||||
"""
|
"""
|
||||||
@@ -92,7 +97,8 @@ def dos_very_basic_config_path() -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def data_manipulation_config_path() -> Path:
|
def data_manipulation_config_path() -> Path:
|
||||||
"""The path to the example lay_down_config_5_data_manipulation.yaml file.
|
"""
|
||||||
|
The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||||
|
|
||||||
:return: The file path.
|
:return: The file path.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -180,7 +180,8 @@ class TrainingConfig:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||||
"""Create an instance of TrainingConfig from a dict.
|
"""
|
||||||
|
Create an instance of TrainingConfig from a dict.
|
||||||
|
|
||||||
:param config_dict: The training config dict.
|
:param config_dict: The training config dict.
|
||||||
:return: The instance of TrainingConfig.
|
:return: The instance of TrainingConfig.
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ def plot_av_reward_per_episode(
|
|||||||
title: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
subtitle: Optional[str] = None,
|
subtitle: Optional[str] = None,
|
||||||
) -> Figure:
|
) -> Figure:
|
||||||
"""Plot the average reward per episode from a csv session output.
|
"""
|
||||||
|
Plot the average reward per episode from a csv session output.
|
||||||
|
|
||||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||||
file path.
|
file path.
|
||||||
|
|||||||
@@ -50,7 +50,8 @@ class AbstractObservationComponent(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class NodeLinkTable(AbstractObservationComponent):
|
class NodeLinkTable(AbstractObservationComponent):
|
||||||
"""Table with nodes and links as rows and hardware/software status as cols.
|
"""
|
||||||
|
Table with nodes and links as rows and hardware/software status as cols.
|
||||||
|
|
||||||
This will create the observation space formatted as a table of integers.
|
This will create the observation space formatted as a table of integers.
|
||||||
There is one row per node, followed by one row per link.
|
There is one row per node, followed by one row per link.
|
||||||
@@ -74,7 +75,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
|||||||
_DATA_TYPE: type = np.int64
|
_DATA_TYPE: type = np.int64
|
||||||
|
|
||||||
def __init__(self, env: "Primaite"):
|
def __init__(self, env: "Primaite"):
|
||||||
"""Initialise a NodeLinkTable observation space component.
|
"""
|
||||||
|
Initialise a NodeLinkTable observation space component.
|
||||||
|
|
||||||
:param env: Training environment.
|
:param env: Training environment.
|
||||||
:type env: Primaite
|
:type env: Primaite
|
||||||
@@ -100,7 +102,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
|||||||
self.structure = self.generate_structure()
|
self.structure = self.generate_structure()
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update the observation based on current environment state.
|
"""
|
||||||
|
Update the observation based on current environment state.
|
||||||
|
|
||||||
The structure of the observation space is described in :class:`.NodeLinkTable`
|
The structure of the observation space is described in :class:`.NodeLinkTable`
|
||||||
"""
|
"""
|
||||||
@@ -181,7 +184,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
|||||||
|
|
||||||
|
|
||||||
class NodeStatuses(AbstractObservationComponent):
|
class NodeStatuses(AbstractObservationComponent):
|
||||||
"""Flat list of nodes' hardware, OS, file system, and service states.
|
"""
|
||||||
|
Flat list of nodes' hardware, OS, file system, and service states.
|
||||||
|
|
||||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||||
integers.
|
integers.
|
||||||
@@ -234,7 +238,8 @@ class NodeStatuses(AbstractObservationComponent):
|
|||||||
self.structure = self.generate_structure()
|
self.structure = self.generate_structure()
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update the observation based on current environment state.
|
"""
|
||||||
|
Update the observation based on current environment state.
|
||||||
|
|
||||||
The structure of the observation space is described in :class:`.NodeStatuses`
|
The structure of the observation space is described in :class:`.NodeStatuses`
|
||||||
"""
|
"""
|
||||||
@@ -287,7 +292,8 @@ class NodeStatuses(AbstractObservationComponent):
|
|||||||
|
|
||||||
|
|
||||||
class LinkTrafficLevels(AbstractObservationComponent):
|
class LinkTrafficLevels(AbstractObservationComponent):
|
||||||
"""Flat list of traffic levels encoded into banded categories.
|
"""
|
||||||
|
Flat list of traffic levels encoded into banded categories.
|
||||||
|
|
||||||
For each link, total traffic or traffic per service is encoded into a categorical value.
|
For each link, total traffic or traffic per service is encoded into a categorical value.
|
||||||
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
||||||
@@ -354,7 +360,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
|||||||
self.structure = self.generate_structure()
|
self.structure = self.generate_structure()
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update the observation based on current environment state.
|
"""
|
||||||
|
Update the observation based on current environment state.
|
||||||
|
|
||||||
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
||||||
"""
|
"""
|
||||||
@@ -395,7 +402,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
|||||||
|
|
||||||
|
|
||||||
class ObservationsHandler:
|
class ObservationsHandler:
|
||||||
"""Component-based observation space handler.
|
"""
|
||||||
|
Component-based observation space handler.
|
||||||
|
|
||||||
This allows users to configure observation spaces by mixing and matching components. Each component can also define
|
This allows users to configure observation spaces by mixing and matching components. Each component can also define
|
||||||
further parameters to make them more flexible.
|
further parameters to make them more flexible.
|
||||||
@@ -436,7 +444,8 @@ class ObservationsHandler:
|
|||||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||||
|
|
||||||
def register(self, obs_component: AbstractObservationComponent):
|
def register(self, obs_component: AbstractObservationComponent):
|
||||||
"""Add a component for this handler to track.
|
"""
|
||||||
|
Add a component for this handler to track.
|
||||||
|
|
||||||
:param obs_component: The component to add.
|
:param obs_component: The component to add.
|
||||||
:type obs_component: AbstractObservationComponent
|
:type obs_component: AbstractObservationComponent
|
||||||
@@ -445,7 +454,8 @@ class ObservationsHandler:
|
|||||||
self.update_space()
|
self.update_space()
|
||||||
|
|
||||||
def deregister(self, obs_component: AbstractObservationComponent):
|
def deregister(self, obs_component: AbstractObservationComponent):
|
||||||
"""Remove a component from this handler.
|
"""
|
||||||
|
Remove a component from this handler.
|
||||||
|
|
||||||
:param obs_component: Which component to remove. It must exist within this object's
|
:param obs_component: Which component to remove. It must exist within this object's
|
||||||
``registered_obs_components`` attribute.
|
``registered_obs_components`` attribute.
|
||||||
@@ -488,7 +498,8 @@ class ObservationsHandler:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||||
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
"""
|
||||||
|
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||||
|
|
||||||
The expected format for the config dictionary is:
|
The expected format for the config dictionary is:
|
||||||
|
|
||||||
@@ -533,7 +544,8 @@ class ObservationsHandler:
|
|||||||
return handler
|
return handler
|
||||||
|
|
||||||
def describe_structure(self):
|
def describe_structure(self):
|
||||||
"""Create a list of names for the features of the obs space.
|
"""
|
||||||
|
Create a list of names for the features of the obs space.
|
||||||
|
|
||||||
The order of labels follows the flattened version of the space.
|
The order of labels follows the flattened version of the space.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -255,7 +255,8 @@ class Primaite(Env):
|
|||||||
self.total_step_count = 0
|
self.total_step_count = 0
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""AI Gym Reset function.
|
"""
|
||||||
|
AI Gym Reset function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Environment observation space (reset)
|
Environment observation space (reset)
|
||||||
@@ -291,7 +292,8 @@ class Primaite(Env):
|
|||||||
return self.env_obs
|
return self.env_obs
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""AI Gym Step function.
|
"""
|
||||||
|
AI Gym Step function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action: Action space from agent
|
action: Action space from agent
|
||||||
@@ -429,7 +431,8 @@ class Primaite(Env):
|
|||||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||||
|
|
||||||
def interpret_action_and_apply(self, _action):
|
def interpret_action_and_apply(self, _action):
|
||||||
"""Applies agent actions to the nodes and Access Control List.
|
"""
|
||||||
|
Applies agent actions to the nodes and Access Control List.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_action: The action space from the agent
|
_action: The action space from the agent
|
||||||
@@ -448,7 +451,8 @@ class Primaite(Env):
|
|||||||
logging.error("Invalid action type found")
|
logging.error("Invalid action type found")
|
||||||
|
|
||||||
def apply_actions_to_nodes(self, _action):
|
def apply_actions_to_nodes(self, _action):
|
||||||
"""Applies agent actions to the nodes.
|
"""
|
||||||
|
Applies agent actions to the nodes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_action: The action space from the agent
|
_action: The action space from the agent
|
||||||
@@ -535,7 +539,8 @@ class Primaite(Env):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def apply_actions_to_acl(self, _action):
|
def apply_actions_to_acl(self, _action):
|
||||||
"""Applies agent actions to the Access Control List [TO DO].
|
"""
|
||||||
|
Applies agent actions to the Access Control List [TO DO].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_action: The action space from the agent
|
_action: The action space from the agent
|
||||||
@@ -612,7 +617,8 @@ class Primaite(Env):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def apply_time_based_updates(self):
|
def apply_time_based_updates(self):
|
||||||
"""Updates anything that needs to count down and then change state.
|
"""
|
||||||
|
Updates anything that needs to count down and then change state.
|
||||||
|
|
||||||
e.g. reset / patching status
|
e.g. reset / patching status
|
||||||
"""
|
"""
|
||||||
@@ -653,7 +659,8 @@ class Primaite(Env):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
||||||
"""Create the environment's observation handler.
|
"""
|
||||||
|
Create the environment's observation handler.
|
||||||
|
|
||||||
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||||
:rtype: Tuple[spaces.Space, np.ndarray]
|
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||||
@@ -709,7 +716,8 @@ class Primaite(Env):
|
|||||||
print("Environment configuration loaded")
|
print("Environment configuration loaded")
|
||||||
|
|
||||||
def create_node(self, item):
|
def create_node(self, item):
|
||||||
"""Creates a node from config data.
|
"""
|
||||||
|
Creates a node from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -789,7 +797,8 @@ class Primaite(Env):
|
|||||||
self.network_reference.add_nodes_from([node_ref])
|
self.network_reference.add_nodes_from([node_ref])
|
||||||
|
|
||||||
def create_link(self, item: Dict):
|
def create_link(self, item: Dict):
|
||||||
"""Creates a link from config data.
|
"""
|
||||||
|
Creates a link from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -832,7 +841,8 @@ class Primaite(Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_green_ier(self, item):
|
def create_green_ier(self, item):
|
||||||
"""Creates a green IER from config data.
|
"""
|
||||||
|
Creates a green IER from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -872,7 +882,8 @@ class Primaite(Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_red_ier(self, item):
|
def create_red_ier(self, item):
|
||||||
"""Creates a red IER from config data.
|
"""
|
||||||
|
Creates a red IER from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -901,7 +912,8 @@ class Primaite(Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_green_pol(self, item):
|
def create_green_pol(self, item):
|
||||||
"""Creates a green PoL object from config data.
|
"""
|
||||||
|
Creates a green PoL object from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -934,7 +946,8 @@ class Primaite(Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_red_pol(self, item):
|
def create_red_pol(self, item):
|
||||||
"""Creates a red PoL object from config data.
|
"""
|
||||||
|
Creates a red PoL object from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -974,7 +987,8 @@ class Primaite(Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_acl_rule(self, item):
|
def create_acl_rule(self, item):
|
||||||
"""Creates an ACL rule from config data.
|
"""
|
||||||
|
Creates an ACL rule from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -994,7 +1008,8 @@ class Primaite(Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_services_list(self, services):
|
def create_services_list(self, services):
|
||||||
"""Creates a list of services (enum) from config data.
|
"""
|
||||||
|
Creates a list of services (enum) from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item representing the services
|
item: A config data item representing the services
|
||||||
@@ -1009,7 +1024,8 @@ class Primaite(Env):
|
|||||||
self.num_services = len(self.services_list)
|
self.num_services = len(self.services_list)
|
||||||
|
|
||||||
def create_ports_list(self, ports):
|
def create_ports_list(self, ports):
|
||||||
"""Creates a list of ports from config data.
|
"""
|
||||||
|
Creates a list of ports from config data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item representing the ports
|
item: A config data item representing the ports
|
||||||
@@ -1024,7 +1040,8 @@ class Primaite(Env):
|
|||||||
self.num_ports = len(self.ports_list)
|
self.num_ports = len(self.ports_list)
|
||||||
|
|
||||||
def get_observation_info(self, observation_info):
|
def get_observation_info(self, observation_info):
|
||||||
"""Extracts observation_info.
|
"""
|
||||||
|
Extracts observation_info.
|
||||||
|
|
||||||
:param observation_info: Config item that defines which type of observation space to use
|
:param observation_info: Config item that defines which type of observation space to use
|
||||||
:type observation_info: str
|
:type observation_info: str
|
||||||
@@ -1032,7 +1049,8 @@ class Primaite(Env):
|
|||||||
self.observation_type = ObservationType[observation_info["type"]]
|
self.observation_type = ObservationType[observation_info["type"]]
|
||||||
|
|
||||||
def get_action_info(self, action_info):
|
def get_action_info(self, action_info):
|
||||||
"""Extracts action_info.
|
"""
|
||||||
|
Extracts action_info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item representing action info
|
item: A config data item representing action info
|
||||||
@@ -1040,7 +1058,8 @@ class Primaite(Env):
|
|||||||
self.action_type = ActionType[action_info["type"]]
|
self.action_type = ActionType[action_info["type"]]
|
||||||
|
|
||||||
def save_obs_config(self, obs_config: dict):
|
def save_obs_config(self, obs_config: dict):
|
||||||
"""Cache the config for the observation space.
|
"""
|
||||||
|
Cache the config for the observation space.
|
||||||
|
|
||||||
This is necessary as the observation space can't be built while reading the config,
|
This is necessary as the observation space can't be built while reading the config,
|
||||||
it must be done after all the nodes, links, and services have been initialised.
|
it must be done after all the nodes, links, and services have been initialised.
|
||||||
@@ -1052,7 +1071,8 @@ class Primaite(Env):
|
|||||||
self.obs_config = obs_config
|
self.obs_config = obs_config
|
||||||
|
|
||||||
def reset_environment(self):
|
def reset_environment(self):
|
||||||
"""# Resets environment.
|
"""
|
||||||
|
Resets environment.
|
||||||
|
|
||||||
Uses config data config data in order to build the environment configuration.
|
Uses config data config data in order to build the environment configuration.
|
||||||
"""
|
"""
|
||||||
@@ -1076,7 +1096,8 @@ class Primaite(Env):
|
|||||||
ier_value.set_is_running(False)
|
ier_value.set_is_running(False)
|
||||||
|
|
||||||
def reset_node(self, item):
|
def reset_node(self, item):
|
||||||
"""Resets the statuses of a node.
|
"""
|
||||||
|
Resets the statuses of a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
item: A config data item
|
item: A config data item
|
||||||
@@ -1123,7 +1144,8 @@ class Primaite(Env):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def create_node_action_dict(self):
|
def create_node_action_dict(self):
|
||||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
"""
|
||||||
|
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||||
|
|
||||||
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
|
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
|
||||||
|
|
||||||
@@ -1187,7 +1209,8 @@ class Primaite(Env):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
def create_node_and_acl_action_dict(self):
|
def create_node_and_acl_action_dict(self):
|
||||||
"""Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
"""
|
||||||
|
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||||
|
|
||||||
The dictionary contains actions of both Node and ACL action types.
|
The dictionary contains actions of both Node and ACL action types.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ def calculate_reward_function(
|
|||||||
step_count,
|
step_count,
|
||||||
config_values,
|
config_values,
|
||||||
):
|
):
|
||||||
"""Compares the states of the initial and final nodes/links to get a reward.
|
"""
|
||||||
|
Compares the states of the initial and final nodes/links to get a reward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initial_nodes: The nodes before red and blue agents take effect
|
initial_nodes: The nodes before red and blue agents take effect
|
||||||
@@ -94,7 +95,8 @@ def calculate_reward_function(
|
|||||||
|
|
||||||
|
|
||||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||||
"""Calculates score relating to the hardware state of a node.
|
"""
|
||||||
|
Calculates score relating to the hardware state of a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
final_node: The node after red and blue agents take effect
|
final_node: The node after red and blue agents take effect
|
||||||
@@ -142,7 +144,8 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
|||||||
|
|
||||||
|
|
||||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||||
"""Calculates score relating to the Software State of a node.
|
"""
|
||||||
|
Calculates score relating to the Software State of a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
final_node: The node after red and blue agents take effect
|
final_node: The node after red and blue agents take effect
|
||||||
@@ -192,7 +195,8 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
|||||||
|
|
||||||
|
|
||||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||||
"""Calculates score relating to the service state(s) of a node.
|
"""
|
||||||
|
Calculates score relating to the service state(s) of a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
final_node: The node after red and blue agents take effect
|
final_node: The node after red and blue agents take effect
|
||||||
@@ -263,7 +267,8 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
|||||||
|
|
||||||
|
|
||||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||||
"""Calculates score relating to the file system state of a node.
|
"""
|
||||||
|
Calculates score relating to the file system state of a node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
final_node: The node after red and blue agents take effect
|
final_node: The node after red and blue agents take effect
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ class Link(object):
|
|||||||
self.add_protocol(protocol_name)
|
self.add_protocol(protocol_name)
|
||||||
|
|
||||||
def add_protocol(self, _protocol):
|
def add_protocol(self, _protocol):
|
||||||
"""Adds a new protocol to the list of protocols on this link.
|
"""
|
||||||
|
Adds a new protocol to the list of protocols on this link.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_protocol: The protocol to be added (enum)
|
_protocol: The protocol to be added (enum)
|
||||||
@@ -37,7 +38,8 @@ class Link(object):
|
|||||||
self.protocol_list.append(Protocol(_protocol))
|
self.protocol_list.append(Protocol(_protocol))
|
||||||
|
|
||||||
def get_id(self):
|
def get_id(self):
|
||||||
"""Gets link ID.
|
"""
|
||||||
|
Gets link ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Link ID
|
Link ID
|
||||||
@@ -45,7 +47,8 @@ class Link(object):
|
|||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
def get_source_node_name(self):
|
def get_source_node_name(self):
|
||||||
"""Gets source node name.
|
"""
|
||||||
|
Gets source node name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Source node name
|
Source node name
|
||||||
@@ -53,7 +56,8 @@ class Link(object):
|
|||||||
return self.source_node_name
|
return self.source_node_name
|
||||||
|
|
||||||
def get_dest_node_name(self):
|
def get_dest_node_name(self):
|
||||||
"""Gets destination node name.
|
"""
|
||||||
|
Gets destination node name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Destination node name
|
Destination node name
|
||||||
@@ -61,7 +65,8 @@ class Link(object):
|
|||||||
return self.dest_node_name
|
return self.dest_node_name
|
||||||
|
|
||||||
def get_bandwidth(self):
|
def get_bandwidth(self):
|
||||||
"""Gets bandwidth of link.
|
"""
|
||||||
|
Gets bandwidth of link.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Link bandwidth (bps)
|
Link bandwidth (bps)
|
||||||
@@ -69,7 +74,8 @@ class Link(object):
|
|||||||
return self.bandwidth
|
return self.bandwidth
|
||||||
|
|
||||||
def get_protocol_list(self):
|
def get_protocol_list(self):
|
||||||
"""Gets list of protocols on this link.
|
"""
|
||||||
|
Gets list of protocols on this link.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of protocols on this link
|
List of protocols on this link
|
||||||
@@ -77,7 +83,8 @@ class Link(object):
|
|||||||
return self.protocol_list
|
return self.protocol_list
|
||||||
|
|
||||||
def get_current_load(self):
|
def get_current_load(self):
|
||||||
"""Gets current total load on this link.
|
"""
|
||||||
|
Gets current total load on this link.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Total load on this link (bps)
|
Total load on this link (bps)
|
||||||
@@ -88,7 +95,8 @@ class Link(object):
|
|||||||
return total_load
|
return total_load
|
||||||
|
|
||||||
def add_protocol_load(self, _protocol, _load):
|
def add_protocol_load(self, _protocol, _load):
|
||||||
"""Adds a loading to a protocol on this link.
|
"""
|
||||||
|
Adds a loading to a protocol on this link.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_protocol: The protocol to load
|
_protocol: The protocol to load
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ def run(
|
|||||||
training_config_path: Union[str, Path],
|
training_config_path: Union[str, Path],
|
||||||
lay_down_config_path: Union[str, Path],
|
lay_down_config_path: Union[str, Path],
|
||||||
):
|
):
|
||||||
"""Run the PrimAITE Session.
|
"""
|
||||||
|
Run the PrimAITE Session.
|
||||||
|
|
||||||
:param training_config_path: The training config filepath.
|
:param training_config_path: The training config filepath.
|
||||||
:param lay_down_config_path: The lay down config filepath.
|
:param lay_down_config_path: The lay down config filepath.
|
||||||
|
|||||||
@@ -52,7 +52,8 @@ class ActiveNode(Node):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def software_state(self) -> SoftwareState:
|
def software_state(self) -> SoftwareState:
|
||||||
"""Get the software_state.
|
"""
|
||||||
|
Get the software_state.
|
||||||
|
|
||||||
:return: The software_state.
|
:return: The software_state.
|
||||||
"""
|
"""
|
||||||
@@ -60,7 +61,8 @@ class ActiveNode(Node):
|
|||||||
|
|
||||||
@software_state.setter
|
@software_state.setter
|
||||||
def software_state(self, software_state: SoftwareState):
|
def software_state(self, software_state: SoftwareState):
|
||||||
"""Get the software_state.
|
"""
|
||||||
|
Get the software_state.
|
||||||
|
|
||||||
:param software_state: Software State.
|
:param software_state: Software State.
|
||||||
"""
|
"""
|
||||||
@@ -78,7 +80,8 @@ class ActiveNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||||
"""Sets Software State if the node is not compromised.
|
"""
|
||||||
|
Sets Software State if the node is not compromised.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
software_state: Software State
|
software_state: Software State
|
||||||
@@ -104,7 +107,8 @@ class ActiveNode(Node):
|
|||||||
self._software_state = SoftwareState.GOOD
|
self._software_state = SoftwareState.GOOD
|
||||||
|
|
||||||
def set_file_system_state(self, file_system_state: FileSystemState):
|
def set_file_system_state(self, file_system_state: FileSystemState):
|
||||||
"""Sets the file system state (actual and observed).
|
"""
|
||||||
|
Sets the file system state (actual and observed).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_system_state: File system state
|
file_system_state: File system state
|
||||||
@@ -130,7 +134,8 @@ class ActiveNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||||
"""Sets the file system state (actual and observed) if not in a compromised state.
|
"""
|
||||||
|
Sets the file system state (actual and observed) if not in a compromised state.
|
||||||
|
|
||||||
Use for green PoL to prevent it overturning a compromised state
|
Use for green PoL to prevent it overturning a compromised state
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ class NodeStateInstructionGreen(object):
|
|||||||
self.state = _state
|
self.state = _state
|
||||||
|
|
||||||
def get_start_step(self):
|
def get_start_step(self):
|
||||||
"""Gets the start step.
|
"""
|
||||||
|
Gets the start step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The start step
|
The start step
|
||||||
@@ -43,7 +44,8 @@ class NodeStateInstructionGreen(object):
|
|||||||
return self.start_step
|
return self.start_step
|
||||||
|
|
||||||
def get_end_step(self):
|
def get_end_step(self):
|
||||||
"""Gets the end step.
|
"""
|
||||||
|
Gets the end step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The end step
|
The end step
|
||||||
@@ -51,7 +53,8 @@ class NodeStateInstructionGreen(object):
|
|||||||
return self.end_step
|
return self.end_step
|
||||||
|
|
||||||
def get_node_id(self):
|
def get_node_id(self):
|
||||||
"""Gets the node ID.
|
"""
|
||||||
|
Gets the node ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The node ID
|
The node ID
|
||||||
@@ -59,7 +62,8 @@ class NodeStateInstructionGreen(object):
|
|||||||
return self.node_id
|
return self.node_id
|
||||||
|
|
||||||
def get_node_pol_type(self):
|
def get_node_pol_type(self):
|
||||||
"""Gets the node pattern of life type (enum).
|
"""
|
||||||
|
Gets the node pattern of life type (enum).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The node pattern of life type (enum)
|
The node pattern of life type (enum)
|
||||||
@@ -67,7 +71,8 @@ class NodeStateInstructionGreen(object):
|
|||||||
return self.node_pol_type
|
return self.node_pol_type
|
||||||
|
|
||||||
def get_service_name(self):
|
def get_service_name(self):
|
||||||
"""Gets the service name.
|
"""
|
||||||
|
Gets the service name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The service name
|
The service name
|
||||||
@@ -75,7 +80,8 @@ class NodeStateInstructionGreen(object):
|
|||||||
return self.service_name
|
return self.service_name
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
"""Gets the state (node or service).
|
"""
|
||||||
|
Gets the state (node or service).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The state (node or service)
|
The state (node or service)
|
||||||
|
|||||||
@@ -51,7 +51,8 @@ class NodeStateInstructionRed(object):
|
|||||||
self.source_node_service_state = _pol_source_node_service_state
|
self.source_node_service_state = _pol_source_node_service_state
|
||||||
|
|
||||||
def get_start_step(self):
|
def get_start_step(self):
|
||||||
"""Gets the start step.
|
"""
|
||||||
|
Gets the start step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The start step
|
The start step
|
||||||
@@ -59,7 +60,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.start_step
|
return self.start_step
|
||||||
|
|
||||||
def get_end_step(self):
|
def get_end_step(self):
|
||||||
"""Gets the end step.
|
"""
|
||||||
|
Gets the end step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The end step
|
The end step
|
||||||
@@ -67,7 +69,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.end_step
|
return self.end_step
|
||||||
|
|
||||||
def get_target_node_id(self):
|
def get_target_node_id(self):
|
||||||
"""Gets the node ID.
|
"""
|
||||||
|
Gets the node ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The node ID
|
The node ID
|
||||||
@@ -75,7 +78,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.target_node_id
|
return self.target_node_id
|
||||||
|
|
||||||
def get_initiator(self):
|
def get_initiator(self):
|
||||||
"""Gets the initiator.
|
"""
|
||||||
|
Gets the initiator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The initiator
|
The initiator
|
||||||
@@ -83,7 +87,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.initiator
|
return self.initiator
|
||||||
|
|
||||||
def get_pol_type(self) -> NodePOLType:
|
def get_pol_type(self) -> NodePOLType:
|
||||||
"""Gets the node pattern of life type (enum).
|
"""
|
||||||
|
Gets the node pattern of life type (enum).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The node pattern of life type (enum)
|
The node pattern of life type (enum)
|
||||||
@@ -91,7 +96,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.pol_type
|
return self.pol_type
|
||||||
|
|
||||||
def get_service_name(self):
|
def get_service_name(self):
|
||||||
"""Gets the service name.
|
"""
|
||||||
|
Gets the service name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The service name
|
The service name
|
||||||
@@ -99,7 +105,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.service_name
|
return self.service_name
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
"""Gets the state (node or service).
|
"""
|
||||||
|
Gets the state (node or service).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The state (node or service)
|
The state (node or service)
|
||||||
@@ -107,7 +114,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
def get_source_node_id(self):
|
def get_source_node_id(self):
|
||||||
"""Gets the source node id (used for initiator type SERVICE).
|
"""
|
||||||
|
Gets the source node id (used for initiator type SERVICE).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The source node id
|
The source node id
|
||||||
@@ -115,7 +123,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.source_node_id
|
return self.source_node_id
|
||||||
|
|
||||||
def get_source_node_service(self):
|
def get_source_node_service(self):
|
||||||
"""Gets the source node service (used for initiator type SERVICE).
|
"""
|
||||||
|
Gets the source node service (used for initiator type SERVICE).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The source node service
|
The source node service
|
||||||
@@ -123,7 +132,8 @@ class NodeStateInstructionRed(object):
|
|||||||
return self.source_node_service
|
return self.source_node_service
|
||||||
|
|
||||||
def get_source_node_service_state(self):
|
def get_source_node_service_state(self):
|
||||||
"""Gets the source node service state (used for initiator type SERVICE).
|
"""
|
||||||
|
Gets the source node service state (used for initiator type SERVICE).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The source node service state
|
The source node service state
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ class PassiveNode(Node):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def ip_address(self) -> str:
|
def ip_address(self) -> str:
|
||||||
"""Gets the node IP address as an empty string.
|
"""
|
||||||
|
Gets the node IP address as an empty string.
|
||||||
|
|
||||||
No concept of IP address for passive nodes for now.
|
No concept of IP address for passive nodes for now.
|
||||||
|
|
||||||
|
|||||||
@@ -53,14 +53,16 @@ class ServiceNode(ActiveNode):
|
|||||||
self.services: Dict[str, Service] = {}
|
self.services: Dict[str, Service] = {}
|
||||||
|
|
||||||
def add_service(self, service: Service):
|
def add_service(self, service: Service):
|
||||||
"""Adds a service to the node.
|
"""
|
||||||
|
Adds a service to the node.
|
||||||
|
|
||||||
:param service: The service to add
|
:param service: The service to add
|
||||||
"""
|
"""
|
||||||
self.services[service.name] = service
|
self.services[service.name] = service
|
||||||
|
|
||||||
def has_service(self, protocol_name: str) -> bool:
|
def has_service(self, protocol_name: str) -> bool:
|
||||||
"""Indicates whether a service is on a node.
|
"""
|
||||||
|
Indicates whether a service is on a node.
|
||||||
|
|
||||||
:param protocol_name: The service (protocol)e.
|
:param protocol_name: The service (protocol)e.
|
||||||
:return: True if service (protocol) is on the node, otherwise False.
|
:return: True if service (protocol) is on the node, otherwise False.
|
||||||
@@ -71,7 +73,8 @@ class ServiceNode(ActiveNode):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def service_running(self, protocol_name: str) -> bool:
|
def service_running(self, protocol_name: str) -> bool:
|
||||||
"""Indicates whether a service is in a running state on the node.
|
"""
|
||||||
|
Indicates whether a service is in a running state on the node.
|
||||||
|
|
||||||
:param protocol_name: The service (protocol)
|
:param protocol_name: The service (protocol)
|
||||||
:return: True if service (protocol) is in a running state on the node, otherwise False.
|
:return: True if service (protocol) is in a running state on the node, otherwise False.
|
||||||
@@ -85,7 +88,8 @@ class ServiceNode(ActiveNode):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def service_is_overwhelmed(self, protocol_name: str) -> bool:
|
def service_is_overwhelmed(self, protocol_name: str) -> bool:
|
||||||
"""Indicates whether a service is in an overwhelmed state on the node.
|
"""
|
||||||
|
Indicates whether a service is in an overwhelmed state on the node.
|
||||||
|
|
||||||
:param protocol_name: The service (protocol)
|
:param protocol_name: The service (protocol)
|
||||||
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
|
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
|
||||||
@@ -99,7 +103,8 @@ class ServiceNode(ActiveNode):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||||
"""Sets the software_state of a service (protocol) on the node.
|
"""
|
||||||
|
Sets the software_state of a service (protocol) on the node.
|
||||||
|
|
||||||
:param protocol_name: The service (protocol).
|
:param protocol_name: The service (protocol).
|
||||||
:param software_state: The software_state.
|
:param software_state: The software_state.
|
||||||
@@ -127,7 +132,8 @@ class ServiceNode(ActiveNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||||
"""Sets the software_state of a service (protocol) on the node.
|
"""
|
||||||
|
Sets the software_state of a service (protocol) on the node.
|
||||||
|
|
||||||
Done if the software_state is not "compromised".
|
Done if the software_state is not "compromised".
|
||||||
|
|
||||||
@@ -153,7 +159,8 @@ class ServiceNode(ActiveNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_service_state(self, protocol_name):
|
def get_service_state(self, protocol_name):
|
||||||
"""Gets the state of a service.
|
"""
|
||||||
|
Gets the state of a service.
|
||||||
|
|
||||||
:return: The software_state of the service.
|
:return: The software_state of the service.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def start_jupyter_session():
|
def start_jupyter_session():
|
||||||
"""Starts a new Jupyter notebook session in the app notebooks directory.
|
"""
|
||||||
|
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||||
|
|
||||||
Currently only works on Windows OS.
|
Currently only works on Windows OS.
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ def apply_iers(
|
|||||||
acl: AccessControlList,
|
acl: AccessControlList,
|
||||||
step: int,
|
step: int,
|
||||||
):
|
):
|
||||||
"""Applies IERs to the links (link pattern of life).
|
"""
|
||||||
|
Applies IERs to the links (link pattern of life).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
network: The network modelled in the environment
|
network: The network modelled in the environment
|
||||||
@@ -217,7 +218,8 @@ def apply_node_pol(
|
|||||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||||
step: int,
|
step: int,
|
||||||
):
|
):
|
||||||
"""Applies node pattern of life.
|
"""
|
||||||
|
Applies node pattern of life.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodes: The nodes within the environment
|
nodes: The nodes within the environment
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""Information Exchange Requirements for APE.
|
"""
|
||||||
|
Information Exchange Requirements for APE.
|
||||||
|
|
||||||
Used to represent an information flow from source to destination.
|
Used to represent an information flow from source to destination.
|
||||||
"""
|
"""
|
||||||
@@ -47,7 +48,8 @@ class IER(object):
|
|||||||
self.running = _running
|
self.running = _running
|
||||||
|
|
||||||
def get_id(self):
|
def get_id(self):
|
||||||
"""Gets IER ID.
|
"""
|
||||||
|
Gets IER ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER ID
|
IER ID
|
||||||
@@ -55,7 +57,8 @@ class IER(object):
|
|||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
def get_start_step(self):
|
def get_start_step(self):
|
||||||
"""Gets IER start step.
|
"""
|
||||||
|
Gets IER start step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER start step
|
IER start step
|
||||||
@@ -63,7 +66,8 @@ class IER(object):
|
|||||||
return self.start_step
|
return self.start_step
|
||||||
|
|
||||||
def get_end_step(self):
|
def get_end_step(self):
|
||||||
"""Gets IER end step.
|
"""
|
||||||
|
Gets IER end step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER end step
|
IER end step
|
||||||
@@ -71,7 +75,8 @@ class IER(object):
|
|||||||
return self.end_step
|
return self.end_step
|
||||||
|
|
||||||
def get_load(self):
|
def get_load(self):
|
||||||
"""Gets IER load.
|
"""
|
||||||
|
Gets IER load.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER load
|
IER load
|
||||||
@@ -79,7 +84,8 @@ class IER(object):
|
|||||||
return self.load
|
return self.load
|
||||||
|
|
||||||
def get_protocol(self):
|
def get_protocol(self):
|
||||||
"""Gets IER protocol.
|
"""
|
||||||
|
Gets IER protocol.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER protocol
|
IER protocol
|
||||||
@@ -87,7 +93,8 @@ class IER(object):
|
|||||||
return self.protocol
|
return self.protocol
|
||||||
|
|
||||||
def get_port(self):
|
def get_port(self):
|
||||||
"""Gets IER port.
|
"""
|
||||||
|
Gets IER port.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER port
|
IER port
|
||||||
@@ -95,7 +102,8 @@ class IER(object):
|
|||||||
return self.port
|
return self.port
|
||||||
|
|
||||||
def get_source_node_id(self):
|
def get_source_node_id(self):
|
||||||
"""Gets IER source node ID.
|
"""
|
||||||
|
Gets IER source node ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER source node ID
|
IER source node ID
|
||||||
@@ -103,7 +111,8 @@ class IER(object):
|
|||||||
return self.source_node_id
|
return self.source_node_id
|
||||||
|
|
||||||
def get_dest_node_id(self):
|
def get_dest_node_id(self):
|
||||||
"""Gets IER destination node ID.
|
"""
|
||||||
|
Gets IER destination node ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IER destination node ID
|
IER destination node ID
|
||||||
@@ -111,7 +120,8 @@ class IER(object):
|
|||||||
return self.dest_node_id
|
return self.dest_node_id
|
||||||
|
|
||||||
def get_is_running(self):
|
def get_is_running(self):
|
||||||
"""Informs whether the IER is currently running.
|
"""
|
||||||
|
Informs whether the IER is currently running.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if running
|
True if running
|
||||||
@@ -119,7 +129,8 @@ class IER(object):
|
|||||||
return self.running
|
return self.running
|
||||||
|
|
||||||
def set_is_running(self, _value):
|
def set_is_running(self, _value):
|
||||||
"""Sets the running state of the IER.
|
"""
|
||||||
|
Sets the running state of the IER.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_value: running status
|
_value: running status
|
||||||
@@ -127,7 +138,8 @@ class IER(object):
|
|||||||
self.running = _value
|
self.running = _value
|
||||||
|
|
||||||
def get_mission_criticality(self):
|
def get_mission_criticality(self):
|
||||||
"""Gets the IER mission criticality (used in the reward function).
|
"""
|
||||||
|
Gets the IER mission criticality (used in the reward function).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Mission criticality value (0 lowest to 5 highest)
|
Mission criticality value (0 lowest to 5 highest)
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ def apply_red_agent_iers(
|
|||||||
acl: AccessControlList,
|
acl: AccessControlList,
|
||||||
step: int,
|
step: int,
|
||||||
):
|
):
|
||||||
"""Applies IERs to the links (link POL) resulting from red agent attack.
|
"""
|
||||||
|
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
network: The network modelled in the environment
|
network: The network modelled in the environment
|
||||||
@@ -213,7 +214,8 @@ def apply_red_agent_node_pol(
|
|||||||
node_pol: Dict[str, NodeStateInstructionRed],
|
node_pol: Dict[str, NodeStateInstructionRed],
|
||||||
step: int,
|
step: int,
|
||||||
):
|
):
|
||||||
"""Applies node pattern of life.
|
"""
|
||||||
|
Applies node pattern of life.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodes: The nodes within the environment
|
nodes: The nodes within the environment
|
||||||
@@ -295,7 +297,8 @@ def apply_red_agent_node_pol(
|
|||||||
|
|
||||||
|
|
||||||
def is_red_ier_incoming(node, iers, node_pol_type):
|
def is_red_ier_incoming(node, iers, node_pol_type):
|
||||||
"""Checks if the RED IER is incoming.
|
"""
|
||||||
|
Checks if the RED IER is incoming.
|
||||||
|
|
||||||
TODO: Write more descriptive docstring with params and returns.
|
TODO: Write more descriptive docstring with params and returns.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -125,7 +125,8 @@ class PrimaiteSession:
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Train the agent.
|
"""
|
||||||
|
Train the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-framework specific key word args.
|
:param kwargs: Any agent-framework specific key word args.
|
||||||
"""
|
"""
|
||||||
@@ -136,7 +137,8 @@ class PrimaiteSession:
|
|||||||
self,
|
self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Evaluate the agent.
|
"""
|
||||||
|
Evaluate the agent.
|
||||||
|
|
||||||
:param kwargs: Any agent-framework specific key word args.
|
:param kwargs: Any agent-framework specific key word args.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ _LOGGER = getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def run(overwrite_existing: bool = True):
|
def run(overwrite_existing: bool = True):
|
||||||
"""Resets the demo jupyter notebooks in the users app notebooks directory.
|
"""
|
||||||
|
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||||
|
|
||||||
:param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off.
|
:param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def run(overwrite_existing=True):
|
def run(overwrite_existing=True):
|
||||||
"""Resets the example config files in the users app config directory.
|
"""
|
||||||
|
Resets the example config files in the users app config directory.
|
||||||
|
|
||||||
:param overwrite_existing: A bool to toggle replacing existing edited config on or off.
|
:param overwrite_existing: A bool to toggle replacing existing edited config on or off.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ _LOGGER = getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
"""Handles creation of application directories and user directories.
|
"""
|
||||||
|
Handles creation of application directories and user directories.
|
||||||
|
|
||||||
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
|
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
|
||||||
app directories in the correct locations based on the users OS.
|
app directories in the correct locations based on the users OS.
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ class Transaction(object):
|
|||||||
"The env observation space description"
|
"The env observation space description"
|
||||||
|
|
||||||
def as_csv_data(self) -> Tuple[List, List]:
|
def as_csv_data(self) -> Tuple[List, List]:
|
||||||
"""Converts the Transaction to a csv data row and provides a header.
|
"""
|
||||||
|
Converts the Transaction to a csv data row and provides a header.
|
||||||
|
|
||||||
:return: A tuple consisting of (header, data).
|
:return: A tuple consisting of (header, data).
|
||||||
"""
|
"""
|
||||||
@@ -68,7 +69,8 @@ class Transaction(object):
|
|||||||
|
|
||||||
|
|
||||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||||
"""Turns action space into a string array so it can be saved to csv.
|
"""
|
||||||
|
Turns action space into a string array so it can be saved to csv.
|
||||||
|
|
||||||
:param action_space: The action space
|
:param action_space: The action space
|
||||||
:return: The action space as an array of strings
|
:return: The action space as an array of strings
|
||||||
@@ -80,7 +82,8 @@ def _turn_action_space_to_array(action_space) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||||
"""Turns observation space into a string array so it can be saved to csv.
|
"""
|
||||||
|
Turns observation space into a string array so it can be saved to csv.
|
||||||
|
|
||||||
:param obs_space: The observation space
|
:param obs_space: The observation space
|
||||||
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ _LOGGER = getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_file_path(path: str) -> Path:
|
def get_file_path(path: str) -> Path:
|
||||||
"""Get PrimAITE package data.
|
"""
|
||||||
|
Get PrimAITE package data.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ import polars as pl
|
|||||||
|
|
||||||
|
|
||||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||||
"""Read an average rewards per episode csv file and return as a dict.
|
"""
|
||||||
|
Read an average rewards per episode csv file and return as a dict.
|
||||||
|
|
||||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,8 @@ class SessionOutputWriter:
|
|||||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||||
|
|
||||||
def write(self, data: Union[Tuple, Transaction]):
|
def write(self, data: Union[Tuple, Transaction]):
|
||||||
"""Write a row of session data.
|
"""
|
||||||
|
Write a row of session data.
|
||||||
|
|
||||||
:param data: The row of data to write. Can be a Tuple or an instance of Transaction.
|
:param data: The row of data to write. Can be a Tuple or an instance of Transaction.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -75,7 +75,8 @@ class TestNodeLinkTable:
|
|||||||
assert env.env_obs.shape == (5, 6)
|
assert env.env_obs.shape == (5, 6)
|
||||||
|
|
||||||
def test_value(self, temp_primaite_session):
|
def test_value(self, temp_primaite_session):
|
||||||
"""Test that the observation is generated correctly.
|
"""
|
||||||
|
Test that the observation is generated correctly.
|
||||||
|
|
||||||
The laydown has:
|
The laydown has:
|
||||||
* 3 nodes (2 service nodes and 1 active node)
|
* 3 nodes (2 service nodes and 1 active node)
|
||||||
@@ -157,7 +158,8 @@ class TestNodeStatuses:
|
|||||||
assert env.env_obs.shape == (15,)
|
assert env.env_obs.shape == (15,)
|
||||||
|
|
||||||
def test_values(self, temp_primaite_session):
|
def test_values(self, temp_primaite_session):
|
||||||
"""Test that the hardware and software states are encoded correctly.
|
"""
|
||||||
|
Test that the hardware and software states are encoded correctly.
|
||||||
|
|
||||||
The laydown has:
|
The laydown has:
|
||||||
* one node with a compromised operating system state
|
* one node with a compromised operating system state
|
||||||
@@ -213,7 +215,8 @@ class TestLinkTrafficLevels:
|
|||||||
assert env.env_obs.shape == (2 * 2,)
|
assert env.env_obs.shape == (2 * 2,)
|
||||||
|
|
||||||
def test_values(self, temp_primaite_session):
|
def test_values(self, temp_primaite_session):
|
||||||
"""Test that traffic values are encoded correctly.
|
"""
|
||||||
|
Test that traffic values are encoded correctly.
|
||||||
|
|
||||||
The laydown has:
|
The laydown has:
|
||||||
* two services
|
* two services
|
||||||
|
|||||||
Reference in New Issue
Block a user