Standardise docstring summary line placement.

This commit is contained in:
Marek Wolan
2023-07-07 10:28:00 +01:00
parent 86725064ec
commit f4b98542b6
36 changed files with 350 additions and 175 deletions

View File

@@ -66,7 +66,8 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
# region Setup Logging
class _LevelFormatter(Formatter):
"""A custom level-specific formatter.
"""
A custom level-specific formatter.
Credit to: https://stackoverflow.com/a/68154386
"""
@@ -134,7 +135,8 @@ _LOGGER.addHandler(_FILE_HANDLER)
def getLogger(name: str) -> Logger: # noqa
"""Get a PrimAITE logger.
"""
Get a PrimAITE logger.
:param name: The logger name. Use ``__name__``.
:return: An instance of :py:class:`logging.Logger` with the PrimAITE

View File

@@ -35,7 +35,8 @@ class AccessControlList:
return False
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
"""Checks for rules that block a protocol / port.
"""
Checks for rules that block a protocol / port.
Args:
_source_ip_address: the source IP address to check
@@ -61,7 +62,8 @@ class AccessControlList:
return True
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
"""Adds a new rule.
"""
Adds a new rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
@@ -75,7 +77,8 @@ class AccessControlList:
self.acl[hash_value] = new_rule
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
"""Removes a rule.
"""
Removes a rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
@@ -97,7 +100,8 @@ class AccessControlList:
self.acl.clear()
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
"""Produces a hash value for a rule.
"""
Produces a hash value for a rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")

View File

@@ -22,7 +22,8 @@ class ACLRule:
self.port = _port
def __hash__(self):
"""Override the hash function.
"""
Override the hash function.
Returns:
Returns hash of core parameters.
@@ -38,7 +39,8 @@ class ACLRule:
)
def get_permission(self):
"""Gets the permission attribute.
"""
Gets the permission attribute.
Returns:
Returns permission attribute
@@ -46,7 +48,8 @@ class ACLRule:
return self.permission
def get_source_ip(self):
"""Gets the source IP address attribute.
"""
Gets the source IP address attribute.
Returns:
Returns source IP address attribute
@@ -54,7 +57,8 @@ class ACLRule:
return self.source_ip
def get_dest_ip(self):
"""Gets the desintation IP address attribute.
"""
Gets the desintation IP address attribute.
Returns:
Returns destination IP address attribute
@@ -62,7 +66,8 @@ class ACLRule:
return self.dest_ip
def get_protocol(self):
"""Gets the protocol attribute.
"""
Gets the protocol attribute.
Returns:
Returns protocol attribute
@@ -70,7 +75,8 @@ class ACLRule:
return self.protocol
def get_port(self):
"""Gets the port attribute.
"""
Gets the port attribute.
Returns:
Returns port attribute

View File

@@ -21,7 +21,8 @@ _LOGGER = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:
"""Get the directory path the session will output to.
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
@@ -194,7 +195,8 @@ class AgentSessionABC(ABC):
self,
**kwargs,
):
"""Train the agent.
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
@@ -211,7 +213,8 @@ class AgentSessionABC(ABC):
self,
**kwargs,
):
"""Evaluate the agent.
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
@@ -340,7 +343,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self,
**kwargs,
):
"""Train the agent.
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
@@ -354,7 +358,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self,
**kwargs,
):
"""Evaluate the agent.
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""

View File

@@ -46,7 +46,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
"""Get matching ACL rules for an IER.
"""
Get matching ACL rules for an IER.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -62,7 +63,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
"""Get blocking ACL rules for an IER.
"""
Get blocking ACL rules for an IER.
.. warning::
Can return empty dict but IER can still be blocked by default
@@ -81,7 +83,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_rules
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
"""Get all allowing ACL rules for an IER.
"""
Get all allowing ACL rules for an IER.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -105,7 +108,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
nodes,
services_list,
):
"""Get matching ACL rules.
"""
Get matching ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -136,7 +140,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
nodes,
services_list,
):
"""Get the ALLOW ACL rules.
"""
Get the ALLOW ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -168,7 +173,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
nodes,
services_list,
):
"""Get the DENY ACL rules.
"""
Get the DENY ACL rules.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -191,7 +197,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return allowed_rules
def _calculate_action_full_view(self, obs):
"""Calculate a good acl-based action for the blue agent to take.
"""
Calculate a good acl-based action for the blue agent to take.
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
@@ -355,7 +362,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return action
def _calculate_action_basic_view(self, obs):
"""Calculate a good acl-based action for the blue agent to take.
"""
Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge
of previous actions taken and NO reward feedback.

View File

