Standardise docstring summary line placement.
This commit is contained in:
@@ -66,7 +66,8 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
|
||||
|
||||
# region Setup Logging
|
||||
class _LevelFormatter(Formatter):
|
||||
"""A custom level-specific formatter.
|
||||
"""
|
||||
A custom level-specific formatter.
|
||||
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
@@ -134,7 +135,8 @@ _LOGGER.addHandler(_FILE_HANDLER)
|
||||
|
||||
|
||||
def getLogger(name: str) -> Logger: # noqa
|
||||
"""Get a PrimAITE logger.
|
||||
"""
|
||||
Get a PrimAITE logger.
|
||||
|
||||
:param name: The logger name. Use ``__name__``.
|
||||
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
|
||||
|
||||
@@ -35,7 +35,8 @@ class AccessControlList:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
"""Checks for rules that block a protocol / port.
|
||||
"""
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
Args:
|
||||
_source_ip_address: the source IP address to check
|
||||
@@ -61,7 +62,8 @@ class AccessControlList:
|
||||
return True
|
||||
|
||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""Adds a new rule.
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -75,7 +77,8 @@ class AccessControlList:
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""Removes a rule.
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -97,7 +100,8 @@ class AccessControlList:
|
||||
self.acl.clear()
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""Produces a hash value for a rule.
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
|
||||
@@ -22,7 +22,8 @@ class ACLRule:
|
||||
self.port = _port
|
||||
|
||||
def __hash__(self):
|
||||
"""Override the hash function.
|
||||
"""
|
||||
Override the hash function.
|
||||
|
||||
Returns:
|
||||
Returns hash of core parameters.
|
||||
@@ -38,7 +39,8 @@ class ACLRule:
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
"""Gets the permission attribute.
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
|
||||
Returns:
|
||||
Returns permission attribute
|
||||
@@ -46,7 +48,8 @@ class ACLRule:
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self):
|
||||
"""Gets the source IP address attribute.
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns source IP address attribute
|
||||
@@ -54,7 +57,8 @@ class ACLRule:
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self):
|
||||
"""Gets the desintation IP address attribute.
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns destination IP address attribute
|
||||
@@ -62,7 +66,8 @@ class ACLRule:
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self):
|
||||
"""Gets the protocol attribute.
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
|
||||
Returns:
|
||||
Returns protocol attribute
|
||||
@@ -70,7 +75,8 @@ class ACLRule:
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
"""Gets the port attribute.
|
||||
"""
|
||||
Gets the port attribute.
|
||||
|
||||
Returns:
|
||||
Returns port attribute
|
||||
|
||||
@@ -21,7 +21,8 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""Get the directory path the session will output to.
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
@@ -194,7 +195,8 @@ class AgentSessionABC(ABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Train the agent.
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -211,7 +213,8 @@ class AgentSessionABC(ABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Evaluate the agent.
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -340,7 +343,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Train the agent.
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -354,7 +358,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Evaluate the agent.
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
|
||||
@@ -46,7 +46,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""Get matching ACL rules for an IER.
|
||||
"""
|
||||
Get matching ACL rules for an IER.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -62,7 +63,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""Get blocking ACL rules for an IER.
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
.. warning::
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
@@ -81,7 +83,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
"""
|
||||
Get all allowing ACL rules for an IER.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -105,7 +108,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""Get matching ACL rules.
|
||||
"""
|
||||
Get matching ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -136,7 +140,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""Get the ALLOW ACL rules.
|
||||
"""
|
||||
Get the ALLOW ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -168,7 +173,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
nodes,
|
||||
services_list,
|
||||
):
|
||||
"""Get the DENY ACL rules.
|
||||
"""
|
||||
Get the DENY ACL rules.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -191,7 +197,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs):
|
||||
"""Calculate a good acl-based action for the blue agent to take.
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
@@ -355,7 +362,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs):
|
||||
"""Calculate a good acl-based action for the blue agent to take.
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge
|
||||
of previous actions taken and NO reward feedback.
|
||||
|
||||
@@ -6,7 +6,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
"""Calculate a good node-based action for the blue agent to take.
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
|
||||
@@ -140,7 +140,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Evaluate the agent.
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -158,7 +159,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Evaluate the agent.
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
|
||||
@@ -89,7 +89,8 @@ class SB3Agent(AgentSessionABC):
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Train the agent.
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
@@ -109,7 +110,8 @@ class SB3Agent(AgentSessionABC):
|
||||
deterministic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Evaluate the agent.
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
|
||||
@@ -11,7 +11,8 @@ from primaite.common.enums import (
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
@@ -33,7 +34,8 @@ def transform_action_node_readable(action):
|
||||
|
||||
|
||||
def transform_action_acl_readable(action):
|
||||
"""Transform an ACL action to a more readable format.
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
@@ -57,7 +59,8 @@ def transform_action_acl_readable(action):
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
"""Is the node action an actual valid action.
|
||||
"""
|
||||
Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
@@ -92,7 +95,8 @@ def is_valid_node_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""Is the ACL action an actual valid action.
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect.
|
||||
|
||||
@@ -124,7 +128,8 @@ def is_valid_acl_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action.
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -147,7 +152,8 @@ def is_valid_acl_action_extra(action):
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs):
|
||||
"""Transform list of transactions to readable list of each observation property.
|
||||
"""
|
||||
Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
@@ -169,7 +175,8 @@ def transform_change_obs_readable(obs):
|
||||
|
||||
|
||||
def transform_obs_readable(obs):
|
||||
"""Transform observation to readable format.
|
||||
"""
|
||||
Transform observation to readable format.
|
||||
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
@@ -185,7 +192,8 @@ def transform_obs_readable(obs):
|
||||
|
||||
|
||||
def convert_to_new_obs(obs, num_nodes=10):
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
"""
|
||||
Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -196,7 +204,8 @@ def convert_to_new_obs(obs, num_nodes=10):
|
||||
|
||||
|
||||
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
"""Convert to old observation.
|
||||
"""
|
||||
Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
@@ -232,7 +241,8 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
|
||||
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""Return string describing change between two observations.
|
||||
"""
|
||||
Return string describing change between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
@@ -260,7 +270,8 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change, is_link):
|
||||
"""Helper funcion to describe what has changed.
|
||||
"""
|
||||
Helper funcion to describe what has changed.
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
@@ -295,7 +306,8 @@ def _describe_obs_change_helper(obs_change, is_link):
|
||||
|
||||
|
||||
def transform_action_node_enum(action):
|
||||
"""Convert a node action from readable string format, to enumerated format.
|
||||
"""
|
||||
Convert a node action from readable string format, to enumerated format.
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
@@ -326,7 +338,8 @@ def transform_action_node_enum(action):
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
"""
|
||||
Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
@@ -348,7 +361,8 @@ def transform_action_node_readable(action):
|
||||
|
||||
|
||||
def node_action_description(action):
|
||||
"""Generate string describing a node-based action.
|
||||
"""
|
||||
Generate string describing a node-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -375,7 +389,8 @@ def node_action_description(action):
|
||||
|
||||
|
||||
def transform_action_acl_enum(action):
|
||||
"""Convert acl action from readable str format, to enumerated format.
|
||||
"""
|
||||
Convert acl action from readable str format, to enumerated format.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -397,7 +412,8 @@ def transform_action_acl_enum(action):
|
||||
|
||||
|
||||
def acl_action_description(action):
|
||||
"""Generate string describing an acl-based action.
|
||||
"""
|
||||
Generate string describing an acl-based action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -417,7 +433,8 @@ def acl_action_description(action):
|
||||
|
||||
|
||||
def get_node_of_ip(ip, node_dict):
|
||||
"""Get the node ID of an IP address.
|
||||
"""
|
||||
Get the node ID of an IP address.
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
|
||||
@@ -431,7 +448,8 @@ def get_node_of_ip(ip, node_dict):
|
||||
|
||||
|
||||
def is_valid_node_action(action):
|
||||
"""Is the node action an actual valid action.
|
||||
"""
|
||||
Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
@@ -464,7 +482,8 @@ def is_valid_node_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action(action):
|
||||
"""Is the ACL action an actual valid action.
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
@@ -496,7 +515,8 @@ def is_valid_acl_action(action):
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action):
|
||||
"""Harsher version of valid acl actions, does not allow action.
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
TODO: Add params and return in docstring.
|
||||
TODO: Typehint params and return.
|
||||
@@ -519,7 +539,8 @@ def is_valid_acl_action_extra(action):
|
||||
|
||||
|
||||
def get_new_action(old_action, action_dict):
|
||||
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
"""
|
||||
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
|
||||
Old_action can be either node or acl action type
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ def build_dirs():
|
||||
|
||||
@app.command()
|
||||
def reset_notebooks(overwrite: bool = True):
|
||||
"""Force a reset of the demo notebooks in the users notebooks directory.
|
||||
"""
|
||||
Force a reset of the demo notebooks in the users notebooks directory.
|
||||
|
||||
:param overwrite: If True, will overwrite existing demo notebooks.
|
||||
"""
|
||||
@@ -39,7 +40,8 @@ def reset_notebooks(overwrite: bool = True):
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
"""Print the PrimAITE log file.
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
@@ -59,7 +61,8 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
|
||||
|
||||
@app.command()
|
||||
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
|
||||
"""View or set the PrimAITE Log Level.
|
||||
"""
|
||||
View or set the PrimAITE Log Level.
|
||||
|
||||
To View, simply call: primaite log-level
|
||||
|
||||
@@ -110,7 +113,8 @@ def clean_up():
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True):
|
||||
"""Perform the PrimAITE first-time setup.
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
@@ -148,7 +152,8 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
"""Run a PrimAITE session.
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
@@ -173,7 +178,8 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||
"""View or set the plotly template for Session plots.
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a legacy lay down config dict to the new format.
|
||||
"""
|
||||
Convert a legacy lay down config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy lay down config dict.
|
||||
"""
|
||||
@@ -21,7 +22,8 @@ def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> D
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
"""Read in a lay down config yaml file.
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise False.
|
||||
@@ -50,7 +52,8 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
@@ -64,7 +67,8 @@ def ddos_basic_one_config_path() -> Path:
|
||||
|
||||
|
||||
def ddos_basic_two_config_path() -> Path:
|
||||
"""The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
"""
|
||||
The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
@@ -78,7 +82,8 @@ def ddos_basic_two_config_path() -> Path:
|
||||
|
||||
|
||||
def dos_very_basic_config_path() -> Path:
|
||||
"""The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
"""
|
||||
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
@@ -92,7 +97,8 @@ def dos_very_basic_config_path() -> Path:
|
||||
|
||||
|
||||
def data_manipulation_config_path() -> Path:
|
||||
"""The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
"""
|
||||
The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
|
||||
@@ -180,7 +180,8 @@ class TrainingConfig:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
"""Create an instance of TrainingConfig from a dict.
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
:param config_dict: The training config dict.
|
||||
:return: The instance of TrainingConfig.
|
||||
|
||||
@@ -22,7 +22,8 @@ def plot_av_reward_per_episode(
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""Plot the average reward per episode from a csv session output.
|
||||
"""
|
||||
Plot the average reward per episode from a csv session output.
|
||||
|
||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||
file path.
|
||||
|
||||
@@ -50,7 +50,8 @@ class AbstractObservationComponent(ABC):
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""Table with nodes and links as rows and hardware/software status as cols.
|
||||
"""
|
||||
Table with nodes and links as rows and hardware/software status as cols.
|
||||
|
||||
This will create the observation space formatted as a table of integers.
|
||||
There is one row per node, followed by one row per link.
|
||||
@@ -74,7 +75,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
"""Initialise a NodeLinkTable observation space component.
|
||||
"""
|
||||
Initialise a NodeLinkTable observation space component.
|
||||
|
||||
:param env: Training environment.
|
||||
:type env: Primaite
|
||||
@@ -100,7 +102,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeLinkTable`
|
||||
"""
|
||||
@@ -181,7 +184,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""Flat list of nodes' hardware, OS, file system, and service states.
|
||||
"""
|
||||
Flat list of nodes' hardware, OS, file system, and service states.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
@@ -234,7 +238,8 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeStatuses`
|
||||
"""
|
||||
@@ -287,7 +292,8 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""Flat list of traffic levels encoded into banded categories.
|
||||
"""
|
||||
Flat list of traffic levels encoded into banded categories.
|
||||
|
||||
For each link, total traffic or traffic per service is encoded into a categorical value.
|
||||
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
||||
@@ -354,7 +360,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
||||
"""
|
||||
@@ -395,7 +402,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
"""
|
||||
Component-based observation space handler.
|
||||
|
||||
This allows users to configure observation spaces by mixing and matching components. Each component can also define
|
||||
further parameters to make them more flexible.
|
||||
@@ -436,7 +444,8 @@ class ObservationsHandler:
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""Add a component for this handler to track.
|
||||
"""
|
||||
Add a component for this handler to track.
|
||||
|
||||
:param obs_component: The component to add.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
@@ -445,7 +454,8 @@ class ObservationsHandler:
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent):
|
||||
"""Remove a component from this handler.
|
||||
"""
|
||||
Remove a component from this handler.
|
||||
|
||||
:param obs_component: Which component to remove. It must exist within this object's
|
||||
``registered_obs_components`` attribute.
|
||||
@@ -488,7 +498,8 @@ class ObservationsHandler:
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
"""
|
||||
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
The expected format for the config dictionary is:
|
||||
|
||||
@@ -533,7 +544,8 @@ class ObservationsHandler:
|
||||
return handler
|
||||
|
||||
def describe_structure(self):
|
||||
"""Create a list of names for the features of the obs space.
|
||||
"""
|
||||
Create a list of names for the features of the obs space.
|
||||
|
||||
The order of labels follows the flattened version of the space.
|
||||
"""
|
||||
|
||||
@@ -255,7 +255,8 @@ class Primaite(Env):
|
||||
self.total_step_count = 0
|
||||
|
||||
def reset(self):
|
||||
"""AI Gym Reset function.
|
||||
"""
|
||||
AI Gym Reset function.
|
||||
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
@@ -291,7 +292,8 @@ class Primaite(Env):
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
"""AI Gym Step function.
|
||||
"""
|
||||
AI Gym Step function.
|
||||
|
||||
Args:
|
||||
action: Action space from agent
|
||||
@@ -429,7 +431,8 @@ class Primaite(Env):
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
"""Applies agent actions to the nodes and Access Control List.
|
||||
"""
|
||||
Applies agent actions to the nodes and Access Control List.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
@@ -448,7 +451,8 @@ class Primaite(Env):
|
||||
logging.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
"""Applies agent actions to the nodes.
|
||||
"""
|
||||
Applies agent actions to the nodes.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
@@ -535,7 +539,8 @@ class Primaite(Env):
|
||||
return
|
||||
|
||||
def apply_actions_to_acl(self, _action):
|
||||
"""Applies agent actions to the Access Control List [TO DO].
|
||||
"""
|
||||
Applies agent actions to the Access Control List [TO DO].
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
@@ -612,7 +617,8 @@ class Primaite(Env):
|
||||
return
|
||||
|
||||
def apply_time_based_updates(self):
|
||||
"""Updates anything that needs to count down and then change state.
|
||||
"""
|
||||
Updates anything that needs to count down and then change state.
|
||||
|
||||
e.g. reset / patching status
|
||||
"""
|
||||
@@ -653,7 +659,8 @@ class Primaite(Env):
|
||||
pass
|
||||
|
||||
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
||||
"""Create the environment's observation handler.
|
||||
"""
|
||||
Create the environment's observation handler.
|
||||
|
||||
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||
@@ -709,7 +716,8 @@ class Primaite(Env):
|
||||
print("Environment configuration loaded")
|
||||
|
||||
def create_node(self, item):
|
||||
"""Creates a node from config data.
|
||||
"""
|
||||
Creates a node from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -789,7 +797,8 @@ class Primaite(Env):
|
||||
self.network_reference.add_nodes_from([node_ref])
|
||||
|
||||
def create_link(self, item: Dict):
|
||||
"""Creates a link from config data.
|
||||
"""
|
||||
Creates a link from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -832,7 +841,8 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_green_ier(self, item):
|
||||
"""Creates a green IER from config data.
|
||||
"""
|
||||
Creates a green IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -872,7 +882,8 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""Creates a red IER from config data.
|
||||
"""
|
||||
Creates a red IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -901,7 +912,8 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_green_pol(self, item):
|
||||
"""Creates a green PoL object from config data.
|
||||
"""
|
||||
Creates a green PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -934,7 +946,8 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_red_pol(self, item):
|
||||
"""Creates a red PoL object from config data.
|
||||
"""
|
||||
Creates a red PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -974,7 +987,8 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_acl_rule(self, item):
|
||||
"""Creates an ACL rule from config data.
|
||||
"""
|
||||
Creates an ACL rule from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -994,7 +1008,8 @@ class Primaite(Env):
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
"""Creates a list of services (enum) from config data.
|
||||
"""
|
||||
Creates a list of services (enum) from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the services
|
||||
@@ -1009,7 +1024,8 @@ class Primaite(Env):
|
||||
self.num_services = len(self.services_list)
|
||||
|
||||
def create_ports_list(self, ports):
|
||||
"""Creates a list of ports from config data.
|
||||
"""
|
||||
Creates a list of ports from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the ports
|
||||
@@ -1024,7 +1040,8 @@ class Primaite(Env):
|
||||
self.num_ports = len(self.ports_list)
|
||||
|
||||
def get_observation_info(self, observation_info):
|
||||
"""Extracts observation_info.
|
||||
"""
|
||||
Extracts observation_info.
|
||||
|
||||
:param observation_info: Config item that defines which type of observation space to use
|
||||
:type observation_info: str
|
||||
@@ -1032,7 +1049,8 @@ class Primaite(Env):
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""Extracts action_info.
|
||||
"""
|
||||
Extracts action_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing action info
|
||||
@@ -1040,7 +1058,8 @@ class Primaite(Env):
|
||||
self.action_type = ActionType[action_info["type"]]
|
||||
|
||||
def save_obs_config(self, obs_config: dict):
|
||||
"""Cache the config for the observation space.
|
||||
"""
|
||||
Cache the config for the observation space.
|
||||
|
||||
This is necessary as the observation space can't be built while reading the config,
|
||||
it must be done after all the nodes, links, and services have been initialised.
|
||||
@@ -1052,7 +1071,8 @@ class Primaite(Env):
|
||||
self.obs_config = obs_config
|
||||
|
||||
def reset_environment(self):
|
||||
"""# Resets environment.
|
||||
"""
|
||||
Resets environment.
|
||||
|
||||
Uses config data config data in order to build the environment configuration.
|
||||
"""
|
||||
@@ -1076,7 +1096,8 @@ class Primaite(Env):
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
def reset_node(self, item):
|
||||
"""Resets the statuses of a node.
|
||||
"""
|
||||
Resets the statuses of a node.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
@@ -1123,7 +1144,8 @@ class Primaite(Env):
|
||||
pass
|
||||
|
||||
def create_node_action_dict(self):
|
||||
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
"""
|
||||
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
||||
|
||||
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
|
||||
|
||||
@@ -1187,7 +1209,8 @@ class Primaite(Env):
|
||||
return actions
|
||||
|
||||
def create_node_and_acl_action_dict(self):
|
||||
"""Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
"""
|
||||
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
||||
|
||||
The dictionary contains actions of both Node and ACL action types.
|
||||
"""
|
||||
|
||||
@@ -21,7 +21,8 @@ def calculate_reward_function(
|
||||
step_count,
|
||||
config_values,
|
||||
):
|
||||
"""Compares the states of the initial and final nodes/links to get a reward.
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
Args:
|
||||
initial_nodes: The nodes before red and blue agents take effect
|
||||
@@ -94,7 +95,8 @@ def calculate_reward_function(
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
"""Calculates score relating to the hardware state of a node.
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -142,7 +144,8 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
"""Calculates score relating to the Software State of a node.
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -192,7 +195,8 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
"""Calculates score relating to the service state(s) of a node.
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -263,7 +267,8 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
"""Calculates score relating to the file system state of a node.
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
|
||||
@@ -29,7 +29,8 @@ class Link(object):
|
||||
self.add_protocol(protocol_name)
|
||||
|
||||
def add_protocol(self, _protocol):
|
||||
"""Adds a new protocol to the list of protocols on this link.
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to be added (enum)
|
||||
@@ -37,7 +38,8 @@ class Link(object):
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self):
|
||||
"""Gets link ID.
|
||||
"""
|
||||
Gets link ID.
|
||||
|
||||
Returns:
|
||||
Link ID
|
||||
@@ -45,7 +47,8 @@ class Link(object):
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self):
|
||||
"""Gets source node name.
|
||||
"""
|
||||
Gets source node name.
|
||||
|
||||
Returns:
|
||||
Source node name
|
||||
@@ -53,7 +56,8 @@ class Link(object):
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self):
|
||||
"""Gets destination node name.
|
||||
"""
|
||||
Gets destination node name.
|
||||
|
||||
Returns:
|
||||
Destination node name
|
||||
@@ -61,7 +65,8 @@ class Link(object):
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self):
|
||||
"""Gets bandwidth of link.
|
||||
"""
|
||||
Gets bandwidth of link.
|
||||
|
||||
Returns:
|
||||
Link bandwidth (bps)
|
||||
@@ -69,7 +74,8 @@ class Link(object):
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self):
|
||||
"""Gets list of protocols on this link.
|
||||
"""
|
||||
Gets list of protocols on this link.
|
||||
|
||||
Returns:
|
||||
List of protocols on this link
|
||||
@@ -77,7 +83,8 @@ class Link(object):
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self):
|
||||
"""Gets current total load on this link.
|
||||
"""
|
||||
Gets current total load on this link.
|
||||
|
||||
Returns:
|
||||
Total load on this link (bps)
|
||||
@@ -88,7 +95,8 @@ class Link(object):
|
||||
return total_load
|
||||
|
||||
def add_protocol_load(self, _protocol, _load):
|
||||
"""Adds a loading to a protocol on this link.
|
||||
"""
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to load
|
||||
|
||||
@@ -14,7 +14,8 @@ def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
"""Run the PrimAITE Session.
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
|
||||
@@ -52,7 +52,8 @@ class ActiveNode(Node):
|
||||
|
||||
@property
|
||||
def software_state(self) -> SoftwareState:
|
||||
"""Get the software_state.
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:return: The software_state.
|
||||
"""
|
||||
@@ -60,7 +61,8 @@ class ActiveNode(Node):
|
||||
|
||||
@software_state.setter
|
||||
def software_state(self, software_state: SoftwareState):
|
||||
"""Get the software_state.
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:param software_state: Software State.
|
||||
"""
|
||||
@@ -78,7 +80,8 @@ class ActiveNode(Node):
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||
"""Sets Software State if the node is not compromised.
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
Args:
|
||||
software_state: Software State
|
||||
@@ -104,7 +107,8 @@ class ActiveNode(Node):
|
||||
self._software_state = SoftwareState.GOOD
|
||||
|
||||
def set_file_system_state(self, file_system_state: FileSystemState):
|
||||
"""Sets the file system state (actual and observed).
|
||||
"""
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
@@ -130,7 +134,8 @@ class ActiveNode(Node):
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||
"""Sets the file system state (actual and observed) if not in a compromised state.
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
Use for green PoL to prevent it overturning a compromised state
|
||||
|
||||
|
||||
@@ -35,7 +35,8 @@ class NodeStateInstructionGreen(object):
|
||||
self.state = _state
|
||||
|
||||
def get_start_step(self):
|
||||
"""Gets the start step.
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
@@ -43,7 +44,8 @@ class NodeStateInstructionGreen(object):
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""Gets the end step.
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
@@ -51,7 +53,8 @@ class NodeStateInstructionGreen(object):
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self):
|
||||
"""Gets the node ID.
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
@@ -59,7 +62,8 @@ class NodeStateInstructionGreen(object):
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self):
|
||||
"""Gets the node pattern of life type (enum).
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
@@ -67,7 +71,8 @@ class NodeStateInstructionGreen(object):
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""Gets the service name.
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
@@ -75,7 +80,8 @@ class NodeStateInstructionGreen(object):
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""Gets the state (node or service).
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
|
||||
@@ -51,7 +51,8 @@ class NodeStateInstructionRed(object):
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self):
|
||||
"""Gets the start step.
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
@@ -59,7 +60,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""Gets the end step.
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
@@ -67,7 +69,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self):
|
||||
"""Gets the node ID.
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
@@ -75,7 +78,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self):
|
||||
"""Gets the initiator.
|
||||
"""
|
||||
Gets the initiator.
|
||||
|
||||
Returns:
|
||||
The initiator
|
||||
@@ -83,7 +87,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.initiator
|
||||
|
||||
def get_pol_type(self) -> NodePOLType:
|
||||
"""Gets the node pattern of life type (enum).
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
@@ -91,7 +96,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""Gets the service name.
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
@@ -99,7 +105,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""Gets the state (node or service).
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
@@ -107,7 +114,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""Gets the source node id (used for initiator type SERVICE).
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node id
|
||||
@@ -115,7 +123,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self):
|
||||
"""Gets the source node service (used for initiator type SERVICE).
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service
|
||||
@@ -123,7 +132,8 @@ class NodeStateInstructionRed(object):
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self):
|
||||
"""Gets the source node service state (used for initiator type SERVICE).
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service state
|
||||
|
||||
@@ -32,7 +32,8 @@ class PassiveNode(Node):
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
"""Gets the node IP address as an empty string.
|
||||
"""
|
||||
Gets the node IP address as an empty string.
|
||||
|
||||
No concept of IP address for passive nodes for now.
|
||||
|
||||
|
||||
@@ -53,14 +53,16 @@ class ServiceNode(ActiveNode):
|
||||
self.services: Dict[str, Service] = {}
|
||||
|
||||
def add_service(self, service: Service):
|
||||
"""Adds a service to the node.
|
||||
"""
|
||||
Adds a service to the node.
|
||||
|
||||
:param service: The service to add
|
||||
"""
|
||||
self.services[service.name] = service
|
||||
|
||||
def has_service(self, protocol_name: str) -> bool:
|
||||
"""Indicates whether a service is on a node.
|
||||
"""
|
||||
Indicates whether a service is on a node.
|
||||
|
||||
:param protocol_name: The service (protocol)e.
|
||||
:return: True if service (protocol) is on the node, otherwise False.
|
||||
@@ -71,7 +73,8 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
|
||||
def service_running(self, protocol_name: str) -> bool:
|
||||
"""Indicates whether a service is in a running state on the node.
|
||||
"""
|
||||
Indicates whether a service is in a running state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in a running state on the node, otherwise False.
|
||||
@@ -85,7 +88,8 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
|
||||
def service_is_overwhelmed(self, protocol_name: str) -> bool:
|
||||
"""Indicates whether a service is in an overwhelmed state on the node.
|
||||
"""
|
||||
Indicates whether a service is in an overwhelmed state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
|
||||
@@ -99,7 +103,8 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""Sets the software_state of a service (protocol) on the node.
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
@@ -127,7 +132,8 @@ class ServiceNode(ActiveNode):
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""Sets the software_state of a service (protocol) on the node.
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
Done if the software_state is not "compromised".
|
||||
|
||||
@@ -153,7 +159,8 @@ class ServiceNode(ActiveNode):
|
||||
)
|
||||
|
||||
def get_service_state(self, protocol_name):
|
||||
"""Gets the state of a service.
|
||||
"""
|
||||
Gets the state of a service.
|
||||
|
||||
:return: The software_state of the service.
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session():
|
||||
"""Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
"""
|
||||
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
|
||||
Currently only works on Windows OS.
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ def apply_iers(
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
"""Applies IERs to the links (link pattern of life).
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -217,7 +218,8 @@ def apply_node_pol(
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
step: int,
|
||||
):
|
||||
"""Applies node pattern of life.
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Information Exchange Requirements for APE.
|
||||
"""
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
Used to represent an information flow from source to destination.
|
||||
"""
|
||||
@@ -47,7 +48,8 @@ class IER(object):
|
||||
self.running = _running
|
||||
|
||||
def get_id(self):
|
||||
"""Gets IER ID.
|
||||
"""
|
||||
Gets IER ID.
|
||||
|
||||
Returns:
|
||||
IER ID
|
||||
@@ -55,7 +57,8 @@ class IER(object):
|
||||
return self.id
|
||||
|
||||
def get_start_step(self):
|
||||
"""Gets IER start step.
|
||||
"""
|
||||
Gets IER start step.
|
||||
|
||||
Returns:
|
||||
IER start step
|
||||
@@ -63,7 +66,8 @@ class IER(object):
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""Gets IER end step.
|
||||
"""
|
||||
Gets IER end step.
|
||||
|
||||
Returns:
|
||||
IER end step
|
||||
@@ -71,7 +75,8 @@ class IER(object):
|
||||
return self.end_step
|
||||
|
||||
def get_load(self):
|
||||
"""Gets IER load.
|
||||
"""
|
||||
Gets IER load.
|
||||
|
||||
Returns:
|
||||
IER load
|
||||
@@ -79,7 +84,8 @@ class IER(object):
|
||||
return self.load
|
||||
|
||||
def get_protocol(self):
|
||||
"""Gets IER protocol.
|
||||
"""
|
||||
Gets IER protocol.
|
||||
|
||||
Returns:
|
||||
IER protocol
|
||||
@@ -87,7 +93,8 @@ class IER(object):
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
"""Gets IER port.
|
||||
"""
|
||||
Gets IER port.
|
||||
|
||||
Returns:
|
||||
IER port
|
||||
@@ -95,7 +102,8 @@ class IER(object):
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""Gets IER source node ID.
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
|
||||
Returns:
|
||||
IER source node ID
|
||||
@@ -103,7 +111,8 @@ class IER(object):
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self):
|
||||
"""Gets IER destination node ID.
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
|
||||
Returns:
|
||||
IER destination node ID
|
||||
@@ -111,7 +120,8 @@ class IER(object):
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self):
|
||||
"""Informs whether the IER is currently running.
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
Returns:
|
||||
True if running
|
||||
@@ -119,7 +129,8 @@ class IER(object):
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value):
|
||||
"""Sets the running state of the IER.
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
|
||||
Args:
|
||||
_value: running status
|
||||
@@ -127,7 +138,8 @@ class IER(object):
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self):
|
||||
"""Gets the IER mission criticality (used in the reward function).
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
Returns:
|
||||
Mission criticality value (0 lowest to 5 highest)
|
||||
|
||||
@@ -24,7 +24,8 @@ def apply_red_agent_iers(
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
):
|
||||
"""Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -213,7 +214,8 @@ def apply_red_agent_node_pol(
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
):
|
||||
"""Applies node pattern of life.
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
@@ -295,7 +297,8 @@ def apply_red_agent_node_pol(
|
||||
|
||||
|
||||
def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
"""Checks if the RED IER is incoming.
|
||||
"""
|
||||
Checks if the RED IER is incoming.
|
||||
|
||||
TODO: Write more descriptive docstring with params and returns.
|
||||
"""
|
||||
|
||||
@@ -125,7 +125,8 @@ class PrimaiteSession:
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Train the agent.
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
@@ -136,7 +137,8 @@ class PrimaiteSession:
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""Evaluate the agent.
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,8 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True):
|
||||
"""Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
"""
|
||||
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off.
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing=True):
|
||||
"""Resets the example config files in the users app config directory.
|
||||
"""
|
||||
Resets the example config files in the users app config directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited config on or off.
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,8 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
"""Handles creation of application directories and user directories.
|
||||
"""
|
||||
Handles creation of application directories and user directories.
|
||||
|
||||
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
|
||||
app directories in the correct locations based on the users OS.
|
||||
|
||||
@@ -39,7 +39,8 @@ class Transaction(object):
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""Converts the Transaction to a csv data row and provides a header.
|
||||
"""
|
||||
Converts the Transaction to a csv data row and provides a header.
|
||||
|
||||
:return: A tuple consisting of (header, data).
|
||||
"""
|
||||
@@ -68,7 +69,8 @@ class Transaction(object):
|
||||
|
||||
|
||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
"""Turns action space into a string array so it can be saved to csv.
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
:param action_space: The action space
|
||||
:return: The action space as an array of strings
|
||||
@@ -80,7 +82,8 @@ def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||
"""Turns observation space into a string array so it can be saved to csv.
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
:param obs_space: The observation space
|
||||
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
|
||||
@@ -10,7 +10,8 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
"""Get PrimAITE package data.
|
||||
"""
|
||||
Get PrimAITE package data.
|
||||
|
||||
:Example:
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ import polars as pl
|
||||
|
||||
|
||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""Read an average rewards per episode csv file and return as a dict.
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ class SessionOutputWriter:
|
||||
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(self, data: Union[Tuple, Transaction]):
|
||||
"""Write a row of session data.
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
:param data: The row of data to write. Can be a Tuple or an instance of Transaction.
|
||||
"""
|
||||
|
||||
@@ -75,7 +75,8 @@ class TestNodeLinkTable:
|
||||
assert env.env_obs.shape == (5, 6)
|
||||
|
||||
def test_value(self, temp_primaite_session):
|
||||
"""Test that the observation is generated correctly.
|
||||
"""
|
||||
Test that the observation is generated correctly.
|
||||
|
||||
The laydown has:
|
||||
* 3 nodes (2 service nodes and 1 active node)
|
||||
@@ -157,7 +158,8 @@ class TestNodeStatuses:
|
||||
assert env.env_obs.shape == (15,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that the hardware and software states are encoded correctly.
|
||||
"""
|
||||
Test that the hardware and software states are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* one node with a compromised operating system state
|
||||
@@ -213,7 +215,8 @@ class TestLinkTrafficLevels:
|
||||
assert env.env_obs.shape == (2 * 2,)
|
||||
|
||||
def test_values(self, temp_primaite_session):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
"""
|
||||
Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* two services
|
||||
|
||||
Reference in New Issue
Block a user