@@ -6,7 +6,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic Node agent."""
def _calculate_action(self, obs):
"""Calculate a good node-based action for the blue agent to take.
"""
Calculate a good node-based action for the blue agent to take.
TODO: Add params and return in docstring.
TODO: Typehint params and return.

View File

@@ -140,7 +140,8 @@ class RLlibAgent(AgentSessionABC):
self,
**kwargs,
):
"""Evaluate the agent.
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
@@ -158,7 +159,8 @@ class RLlibAgent(AgentSessionABC):
self,
**kwargs,
):
"""Evaluate the agent.
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""

View File

@@ -89,7 +89,8 @@ class SB3Agent(AgentSessionABC):
self,
**kwargs,
):
"""Train the agent.
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
@@ -109,7 +110,8 @@ class SB3Agent(AgentSessionABC):
deterministic: bool = True,
**kwargs,
):
"""Evaluate the agent.
"""
Evaluate the agent.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.

View File

@@ -11,7 +11,8 @@ from primaite.common.enums import (
def transform_action_node_readable(action):
"""Convert a node action from enumerated format to readable format.
"""
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
@@ -33,7 +34,8 @@ def transform_action_node_readable(action):
def transform_action_acl_readable(action):
"""Transform an ACL action to a more readable format.
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
@@ -57,7 +59,8 @@ def transform_action_acl_readable(action):
def is_valid_node_action(action):
"""Is the node action an actual valid action.
"""
Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
@@ -92,7 +95,8 @@ def is_valid_node_action(action):
def is_valid_acl_action(action):
"""Is the ACL action an actual valid action.
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect.
@@ -124,7 +128,8 @@ def is_valid_acl_action(action):
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action.
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -147,7 +152,8 @@ def is_valid_acl_action_extra(action):
def transform_change_obs_readable(obs):
"""Transform list of transactions to readable list of each observation property.
"""
Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
@@ -169,7 +175,8 @@ def transform_change_obs_readable(obs):
def transform_obs_readable(obs):
"""Transform observation to readable format.
"""
Transform observation to readable format.
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
@@ -185,7 +192,8 @@ def transform_obs_readable(obs):
def convert_to_new_obs(obs, num_nodes=10):
"""Convert original gym Box observation space to new multiDiscrete observation space.
"""
Convert original gym Box observation space to new multiDiscrete observation space.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -196,7 +204,8 @@ def convert_to_new_obs(obs, num_nodes=10):
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
"""Convert to old observation.
"""
Convert to old observation.
Links filled with 0's as no information is included in new observation space.
@@ -232,7 +241,8 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""Return string describing change between two observations.
"""
Return string describing change between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
@@ -260,7 +270,8 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
def _describe_obs_change_helper(obs_change, is_link):
"""Helper funcion to describe what has changed.
"""
Helper funcion to describe what has changed.
example:
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
@@ -295,7 +306,8 @@ def _describe_obs_change_helper(obs_change, is_link):
def transform_action_node_enum(action):
"""Convert a node action from readable string format, to enumerated format.
"""
Convert a node action from readable string format, to enumerated format.
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
@@ -326,7 +338,8 @@ def transform_action_node_enum(action):
def transform_action_node_readable(action):
"""Convert a node action from enumerated format to readable format.
"""
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
@@ -348,7 +361,8 @@ def transform_action_node_readable(action):
def node_action_description(action):
"""Generate string describing a node-based action.
"""
Generate string describing a node-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -375,7 +389,8 @@ def node_action_description(action):
def transform_action_acl_enum(action):
"""Convert acl action from readable str format, to enumerated format.
"""
Convert acl action from readable str format, to enumerated format.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -397,7 +412,8 @@ def transform_action_acl_enum(action):
def acl_action_description(action):
"""Generate string describing an acl-based action.
"""
Generate string describing an acl-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -417,7 +433,8 @@ def acl_action_description(action):
def get_node_of_ip(ip, node_dict):
"""Get the node ID of an IP address.
"""
Get the node ID of an IP address.
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
@@ -431,7 +448,8 @@ def get_node_of_ip(ip, node_dict):
def is_valid_node_action(action):
"""Is the node action an actual valid action.
"""
Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
@@ -464,7 +482,8 @@ def is_valid_node_action(action):
def is_valid_acl_action(action):
"""Is the ACL action an actual valid action.
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect
@@ -496,7 +515,8 @@ def is_valid_acl_action(action):
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action.
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
@@ -519,7 +539,8 @@ def is_valid_acl_action_extra(action):
def get_new_action(old_action, action_dict):
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0].
"""
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
Old_action can be either node or acl action type

View File

@@ -28,7 +28,8 @@ def build_dirs():
@app.command()
def reset_notebooks(overwrite: bool = True):
"""Force a reset of the demo notebooks in the users notebooks directory.
"""
Force a reset of the demo notebooks in the users notebooks directory.
:param overwrite: If True, will overwrite existing demo notebooks.
"""
@@ -39,7 +40,8 @@ def reset_notebooks(overwrite: bool = True):
@app.command()
def logs(last_n: Annotated[int, typer.Option("-n")]):
"""Print the PrimAITE log file.
"""
Print the PrimAITE log file.
:param last_n: The number of lines to print. Default value is 10.
"""
@@ -59,7 +61,8 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n
@app.command()
def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None):
"""View or set the PrimAITE Log Level.
"""
View or set the PrimAITE Log Level.
To View, simply call: primaite log-level
@@ -110,7 +113,8 @@ def clean_up():
@app.command()
def setup(overwrite_existing: bool = True):
"""Perform the PrimAITE first-time setup.
"""
Perform the PrimAITE first-time setup.
WARNING: All user-data will be lost.
"""
@@ -148,7 +152,8 @@ def setup(overwrite_existing: bool = True):
@app.command()
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
"""Run a PrimAITE session.
"""
Run a PrimAITE session.
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
@@ -173,7 +178,8 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
"""View or set the plotly template for Session plots.
"""
View or set the plotly template for Session plots.
To View, simply call: primaite plotly-template

View File

@@ -12,7 +12,8 @@ _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a legacy lay down config dict to the new format.
"""
Convert a legacy lay down config dict to the new format.
:param legacy_config_dict: A legacy lay down config dict.
"""
@@ -21,7 +22,8 @@ def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> D
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
"""Read in a lay down config yaml file.
"""
Read in a lay down config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise False.
@@ -50,7 +52,8 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
def ddos_basic_one_config_path() -> Path:
"""The path to the example lay_down_config_1_DDOS_basic.yaml file.
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.
:return: The file path.
"""
@@ -64,7 +67,8 @@ def ddos_basic_one_config_path() -> Path:
def ddos_basic_two_config_path() -> Path:
"""The path to the example lay_down_config_2_DDOS_basic.yaml file.
"""
The path to the example lay_down_config_2_DDOS_basic.yaml file.
:return: The file path.
"""
@@ -78,7 +82,8 @@ def ddos_basic_two_config_path() -> Path:
def dos_very_basic_config_path() -> Path:
"""The path to the example lay_down_config_3_DOS_very_basic.yaml file.
"""
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
:return: The file path.
"""
@@ -92,7 +97,8 @@ def dos_very_basic_config_path() -> Path:
def data_manipulation_config_path() -> Path:
"""The path to the example lay_down_config_5_data_manipulation.yaml file.
"""
The path to the example lay_down_config_5_data_manipulation.yaml file.
:return: The file path.
"""

View File

@@ -180,7 +180,8 @@ class TrainingConfig:
@classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
"""Create an instance of TrainingConfig from a dict.
"""
Create an instance of TrainingConfig from a dict.
:param config_dict: The training config dict.
:return: The instance of TrainingConfig.

View File

@@ -22,7 +22,8 @@ def plot_av_reward_per_episode(
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
"""Plot the average reward per episode from a csv session output.
"""
Plot the average reward per episode from a csv session output.
:param av_reward_per_episode_csv: The average reward per episode csv
file path.

View File

@@ -50,7 +50,8 @@ class AbstractObservationComponent(ABC):
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes and links as rows and hardware/software status as cols.
"""
Table with nodes and links as rows and hardware/software status as cols.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
@@ -74,7 +75,8 @@ class NodeLinkTable(AbstractObservationComponent):
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
"""Initialise a NodeLinkTable observation space component.
"""
Initialise a NodeLinkTable observation space component.
:param env: Training environment.
:type env: Primaite
@@ -100,7 +102,8 @@ class NodeLinkTable(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeLinkTable`
"""
@@ -181,7 +184,8 @@ class NodeLinkTable(AbstractObservationComponent):
class NodeStatuses(AbstractObservationComponent):
"""Flat list of nodes' hardware, OS, file system, and service states.
"""
Flat list of nodes' hardware, OS, file system, and service states.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
@@ -234,7 +238,8 @@ class NodeStatuses(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeStatuses`
"""
@@ -287,7 +292,8 @@ class NodeStatuses(AbstractObservationComponent):
class LinkTrafficLevels(AbstractObservationComponent):
"""Flat list of traffic levels encoded into banded categories.
"""
Flat list of traffic levels encoded into banded categories.
For each link, total traffic or traffic per service is encoded into a categorical value.
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
@@ -354,7 +360,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.LinkTrafficLevels`
"""
@@ -395,7 +402,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
class ObservationsHandler:
"""Component-based observation space handler.
"""
Component-based observation space handler.
This allows users to configure observation spaces by mixing and matching components. Each component can also define
further parameters to make them more flexible.
@@ -436,7 +444,8 @@ class ObservationsHandler:
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
"""
Add a component for this handler to track.
:param obs_component: The component to add.
:type obs_component: AbstractObservationComponent
@@ -445,7 +454,8 @@ class ObservationsHandler:
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
"""Remove a component from this handler.
"""
Remove a component from this handler.
:param obs_component: Which component to remove. It must exist within this object's
``registered_obs_components`` attribute.
@@ -488,7 +498,8 @@ class ObservationsHandler:
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
"""
Parse a config dictinary, return a new observation handler populated with new observation component objects.
The expected format for the config dictionary is:
@@ -533,7 +544,8 @@ class ObservationsHandler:
return handler
def describe_structure(self):
"""Create a list of names for the features of the obs space.
"""
Create a list of names for the features of the obs space.
The order of labels follows the flattened version of the space.
"""

View File

@@ -255,7 +255,8 @@ class Primaite(Env):
self.total_step_count = 0
def reset(self):
"""AI Gym Reset function.
"""
AI Gym Reset function.
Returns:
Environment observation space (reset)
@@ -291,7 +292,8 @@ class Primaite(Env):
return self.env_obs
def step(self, action):
"""AI Gym Step function.
"""
AI Gym Step function.
Args:
action: Action space from agent
@@ -429,7 +431,8 @@ class Primaite(Env):
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
"""Applies agent actions to the nodes and Access Control List.
"""
Applies agent actions to the nodes and Access Control List.
Args:
_action: The action space from the agent
@@ -448,7 +451,8 @@ class Primaite(Env):
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
"""Applies agent actions to the nodes.
"""
Applies agent actions to the nodes.
Args:
_action: The action space from the agent
@@ -535,7 +539,8 @@ class Primaite(Env):
return
def apply_actions_to_acl(self, _action):
"""Applies agent actions to the Access Control List [TO DO].
"""
Applies agent actions to the Access Control List [TO DO].
Args:
_action: The action space from the agent
@@ -612,7 +617,8 @@ class Primaite(Env):
return
def apply_time_based_updates(self):
"""Updates anything that needs to count down and then change state.
"""
Updates anything that needs to count down and then change state.
e.g. reset / patching status
"""
@@ -653,7 +659,8 @@ class Primaite(Env):
pass
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Create the environment's observation handler.
"""
Create the environment's observation handler.
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
@@ -709,7 +716,8 @@ class Primaite(Env):
print("Environment configuration loaded")
def create_node(self, item):
"""Creates a node from config data.
"""
Creates a node from config data.
Args:
item: A config data item
@@ -789,7 +797,8 @@ class Primaite(Env):
self.network_reference.add_nodes_from([node_ref])
def create_link(self, item: Dict):
"""Creates a link from config data.
"""
Creates a link from config data.
Args:
item: A config data item
@@ -832,7 +841,8 @@ class Primaite(Env):
)
def create_green_ier(self, item):
"""Creates a green IER from config data.
"""
Creates a green IER from config data.
Args:
item: A config data item
@@ -872,7 +882,8 @@ class Primaite(Env):
)
def create_red_ier(self, item):
"""Creates a red IER from config data.
"""
Creates a red IER from config data.
Args:
item: A config data item
@@ -901,7 +912,8 @@ class Primaite(Env):
)
def create_green_pol(self, item):
"""Creates a green PoL object from config data.
"""
Creates a green PoL object from config data.
Args:
item: A config data item
@@ -934,7 +946,8 @@ class Primaite(Env):
)
def create_red_pol(self, item):
"""Creates a red PoL object from config data.
"""
Creates a red PoL object from config data.
Args:
item: A config data item
@@ -974,7 +987,8 @@ class Primaite(Env):
)
def create_acl_rule(self, item):
"""Creates an ACL rule from config data.
"""
Creates an ACL rule from config data.
Args:
item: A config data item
@@ -994,7 +1008,8 @@ class Primaite(Env):
)
def create_services_list(self, services):
"""Creates a list of services (enum) from config data.
"""
Creates a list of services (enum) from config data.
Args:
item: A config data item representing the services
@@ -1009,7 +1024,8 @@ class Primaite(Env):
self.num_services = len(self.services_list)
def create_ports_list(self, ports):
"""Creates a list of ports from config data.
"""
Creates a list of ports from config data.
Args:
item: A config data item representing the ports
@@ -1024,7 +1040,8 @@ class Primaite(Env):
self.num_ports = len(self.ports_list)
def get_observation_info(self, observation_info):
"""Extracts observation_info.
"""
Extracts observation_info.
:param observation_info: Config item that defines which type of observation space to use
:type observation_info: str
@@ -1032,7 +1049,8 @@ class Primaite(Env):
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
"""Extracts action_info.
"""
Extracts action_info.
Args:
item: A config data item representing action info
@@ -1040,7 +1058,8 @@ class Primaite(Env):
self.action_type = ActionType[action_info["type"]]
def save_obs_config(self, obs_config: dict):
"""Cache the config for the observation space.
"""
Cache the config for the observation space.
This is necessary as the observation space can't be built while reading the config,
it must be done after all the nodes, links, and services have been initialised.
@@ -1052,7 +1071,8 @@ class Primaite(Env):
self.obs_config = obs_config
def reset_environment(self):
"""# Resets environment.
"""
Resets environment.
Uses config data config data in order to build the environment configuration.
"""
@@ -1076,7 +1096,8 @@ class Primaite(Env):
ier_value.set_is_running(False)
def reset_node(self, item):
"""Resets the statuses of a node.
"""
Resets the statuses of a node.
Args:
item: A config data item
@@ -1123,7 +1144,8 @@ class Primaite(Env):
pass
def create_node_action_dict(self):
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
@@ -1187,7 +1209,8 @@ class Primaite(Env):
return actions
def create_node_and_acl_action_dict(self):
"""Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
The dictionary contains actions of both Node and ACL action types.
"""

View File

@@ -21,7 +21,8 @@ def calculate_reward_function(
step_count,
config_values,
):
"""Compares the states of the initial and final nodes/links to get a reward.
"""
Compares the states of the initial and final nodes/links to get a reward.
Args:
initial_nodes: The nodes before red and blue agents take effect
@@ -94,7 +95,8 @@ def calculate_reward_function(
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
"""Calculates score relating to the hardware state of a node.
"""
Calculates score relating to the hardware state of a node.
Args:
final_node: The node after red and blue agents take effect
@@ -142,7 +144,8 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
def score_node_os_state(final_node, initial_node, reference_node, config_values):
"""Calculates score relating to the Software State of a node.
"""
Calculates score relating to the Software State of a node.
Args:
final_node: The node after red and blue agents take effect
@@ -192,7 +195,8 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
def score_node_service_state(final_node, initial_node, reference_node, config_values):
"""Calculates score relating to the service state(s) of a node.
"""
Calculates score relating to the service state(s) of a node.
Args:
final_node: The node after red and blue agents take effect
@@ -263,7 +267,8 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
def score_node_file_system(final_node, initial_node, reference_node, config_values):
"""Calculates score relating to the file system state of a node.
"""
Calculates score relating to the file system state of a node.
Args:
final_node: The node after red and blue agents take effect

View File

@@ -29,7 +29,8 @@ class Link(object):
self.add_protocol(protocol_name)
def add_protocol(self, _protocol):
"""Adds a new protocol to the list of protocols on this link.
"""
Adds a new protocol to the list of protocols on this link.
Args:
_protocol: The protocol to be added (enum)
@@ -37,7 +38,8 @@ class Link(object):
self.protocol_list.append(Protocol(_protocol))
def get_id(self):
"""Gets link ID.
"""
Gets link ID.
Returns:
Link ID
@@ -45,7 +47,8 @@ class Link(object):
return self.id
def get_source_node_name(self):
"""Gets source node name.
"""
Gets source node name.
Returns:
Source node name
@@ -53,7 +56,8 @@ class Link(object):
return self.source_node_name
def get_dest_node_name(self):
"""Gets destination node name.
"""
Gets destination node name.
Returns:
Destination node name
@@ -61,7 +65,8 @@ class Link(object):
return self.dest_node_name
def get_bandwidth(self):
"""Gets bandwidth of link.
"""
Gets bandwidth of link.
Returns:
Link bandwidth (bps)
@@ -69,7 +74,8 @@ class Link(object):
return self.bandwidth
def get_protocol_list(self):
"""Gets list of protocols on this link.
"""
Gets list of protocols on this link.
Returns:
List of protocols on this link
@@ -77,7 +83,8 @@ class Link(object):
return self.protocol_list
def get_current_load(self):
"""Gets current total load on this link.
"""
Gets current total load on this link.
Returns:
Total load on this link (bps)
@@ -88,7 +95,8 @@ class Link(object):
return total_load
def add_protocol_load(self, _protocol, _load):
"""Adds a loading to a protocol on this link.
"""
Adds a loading to a protocol on this link.
Args:
_protocol: The protocol to load

View File

@@ -14,7 +14,8 @@ def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
"""Run the PrimAITE Session.
"""
Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.

View File

@@ -52,7 +52,8 @@ class ActiveNode(Node):
@property
def software_state(self) -> SoftwareState:
"""Get the software_state.
"""
Get the software_state.
:return: The software_state.
"""
@@ -60,7 +61,8 @@ class ActiveNode(Node):
@software_state.setter
def software_state(self, software_state: SoftwareState):
"""Get the software_state.
"""
Get the software_state.
:param software_state: Software State.
"""
@@ -78,7 +80,8 @@ class ActiveNode(Node):
)
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
"""Sets Software State if the node is not compromised.
"""
Sets Software State if the node is not compromised.
Args:
software_state: Software State
@@ -104,7 +107,8 @@ class ActiveNode(Node):
self._software_state = SoftwareState.GOOD
def set_file_system_state(self, file_system_state: FileSystemState):
"""Sets the file system state (actual and observed).
"""
Sets the file system state (actual and observed).
Args:
file_system_state: File system state
@@ -130,7 +134,8 @@ class ActiveNode(Node):
)
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
"""Sets the file system state (actual and observed) if not in a compromised state.
"""
Sets the file system state (actual and observed) if not in a compromised state.
Use for green PoL to prevent it overturning a compromised state

View File

@@ -35,7 +35,8 @@ class NodeStateInstructionGreen(object):
self.state = _state
def get_start_step(self):
"""Gets the start step.
"""
Gets the start step.
Returns:
The start step
@@ -43,7 +44,8 @@ class NodeStateInstructionGreen(object):
return self.start_step
def get_end_step(self):
"""Gets the end step.
"""
Gets the end step.
Returns:
The end step
@@ -51,7 +53,8 @@ class NodeStateInstructionGreen(object):
return self.end_step
def get_node_id(self):
"""Gets the node ID.
"""
Gets the node ID.
Returns:
The node ID
@@ -59,7 +62,8 @@ class NodeStateInstructionGreen(object):
return self.node_id
def get_node_pol_type(self):
"""Gets the node pattern of life type (enum).
"""
Gets the node pattern of life type (enum).
Returns:
The node pattern of life type (enum)
@@ -67,7 +71,8 @@ class NodeStateInstructionGreen(object):
return self.node_pol_type
def get_service_name(self):
"""Gets the service name.
"""
Gets the service name.
Returns:
The service name
@@ -75,7 +80,8 @@ class NodeStateInstructionGreen(object):
return self.service_name
def get_state(self):
"""Gets the state (node or service).
"""
Gets the state (node or service).
Returns:
The state (node or service)

View File

@@ -51,7 +51,8 @@ class NodeStateInstructionRed(object):
self.source_node_service_state = _pol_source_node_service_state
def get_start_step(self):
"""Gets the start step.
"""
Gets the start step.
Returns:
The start step
@@ -59,7 +60,8 @@ class NodeStateInstructionRed(object):
return self.start_step
def get_end_step(self):
"""Gets the end step.
"""
Gets the end step.
Returns:
The end step
@@ -67,7 +69,8 @@ class NodeStateInstructionRed(object):
return self.end_step
def get_target_node_id(self):
"""Gets the node ID.
"""
Gets the node ID.
Returns:
The node ID
@@ -75,7 +78,8 @@ class NodeStateInstructionRed(object):
return self.target_node_id
def get_initiator(self):
"""Gets the initiator.
"""
Gets the initiator.
Returns:
The initiator
@@ -83,7 +87,8 @@ class NodeStateInstructionRed(object):
return self.initiator
def get_pol_type(self) -> NodePOLType:
"""Gets the node pattern of life type (enum).
"""
Gets the node pattern of life type (enum).
Returns:
The node pattern of life type (enum)
@@ -91,7 +96,8 @@ class NodeStateInstructionRed(object):
return self.pol_type
def get_service_name(self):
"""Gets the service name.
"""
Gets the service name.
Returns:
The service name
@@ -99,7 +105,8 @@ class NodeStateInstructionRed(object):
return self.service_name
def get_state(self):
"""Gets the state (node or service).
"""
Gets the state (node or service).
Returns:
The state (node or service)
@@ -107,7 +114,8 @@ class NodeStateInstructionRed(object):
return self.state
def get_source_node_id(self):
"""Gets the source node id (used for initiator type SERVICE).
"""
Gets the source node id (used for initiator type SERVICE).
Returns:
The source node id
@@ -115,7 +123,8 @@ class NodeStateInstructionRed(object):
return self.source_node_id
def get_source_node_service(self):
"""Gets the source node service (used for initiator type SERVICE).
"""
Gets the source node service (used for initiator type SERVICE).
Returns:
The source node service
@@ -123,7 +132,8 @@ class NodeStateInstructionRed(object):
return self.source_node_service
def get_source_node_service_state(self):
"""Gets the source node service state (used for initiator type SERVICE).
"""
Gets the source node service state (used for initiator type SERVICE).
Returns:
The source node service state

View File

@@ -32,7 +32,8 @@ class PassiveNode(Node):
@property
def ip_address(self) -> str:
"""Gets the node IP address as an empty string.
"""
Gets the node IP address as an empty string.
No concept of IP address for passive nodes for now.

View File

@@ -53,14 +53,16 @@ class ServiceNode(ActiveNode):
self.services: Dict[str, Service] = {}
def add_service(self, service: Service):
"""Adds a service to the node.
"""
Adds a service to the node.
:param service: The service to add
"""
self.services[service.name] = service
def has_service(self, protocol_name: str) -> bool:
"""Indicates whether a service is on a node.
"""
Indicates whether a service is on a node.
:param protocol_name: The service (protocol)e.
:return: True if service (protocol) is on the node, otherwise False.
@@ -71,7 +73,8 @@ class ServiceNode(ActiveNode):
return False
def service_running(self, protocol_name: str) -> bool:
"""Indicates whether a service is in a running state on the node.
"""
Indicates whether a service is in a running state on the node.
:param protocol_name: The service (protocol)
:return: True if service (protocol) is in a running state on the node, otherwise False.
@@ -85,7 +88,8 @@ class ServiceNode(ActiveNode):
return False
def service_is_overwhelmed(self, protocol_name: str) -> bool:
"""Indicates whether a service is in an overwhelmed state on the node.
"""
Indicates whether a service is in an overwhelmed state on the node.
:param protocol_name: The service (protocol)
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
@@ -99,7 +103,8 @@ class ServiceNode(ActiveNode):
return False
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
"""Sets the software_state of a service (protocol) on the node.
"""
Sets the software_state of a service (protocol) on the node.
:param protocol_name: The service (protocol).
:param software_state: The software_state.
@@ -127,7 +132,8 @@ class ServiceNode(ActiveNode):
)
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
"""Sets the software_state of a service (protocol) on the node.
"""
Sets the software_state of a service (protocol) on the node.
Done if the software_state is not "compromised".
@@ -153,7 +159,8 @@ class ServiceNode(ActiveNode):
)
def get_service_state(self, protocol_name):
"""Gets the state of a service.
"""
Gets the state of a service.
:return: The software_state of the service.
"""

View File

@@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__)
def start_jupyter_session():
"""Starts a new Jupyter notebook session in the app notebooks directory.
"""
Starts a new Jupyter notebook session in the app notebooks directory.
Currently only works on Windows OS.

View File

@@ -25,7 +25,8 @@ def apply_iers(
acl: AccessControlList,
step: int,
):
"""Applies IERs to the links (link pattern of life).
"""
Applies IERs to the links (link pattern of life).
Args:
network: The network modelled in the environment
@@ -217,7 +218,8 @@ def apply_node_pol(
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
step: int,
):
"""Applies node pattern of life.
"""
Applies node pattern of life.
Args:
nodes: The nodes within the environment

View File

@@ -1,5 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Information Exchange Requirements for APE.
"""
Information Exchange Requirements for APE.
Used to represent an information flow from source to destination.
"""
@@ -47,7 +48,8 @@ class IER(object):
self.running = _running
def get_id(self):
"""Gets IER ID.
"""
Gets IER ID.
Returns:
IER ID
@@ -55,7 +57,8 @@ class IER(object):
return self.id
def get_start_step(self):
"""Gets IER start step.
"""
Gets IER start step.
Returns:
IER start step
@@ -63,7 +66,8 @@ class IER(object):
return self.start_step
def get_end_step(self):
"""Gets IER end step.
"""
Gets IER end step.
Returns:
IER end step
@@ -71,7 +75,8 @@ class IER(object):
return self.end_step
def get_load(self):
"""Gets IER load.
"""
Gets IER load.
Returns:
IER load
@@ -79,7 +84,8 @@ class IER(object):
return self.load
def get_protocol(self):
"""Gets IER protocol.
"""
Gets IER protocol.
Returns:
IER protocol
@@ -87,7 +93,8 @@ class IER(object):
return self.protocol
def get_port(self):
"""Gets IER port.
"""
Gets IER port.
Returns:
IER port
@@ -95,7 +102,8 @@ class IER(object):
return self.port
def get_source_node_id(self):
"""Gets IER source node ID.
"""
Gets IER source node ID.
Returns:
IER source node ID
@@ -103,7 +111,8 @@ class IER(object):
return self.source_node_id
def get_dest_node_id(self):
"""Gets IER destination node ID.
"""
Gets IER destination node ID.
Returns:
IER destination node ID
@@ -111,7 +120,8 @@ class IER(object):
return self.dest_node_id
def get_is_running(self):
"""Informs whether the IER is currently running.
"""
Informs whether the IER is currently running.
Returns:
True if running
@@ -119,7 +129,8 @@ class IER(object):
return self.running
def set_is_running(self, _value):
"""Sets the running state of the IER.
"""
Sets the running state of the IER.
Args:
_value: running status
@@ -127,7 +138,8 @@ class IER(object):
self.running = _value
def get_mission_criticality(self):
"""Gets the IER mission criticality (used in the reward function).
"""
Gets the IER mission criticality (used in the reward function).
Returns:
Mission criticality value (0 lowest to 5 highest)

View File

@@ -24,7 +24,8 @@ def apply_red_agent_iers(
acl: AccessControlList,
step: int,
):
"""Applies IERs to the links (link POL) resulting from red agent attack.
"""
Applies IERs to the links (link POL) resulting from red agent attack.
Args:
network: The network modelled in the environment
@@ -213,7 +214,8 @@ def apply_red_agent_node_pol(
node_pol: Dict[str, NodeStateInstructionRed],
step: int,
):
"""Applies node pattern of life.
"""
Applies node pattern of life.
Args:
nodes: The nodes within the environment
@@ -295,7 +297,8 @@ def apply_red_agent_node_pol(
def is_red_ier_incoming(node, iers, node_pol_type):
"""Checks if the RED IER is incoming.
"""
Checks if the RED IER is incoming.
TODO: Write more descriptive docstring with params and returns.
"""

View File

@@ -125,7 +125,8 @@ class PrimaiteSession:
self,
**kwargs,
):
"""Train the agent.
"""
Train the agent.
:param kwargs: Any agent-framework specific key word args.
"""
@@ -136,7 +137,8 @@ class PrimaiteSession:
self,
**kwargs,
):
"""Evaluate the agent.
"""
Evaluate the agent.
:param kwargs: Any agent-framework specific key word args.
"""

View File

@@ -12,7 +12,8 @@ _LOGGER = getLogger(__name__)
def run(overwrite_existing: bool = True):
"""Resets the demo jupyter notebooks in the users app notebooks directory.
"""
Resets the demo jupyter notebooks in the users app notebooks directory.
:param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off.
"""

View File

@@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__)
def run(overwrite_existing=True):
"""Resets the example config files in the users app config directory.
"""
Resets the example config files in the users app config directory.
:param overwrite_existing: A bool to toggle replacing existing edited config on or off.
"""

View File

@@ -5,7 +5,8 @@ _LOGGER = getLogger(__name__)
def run():
"""Handles creation of application directories and user directories.
"""
Handles creation of application directories and user directories.
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
app directories in the correct locations based on the users OS.

View File

@@ -39,7 +39,8 @@ class Transaction(object):
"The env observation space description"
def as_csv_data(self) -> Tuple[List, List]:
"""Converts the Transaction to a csv data row and provides a header.
"""
Converts the Transaction to a csv data row and provides a header.
:return: A tuple consisting of (header, data).
"""
@@ -68,7 +69,8 @@ class Transaction(object):
def _turn_action_space_to_array(action_space) -> List[str]:
"""Turns action space into a string array so it can be saved to csv.
"""
Turns action space into a string array so it can be saved to csv.
:param action_space: The action space
:return: The action space as an array of strings
@@ -80,7 +82,8 @@ def _turn_action_space_to_array(action_space) -> List[str]:
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
"""Turns observation space into a string array so it can be saved to csv.
"""
Turns observation space into a string array so it can be saved to csv.
:param obs_space: The observation space
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space

View File

@@ -10,7 +10,8 @@ _LOGGER = getLogger(__name__)
def get_file_path(path: str) -> Path:
"""Get PrimAITE package data.
"""
Get PrimAITE package data.
:Example:

View File

@@ -7,7 +7,8 @@ import polars as pl
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""Read an average rewards per episode csv file and return as a dict.
"""
Read an average rewards per episode csv file and return as a dict.
The dictionary keys are the episode number, and the values are the mean reward that episode.

View File

@@ -77,7 +77,8 @@ class SessionOutputWriter:
_LOGGER.debug(f"Finished writing file: {self._csv_file_path}")
def write(self, data: Union[Tuple, Transaction]):
"""Write a row of session data.
"""
Write a row of session data.
:param data: The row of data to write. Can be a Tuple or an instance of Transaction.
"""

View File

@@ -75,7 +75,8 @@ class TestNodeLinkTable:
assert env.env_obs.shape == (5, 6)
def test_value(self, temp_primaite_session):
"""Test that the observation is generated correctly.
"""
Test that the observation is generated correctly.
The laydown has:
* 3 nodes (2 service nodes and 1 active node)
@@ -157,7 +158,8 @@ class TestNodeStatuses:
assert env.env_obs.shape == (15,)
def test_values(self, temp_primaite_session):
"""Test that the hardware and software states are encoded correctly.
"""
Test that the hardware and software states are encoded correctly.
The laydown has:
* one node with a compromised operating system state
@@ -213,7 +215,8 @@ class TestLinkTrafficLevels:
assert env.env_obs.shape == (2 * 2,)
def test_values(self, temp_primaite_session):
"""Test that traffic values are encoded correctly.
"""
Test that traffic values are encoded correctly.
The laydown has:
* two services