From 5618283cc592cac951882110522b95e29364c2c0 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 7 Jul 2023 10:28:00 +0100 Subject: [PATCH] Standardise docstring summary line placement. --- src/primaite/__init__.py | 6 +- src/primaite/acl/access_control_list.py | 12 ++-- src/primaite/acl/acl_rule.py | 18 +++-- src/primaite/agents/agent.py | 15 ++-- src/primaite/agents/hardcoded_acl.py | 24 ++++--- src/primaite/agents/hardcoded_node.py | 3 +- src/primaite/agents/rllib.py | 6 +- src/primaite/agents/sb3.py | 6 +- src/primaite/agents/utils.py | 63 +++++++++++------ src/primaite/cli.py | 18 +++-- src/primaite/config/lay_down_config.py | 18 +++-- src/primaite/config/training_config.py | 3 +- src/primaite/data_viz/session_plots.py | 3 +- src/primaite/environment/observations.py | 36 ++++++---- src/primaite/environment/primaite_env.py | 69 ++++++++++++------- src/primaite/environment/reward.py | 15 ++-- src/primaite/links/link.py | 24 ++++--- src/primaite/main.py | 3 +- src/primaite/nodes/active_node.py | 15 ++-- .../nodes/node_state_instruction_green.py | 18 +++-- .../nodes/node_state_instruction_red.py | 30 +++++--- src/primaite/nodes/passive_node.py | 3 +- src/primaite/nodes/service_node.py | 21 ++++-- src/primaite/notebooks/__init__.py | 3 +- src/primaite/pol/green_pol.py | 6 +- src/primaite/pol/ier.py | 36 ++++++---- src/primaite/pol/red_agent_pol.py | 9 ++- src/primaite/primaite_session.py | 6 +- src/primaite/setup/reset_demo_notebooks.py | 3 +- src/primaite/setup/reset_example_configs.py | 3 +- src/primaite/setup/setup_app_dirs.py | 3 +- src/primaite/transactions/transaction.py | 9 ++- src/primaite/utils/package_data.py | 3 +- src/primaite/utils/session_output_reader.py | 3 +- src/primaite/utils/session_output_writer.py | 3 +- tests/test_observation_space.py | 9 ++- 36 files changed, 350 insertions(+), 175 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index a2d157c6..030860d8 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -66,7 +66,8 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``. # region Setup Logging class _LevelFormatter(Formatter): - """A custom level-specific formatter. + """ + A custom level-specific formatter. Credit to: https://stackoverflow.com/a/68154386 """ @@ -134,7 +135,8 @@ _LOGGER.addHandler(_FILE_HANDLER) def getLogger(name: str) -> Logger: # noqa - """Get a PrimAITE logger. + """ + Get a PrimAITE logger. :param name: The logger name. Use ``__name__``. :return: An instance of :py:class:`logging.Logger` with the PrimAITE diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index e1d6aa74..3ac9a8af 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -35,7 +35,8 @@ class AccessControlList: return False def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): - """Checks for rules that block a protocol / port. + """ + Checks for rules that block a protocol / port. Args: _source_ip_address: the source IP address to check @@ -61,7 +62,8 @@ class AccessControlList: return True def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): - """Adds a new rule. + """ + Adds a new rule. Args: _permission: the permission value (e.g. "ALLOW" or "DENY") @@ -75,7 +77,8 @@ class AccessControlList: self.acl[hash_value] = new_rule def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): - """Removes a rule. + """ + Removes a rule. Args: _permission: the permission value (e.g. "ALLOW" or "DENY") @@ -97,7 +100,8 @@ class AccessControlList: self.acl.clear() def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): - """Produces a hash value for a rule. + """ + Produces a hash value for a rule. Args: _permission: the permission value (e.g. "ALLOW" or "DENY") diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index 117c9457..a1fd93f2 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -22,7 +22,8 @@ class ACLRule: self.port = _port def __hash__(self): - """Override the hash function. + """ + Override the hash function. Returns: Returns hash of core parameters. @@ -38,7 +39,8 @@ class ACLRule: ) def get_permission(self): - """Gets the permission attribute. + """ + Gets the permission attribute. Returns: Returns permission attribute @@ -46,7 +48,8 @@ class ACLRule: return self.permission def get_source_ip(self): - """Gets the source IP address attribute. + """ + Gets the source IP address attribute. Returns: Returns source IP address attribute @@ -54,7 +57,8 @@ class ACLRule: return self.source_ip def get_dest_ip(self): - """Gets the desintation IP address attribute. + """ + Gets the desintation IP address attribute. Returns: Returns destination IP address attribute @@ -62,7 +66,8 @@ class ACLRule: return self.dest_ip def get_protocol(self): - """Gets the protocol attribute. + """ + Gets the protocol attribute. Returns: Returns protocol attribute @@ -70,7 +75,8 @@ class ACLRule: return self.protocol def get_port(self): - """Gets the port attribute. + """ + Gets the port attribute. Returns: Returns port attribute diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 7073d795..3b093f86 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -21,7 +21,8 @@ _LOGGER = getLogger(__name__) def get_session_path(session_timestamp: datetime) -> Path: - """Get the directory path the session will output to. + """ + Get the directory path the session will output to. This is set in the format of: ~/primaite/sessions//_. @@ -194,7 +195,8 @@ class AgentSessionABC(ABC): self, **kwargs, ): - """Train the agent. + """ + Train the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -211,7 +213,8 @@ class AgentSessionABC(ABC): self, **kwargs, ): - """Evaluate the agent. + """ + Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -340,7 +343,8 @@ class HardCodedAgentSessionABC(AgentSessionABC): self, **kwargs, ): - """Train the agent. + """ + Train the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -354,7 +358,8 @@ class HardCodedAgentSessionABC(AgentSessionABC): self, **kwargs, ): - """Evaluate the agent. + """ + Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 5cc06bdc..c26bcacf 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -46,7 +46,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers def get_matching_acl_rules_for_ier(self, ier, acl, nodes): - """Get matching ACL rules for an IER. + """ + Get matching ACL rules for an IER. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -62,7 +63,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return matching_rules def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): - """Get blocking ACL rules for an IER. + """ + Get blocking ACL rules for an IER. .. warning:: Can return empty dict but IER can still be blocked by default @@ -81,7 +83,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_rules def get_allow_acl_rules_for_ier(self, ier, acl, nodes): - """Get all allowing ACL rules for an IER. + """ + Get all allowing ACL rules for an IER. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -105,7 +108,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes, services_list, ): - """Get matching ACL rules. + """ + Get matching ACL rules. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -136,7 +140,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes, services_list, ): - """Get the ALLOW ACL rules. + """ + Get the ALLOW ACL rules. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -168,7 +173,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes, services_list, ): - """Get the DENY ACL rules. + """ + Get the DENY ACL rules. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -191,7 +197,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules def _calculate_action_full_view(self, obs): - """Calculate a good acl-based action for the blue agent to take. + """ + Calculate a good acl-based action for the blue agent to take. Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: @@ -355,7 +362,8 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return action def _calculate_action_basic_view(self, obs): - """Calculate a good acl-based action for the blue agent to take. + """ + Calculate a good acl-based action for the blue agent to take. Uses ONLY information from the current observation with NO knowledge of previous actions taken and NO reward feedback. diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 27a2a823..310fc178 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -6,7 +6,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic Node agent.""" def _calculate_action(self, obs): - """Calculate a good node-based action for the blue agent to take. + """ + Calculate a good node-based action for the blue agent to take. TODO: Add params and return in docstring. TODO: Typehint params and return. diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 044b760f..bd5c8585 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -140,7 +140,8 @@ class RLlibAgent(AgentSessionABC): self, **kwargs, ): - """Evaluate the agent. + """ + Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -158,7 +159,8 @@ class RLlibAgent(AgentSessionABC): self, **kwargs, ): - """Evaluate the agent. + """ + Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index b81a0a18..90a24ee2 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -89,7 +89,8 @@ class SB3Agent(AgentSessionABC): self, **kwargs, ): - """Train the agent. + """ + Train the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -109,7 +110,8 @@ class SB3Agent(AgentSessionABC): deterministic: bool = True, **kwargs, ): - """Evaluate the agent. + """ + Evaluate the agent. :param deterministic: Whether the evaluation is deterministic. :param kwargs: Any agent-specific key-word args to be passed. diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 8b3b57f5..0d4a8e2a 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -11,7 +11,8 @@ from primaite.common.enums import ( def transform_action_node_readable(action): - """Convert a node action from enumerated format to readable format. + """ + Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] @@ -33,7 +34,8 @@ def transform_action_node_readable(action): def transform_action_acl_readable(action): - """Transform an ACL action to a more readable format. + """ + Transform an ACL action to a more readable format. example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] @@ -57,7 +59,8 @@ def transform_action_acl_readable(action): def is_valid_node_action(action): - """Is the node action an actual valid action. + """ + Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -92,7 +95,8 @@ def is_valid_node_action(action): def is_valid_acl_action(action): - """Is the ACL action an actual valid action. + """ + Is the ACL action an actual valid action. Only uses information about the action to determine if the action has an effect. @@ -124,7 +128,8 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """Harsher version of valid acl actions, does not allow action. + """ + Harsher version of valid acl actions, does not allow action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -147,7 +152,8 @@ def is_valid_acl_action_extra(action): def transform_change_obs_readable(obs): - """Transform list of transactions to readable list of each observation property. + """ + Transform list of transactions to readable list of each observation property. example: np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] @@ -169,7 +175,8 @@ def transform_change_obs_readable(obs): def transform_obs_readable(obs): - """Transform observation to readable format. + """ + Transform observation to readable format. np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] @@ -185,7 +192,8 @@ def transform_obs_readable(obs): def convert_to_new_obs(obs, num_nodes=10): - """Convert original gym Box observation space to new multiDiscrete observation space. + """ + Convert original gym Box observation space to new multiDiscrete observation space. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -196,7 +204,8 @@ def convert_to_new_obs(obs, num_nodes=10): def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """Convert to old observation. + """ + Convert to old observation. Links filled with 0's as no information is included in new observation space. @@ -232,7 +241,8 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """Return string describing change between two observations. + """ + Return string describing change between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) @@ -260,7 +270,8 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): def _describe_obs_change_helper(obs_change, is_link): - """Helper funcion to describe what has changed. + """ + Helper funcion to describe what has changed. example: [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" @@ -295,7 +306,8 @@ def _describe_obs_change_helper(obs_change, is_link): def transform_action_node_enum(action): - """Convert a node action from readable string format, to enumerated format. + """ + Convert a node action from readable string format, to enumerated format. example: [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] @@ -326,7 +338,8 @@ def transform_action_node_enum(action): def transform_action_node_readable(action): - """Convert a node action from enumerated format to readable format. + """ + Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] @@ -348,7 +361,8 @@ def transform_action_node_readable(action): def node_action_description(action): - """Generate string describing a node-based action. + """ + Generate string describing a node-based action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -375,7 +389,8 @@ def node_action_description(action): def transform_action_acl_enum(action): - """Convert acl action from readable str format, to enumerated format. + """ + Convert acl action from readable str format, to enumerated format. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -397,7 +412,8 @@ def transform_action_acl_enum(action): def acl_action_description(action): - """Generate string describing an acl-based action. + """ + Generate string describing an acl-based action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -417,7 +433,8 @@ def acl_action_description(action): def get_node_of_ip(ip, node_dict): - """Get the node ID of an IP address. + """ + Get the node ID of an IP address. node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) @@ -431,7 +448,8 @@ def get_node_of_ip(ip, node_dict): def is_valid_node_action(action): - """Is the node action an actual valid action. + """ + Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -464,7 +482,8 @@ def is_valid_node_action(action): def is_valid_acl_action(action): - """Is the ACL action an actual valid action. + """ + Is the ACL action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -496,7 +515,8 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """Harsher version of valid acl actions, does not allow action. + """ + Harsher version of valid acl actions, does not allow action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -519,7 +539,8 @@ def is_valid_acl_action_extra(action): def get_new_action(old_action, action_dict): - """Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + """ + Get new action (e.g. 32) from old action e.g. [1,1,1,0]. Old_action can be either node or acl action type diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 42825144..40e8cf0d 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -28,7 +28,8 @@ def build_dirs(): @app.command() def reset_notebooks(overwrite: bool = True): - """Force a reset of the demo notebooks in the users notebooks directory. + """ + Force a reset of the demo notebooks in the users notebooks directory. :param overwrite: If True, will overwrite existing demo notebooks. """ @@ -39,7 +40,8 @@ def reset_notebooks(overwrite: bool = True): @app.command() def logs(last_n: Annotated[int, typer.Option("-n")]): - """Print the PrimAITE log file. + """ + Print the PrimAITE log file. :param last_n: The number of lines to print. Default value is 10. """ @@ -59,7 +61,8 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n @app.command() def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None): - """View or set the PrimAITE Log Level. + """ + View or set the PrimAITE Log Level. To View, simply call: primaite log-level @@ -110,7 +113,8 @@ def clean_up(): @app.command() def setup(overwrite_existing: bool = True): - """Perform the PrimAITE first-time setup. + """ + Perform the PrimAITE first-time setup. WARNING: All user-data will be lost. """ @@ -148,7 +152,8 @@ def setup(overwrite_existing: bool = True): @app.command() def session(tc: Optional[str] = None, ldc: Optional[str] = None): - """Run a PrimAITE session. + """ + Run a PrimAITE session. tc: The training config filepath. Optional. If no value is passed then example default training config is used from: @@ -173,7 +178,8 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None): @app.command() def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None): - """View or set the plotly template for Session plots. + """ + View or set the plotly template for Session plots. To View, simply call: primaite plotly-template diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 587997b7..3a85b9da 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -12,7 +12,8 @@ _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]: - """Convert a legacy lay down config dict to the new format. + """ + Convert a legacy lay down config dict to the new format. :param legacy_config_dict: A legacy lay down config dict. """ @@ -21,7 +22,8 @@ def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> D def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: - """Read in a lay down config yaml file. + """ + Read in a lay down config yaml file. :param file_path: The config file path. :param legacy_file: True if the config file is legacy format, otherwise False. @@ -50,7 +52,8 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: def ddos_basic_one_config_path() -> Path: - """The path to the example lay_down_config_1_DDOS_basic.yaml file. + """ + The path to the example lay_down_config_1_DDOS_basic.yaml file. :return: The file path. """ @@ -64,7 +67,8 @@ def ddos_basic_one_config_path() -> Path: def ddos_basic_two_config_path() -> Path: - """The path to the example lay_down_config_2_DDOS_basic.yaml file. + """ + The path to the example lay_down_config_2_DDOS_basic.yaml file. :return: The file path. """ @@ -78,7 +82,8 @@ def ddos_basic_two_config_path() -> Path: def dos_very_basic_config_path() -> Path: - """The path to the example lay_down_config_3_DOS_very_basic.yaml file. + """ + The path to the example lay_down_config_3_DOS_very_basic.yaml file. :return: The file path. """ @@ -92,7 +97,8 @@ def dos_very_basic_config_path() -> Path: def data_manipulation_config_path() -> Path: - """The path to the example lay_down_config_5_data_manipulation.yaml file. + """ + The path to the example lay_down_config_5_data_manipulation.yaml file. :return: The file path. """ diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 7bdf7995..30edb79b 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -180,7 +180,8 @@ class TrainingConfig: @classmethod def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: - """Create an instance of TrainingConfig from a dict. + """ + Create an instance of TrainingConfig from a dict. :param config_dict: The training config dict. :return: The instance of TrainingConfig. diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py index 542c6677..245b9774 100644 --- a/src/primaite/data_viz/session_plots.py +++ b/src/primaite/data_viz/session_plots.py @@ -22,7 +22,8 @@ def plot_av_reward_per_episode( title: Optional[str] = None, subtitle: Optional[str] = None, ) -> Figure: - """Plot the average reward per episode from a csv session output. + """ + Plot the average reward per episode from a csv session output. :param av_reward_per_episode_csv: The average reward per episode csv file path. diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 28e85b7f..53c173fd 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -50,7 +50,8 @@ class AbstractObservationComponent(ABC): class NodeLinkTable(AbstractObservationComponent): - """Table with nodes and links as rows and hardware/software status as cols. + """ + Table with nodes and links as rows and hardware/software status as cols. This will create the observation space formatted as a table of integers. There is one row per node, followed by one row per link. @@ -74,7 +75,8 @@ class NodeLinkTable(AbstractObservationComponent): _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): - """Initialise a NodeLinkTable observation space component. + """ + Initialise a NodeLinkTable observation space component. :param env: Training environment. :type env: Primaite @@ -100,7 +102,8 @@ class NodeLinkTable(AbstractObservationComponent): self.structure = self.generate_structure() def update(self): - """Update the observation based on current environment state. + """ + Update the observation based on current environment state. The structure of the observation space is described in :class:`.NodeLinkTable` """ @@ -181,7 +184,8 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """Flat list of nodes' hardware, OS, file system, and service states. + """ + Flat list of nodes' hardware, OS, file system, and service states. The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers. @@ -234,7 +238,8 @@ class NodeStatuses(AbstractObservationComponent): self.structure = self.generate_structure() def update(self): - """Update the observation based on current environment state. + """ + Update the observation based on current environment state. The structure of the observation space is described in :class:`.NodeStatuses` """ @@ -287,7 +292,8 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """Flat list of traffic levels encoded into banded categories. + """ + Flat list of traffic levels encoded into banded categories. For each link, total traffic or traffic per service is encoded into a categorical value. For example, if ``quantisation_levels=5``, the traffic levels represent these values: @@ -354,7 +360,8 @@ class LinkTrafficLevels(AbstractObservationComponent): self.structure = self.generate_structure() def update(self): - """Update the observation based on current environment state. + """ + Update the observation based on current environment state. The structure of the observation space is described in :class:`.LinkTrafficLevels` """ @@ -395,7 +402,8 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: - """Component-based observation space handler. + """ + Component-based observation space handler. This allows users to configure observation spaces by mixing and matching components. Each component can also define further parameters to make them more flexible. @@ -436,7 +444,8 @@ class ObservationsHandler: self._flat_observation = spaces.flatten(self._space, self._observation) def register(self, obs_component: AbstractObservationComponent): - """Add a component for this handler to track. + """ + Add a component for this handler to track. :param obs_component: The component to add. :type obs_component: AbstractObservationComponent @@ -445,7 +454,8 @@ class ObservationsHandler: self.update_space() def deregister(self, obs_component: AbstractObservationComponent): - """Remove a component from this handler. + """ + Remove a component from this handler. :param obs_component: Which component to remove. It must exist within this object's ``registered_obs_components`` attribute. @@ -488,7 +498,8 @@ class ObservationsHandler: @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): - """Parse a config dictinary, return a new observation handler populated with new observation component objects. + """ + Parse a config dictinary, return a new observation handler populated with new observation component objects. The expected format for the config dictionary is: @@ -533,7 +544,8 @@ class ObservationsHandler: return handler def describe_structure(self): - """Create a list of names for the features of the obs space. + """ + Create a list of names for the features of the obs space. The order of labels follows the flattened version of the space. """ diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 825818fd..9a5df13a 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -255,7 +255,8 @@ class Primaite(Env): self.total_step_count = 0 def reset(self): - """AI Gym Reset function. + """ + AI Gym Reset function. Returns: Environment observation space (reset) @@ -291,7 +292,8 @@ class Primaite(Env): return self.env_obs def step(self, action): - """AI Gym Step function. + """ + AI Gym Step function. Args: action: Action space from agent @@ -429,7 +431,8 @@ class Primaite(Env): print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) def interpret_action_and_apply(self, _action): - """Applies agent actions to the nodes and Access Control List. + """ + Applies agent actions to the nodes and Access Control List. Args: _action: The action space from the agent @@ -448,7 +451,8 @@ class Primaite(Env): logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): - """Applies agent actions to the nodes. + """ + Applies agent actions to the nodes. Args: _action: The action space from the agent @@ -535,7 +539,8 @@ class Primaite(Env): return def apply_actions_to_acl(self, _action): - """Applies agent actions to the Access Control List [TO DO]. + """ + Applies agent actions to the Access Control List [TO DO]. Args: _action: The action space from the agent @@ -612,7 +617,8 @@ class Primaite(Env): return def apply_time_based_updates(self): - """Updates anything that needs to count down and then change state. + """ + Updates anything that needs to count down and then change state. e.g. reset / patching status """ @@ -653,7 +659,8 @@ class Primaite(Env): pass def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Create the environment's observation handler. + """ + Create the environment's observation handler. :return: The observation space, initial observation (zeroed out array with the correct shape) :rtype: Tuple[spaces.Space, np.ndarray] @@ -709,7 +716,8 @@ class Primaite(Env): print("Environment configuration loaded") def create_node(self, item): - """Creates a node from config data. + """ + Creates a node from config data. Args: item: A config data item @@ -789,7 +797,8 @@ class Primaite(Env): self.network_reference.add_nodes_from([node_ref]) def create_link(self, item: Dict): - """Creates a link from config data. + """ + Creates a link from config data. Args: item: A config data item @@ -832,7 +841,8 @@ class Primaite(Env): ) def create_green_ier(self, item): - """Creates a green IER from config data. + """ + Creates a green IER from config data. Args: item: A config data item @@ -872,7 +882,8 @@ class Primaite(Env): ) def create_red_ier(self, item): - """Creates a red IER from config data. + """ + Creates a red IER from config data. Args: item: A config data item @@ -901,7 +912,8 @@ class Primaite(Env): ) def create_green_pol(self, item): - """Creates a green PoL object from config data. + """ + Creates a green PoL object from config data. Args: item: A config data item @@ -934,7 +946,8 @@ class Primaite(Env): ) def create_red_pol(self, item): - """Creates a red PoL object from config data. + """ + Creates a red PoL object from config data. Args: item: A config data item @@ -974,7 +987,8 @@ class Primaite(Env): ) def create_acl_rule(self, item): - """Creates an ACL rule from config data. + """ + Creates an ACL rule from config data. Args: item: A config data item @@ -994,7 +1008,8 @@ class Primaite(Env): ) def create_services_list(self, services): - """Creates a list of services (enum) from config data. + """ + Creates a list of services (enum) from config data. Args: item: A config data item representing the services @@ -1009,7 +1024,8 @@ class Primaite(Env): self.num_services = len(self.services_list) def create_ports_list(self, ports): - """Creates a list of ports from config data. + """ + Creates a list of ports from config data. Args: item: A config data item representing the ports @@ -1024,7 +1040,8 @@ class Primaite(Env): self.num_ports = len(self.ports_list) def get_observation_info(self, observation_info): - """Extracts observation_info. + """ + Extracts observation_info. :param observation_info: Config item that defines which type of observation space to use :type observation_info: str @@ -1032,7 +1049,8 @@ class Primaite(Env): self.observation_type = ObservationType[observation_info["type"]] def get_action_info(self, action_info): - """Extracts action_info. + """ + Extracts action_info. Args: item: A config data item representing action info @@ -1040,7 +1058,8 @@ class Primaite(Env): self.action_type = ActionType[action_info["type"]] def save_obs_config(self, obs_config: dict): - """Cache the config for the observation space. + """ + Cache the config for the observation space. This is necessary as the observation space can't be built while reading the config, it must be done after all the nodes, links, and services have been initialised. @@ -1052,7 +1071,8 @@ class Primaite(Env): self.obs_config = obs_config def reset_environment(self): - """# Resets environment. + """ + Resets environment. Uses config data config data in order to build the environment configuration. """ @@ -1076,7 +1096,8 @@ class Primaite(Env): ier_value.set_is_running(False) def reset_node(self, item): - """Resets the statuses of a node. + """ + Resets the statuses of a node. Args: item: A config data item @@ -1123,7 +1144,8 @@ class Primaite(Env): pass def create_node_action_dict(self): - """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. + """ + Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. Note: Only actions that have the potential to change the state exist in the mapping (except for key 0) @@ -1187,7 +1209,8 @@ class Primaite(Env): return actions def create_node_and_acl_action_dict(self): - """Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. + """ + Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. The dictionary contains actions of both Node and ACL action types. """ diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 5cef47ef..19094a18 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -21,7 +21,8 @@ def calculate_reward_function( step_count, config_values, ): - """Compares the states of the initial and final nodes/links to get a reward. + """ + Compares the states of the initial and final nodes/links to get a reward. Args: initial_nodes: The nodes before red and blue agents take effect @@ -94,7 +95,8 @@ def calculate_reward_function( def score_node_operating_state(final_node, initial_node, reference_node, config_values): - """Calculates score relating to the hardware state of a node. + """ + Calculates score relating to the hardware state of a node. Args: final_node: The node after red and blue agents take effect @@ -142,7 +144,8 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ def score_node_os_state(final_node, initial_node, reference_node, config_values): - """Calculates score relating to the Software State of a node. + """ + Calculates score relating to the Software State of a node. Args: final_node: The node after red and blue agents take effect @@ -192,7 +195,8 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) def score_node_service_state(final_node, initial_node, reference_node, config_values): - """Calculates score relating to the service state(s) of a node. + """ + Calculates score relating to the service state(s) of a node. Args: final_node: The node after red and blue agents take effect @@ -263,7 +267,8 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va def score_node_file_system(final_node, initial_node, reference_node, config_values): - """Calculates score relating to the file system state of a node. + """ + Calculates score relating to the file system state of a node. Args: final_node: The node after red and blue agents take effect diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index 5892b8e2..f61281cd 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -29,7 +29,8 @@ class Link(object): self.add_protocol(protocol_name) def add_protocol(self, _protocol): - """Adds a new protocol to the list of protocols on this link. + """ + Adds a new protocol to the list of protocols on this link. Args: _protocol: The protocol to be added (enum) @@ -37,7 +38,8 @@ class Link(object): self.protocol_list.append(Protocol(_protocol)) def get_id(self): - """Gets link ID. + """ + Gets link ID. Returns: Link ID @@ -45,7 +47,8 @@ class Link(object): return self.id def get_source_node_name(self): - """Gets source node name. + """ + Gets source node name. Returns: Source node name @@ -53,7 +56,8 @@ class Link(object): return self.source_node_name def get_dest_node_name(self): - """Gets destination node name. + """ + Gets destination node name. Returns: Destination node name @@ -61,7 +65,8 @@ class Link(object): return self.dest_node_name def get_bandwidth(self): - """Gets bandwidth of link. + """ + Gets bandwidth of link. Returns: Link bandwidth (bps) @@ -69,7 +74,8 @@ class Link(object): return self.bandwidth def get_protocol_list(self): - """Gets list of protocols on this link. + """ + Gets list of protocols on this link. Returns: List of protocols on this link @@ -77,7 +83,8 @@ class Link(object): return self.protocol_list def get_current_load(self): - """Gets current total load on this link. + """ + Gets current total load on this link. Returns: Total load on this link (bps) @@ -88,7 +95,8 @@ class Link(object): return total_load def add_protocol_load(self, _protocol, _load): - """Adds a loading to a protocol on this link. + """ + Adds a loading to a protocol on this link. Args: _protocol: The protocol to load diff --git a/src/primaite/main.py b/src/primaite/main.py index 7b1d7ab3..f2d1b9c2 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -14,7 +14,8 @@ def run( training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], ): - """Run the PrimAITE Session. + """ + Run the PrimAITE Session. :param training_config_path: The training config filepath. :param lay_down_config_path: The lay down config filepath. diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index 3789b7a4..f86f818b 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -52,7 +52,8 @@ class ActiveNode(Node): @property def software_state(self) -> SoftwareState: - """Get the software_state. + """ + Get the software_state. :return: The software_state. """ @@ -60,7 +61,8 @@ class ActiveNode(Node): @software_state.setter def software_state(self, software_state: SoftwareState): - """Get the software_state. + """ + Get the software_state. :param software_state: Software State. """ @@ -78,7 +80,8 @@ class ActiveNode(Node): ) def set_software_state_if_not_compromised(self, software_state: SoftwareState): - """Sets Software State if the node is not compromised. + """ + Sets Software State if the node is not compromised. Args: software_state: Software State @@ -104,7 +107,8 @@ class ActiveNode(Node): self._software_state = SoftwareState.GOOD def set_file_system_state(self, file_system_state: FileSystemState): - """Sets the file system state (actual and observed). + """ + Sets the file system state (actual and observed). Args: file_system_state: File system state @@ -130,7 +134,8 @@ class ActiveNode(Node): ) def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState): - """Sets the file system state (actual and observed) if not in a compromised state. + """ + Sets the file system state (actual and observed) if not in a compromised state. Use for green PoL to prevent it overturning a compromised state diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index da4be35e..7ebe3886 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -35,7 +35,8 @@ class NodeStateInstructionGreen(object): self.state = _state def get_start_step(self): - """Gets the start step. + """ + Gets the start step. Returns: The start step @@ -43,7 +44,8 @@ class NodeStateInstructionGreen(object): return self.start_step def get_end_step(self): - """Gets the end step. + """ + Gets the end step. Returns: The end step @@ -51,7 +53,8 @@ class NodeStateInstructionGreen(object): return self.end_step def get_node_id(self): - """Gets the node ID. + """ + Gets the node ID. Returns: The node ID @@ -59,7 +62,8 @@ class NodeStateInstructionGreen(object): return self.node_id def get_node_pol_type(self): - """Gets the node pattern of life type (enum). + """ + Gets the node pattern of life type (enum). Returns: The node pattern of life type (enum) @@ -67,7 +71,8 @@ class NodeStateInstructionGreen(object): return self.node_pol_type def get_service_name(self): - """Gets the service name. + """ + Gets the service name. Returns: The service name @@ -75,7 +80,8 @@ class NodeStateInstructionGreen(object): return self.service_name def get_state(self): - """Gets the state (node or service). + """ + Gets the state (node or service). Returns: The state (node or service) diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index f8ce4e74..540625cc 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -51,7 +51,8 @@ class NodeStateInstructionRed(object): self.source_node_service_state = _pol_source_node_service_state def get_start_step(self): - """Gets the start step. + """ + Gets the start step. Returns: The start step @@ -59,7 +60,8 @@ class NodeStateInstructionRed(object): return self.start_step def get_end_step(self): - """Gets the end step. + """ + Gets the end step. Returns: The end step @@ -67,7 +69,8 @@ class NodeStateInstructionRed(object): return self.end_step def get_target_node_id(self): - """Gets the node ID. + """ + Gets the node ID. Returns: The node ID @@ -75,7 +78,8 @@ class NodeStateInstructionRed(object): return self.target_node_id def get_initiator(self): - """Gets the initiator. + """ + Gets the initiator. Returns: The initiator @@ -83,7 +87,8 @@ class NodeStateInstructionRed(object): return self.initiator def get_pol_type(self) -> NodePOLType: - """Gets the node pattern of life type (enum). + """ + Gets the node pattern of life type (enum). Returns: The node pattern of life type (enum) @@ -91,7 +96,8 @@ class NodeStateInstructionRed(object): return self.pol_type def get_service_name(self): - """Gets the service name. + """ + Gets the service name. Returns: The service name @@ -99,7 +105,8 @@ class NodeStateInstructionRed(object): return self.service_name def get_state(self): - """Gets the state (node or service). + """ + Gets the state (node or service). Returns: The state (node or service) @@ -107,7 +114,8 @@ class NodeStateInstructionRed(object): return self.state def get_source_node_id(self): - """Gets the source node id (used for initiator type SERVICE). + """ + Gets the source node id (used for initiator type SERVICE). Returns: The source node id @@ -115,7 +123,8 @@ class NodeStateInstructionRed(object): return self.source_node_id def get_source_node_service(self): - """Gets the source node service (used for initiator type SERVICE). + """ + Gets the source node service (used for initiator type SERVICE). Returns: The source node service @@ -123,7 +132,8 @@ class NodeStateInstructionRed(object): return self.source_node_service def get_source_node_service_state(self): - """Gets the source node service state (used for initiator type SERVICE). + """ + Gets the source node service state (used for initiator type SERVICE). Returns: The source node service state diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index 13b2d6ad..afe4e2d1 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -32,7 +32,8 @@ class PassiveNode(Node): @property def ip_address(self) -> str: - """Gets the node IP address as an empty string. + """ + Gets the node IP address as an empty string. No concept of IP address for passive nodes for now. diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 7632e944..4ad52a1e 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -53,14 +53,16 @@ class ServiceNode(ActiveNode): self.services: Dict[str, Service] = {} def add_service(self, service: Service): - """Adds a service to the node. + """ + Adds a service to the node. :param service: The service to add """ self.services[service.name] = service def has_service(self, protocol_name: str) -> bool: - """Indicates whether a service is on a node. + """ + Indicates whether a service is on a node. :param protocol_name: The service (protocol)e. :return: True if service (protocol) is on the node, otherwise False. @@ -71,7 +73,8 @@ class ServiceNode(ActiveNode): return False def service_running(self, protocol_name: str) -> bool: - """Indicates whether a service is in a running state on the node. + """ + Indicates whether a service is in a running state on the node. :param protocol_name: The service (protocol) :return: True if service (protocol) is in a running state on the node, otherwise False. @@ -85,7 +88,8 @@ class ServiceNode(ActiveNode): return False def service_is_overwhelmed(self, protocol_name: str) -> bool: - """Indicates whether a service is in an overwhelmed state on the node. + """ + Indicates whether a service is in an overwhelmed state on the node. :param protocol_name: The service (protocol) :return: True if service (protocol) is in an overwhelmed state on the node, otherwise False. @@ -99,7 +103,8 @@ class ServiceNode(ActiveNode): return False def set_service_state(self, protocol_name: str, software_state: SoftwareState): - """Sets the software_state of a service (protocol) on the node. + """ + Sets the software_state of a service (protocol) on the node. :param protocol_name: The service (protocol). :param software_state: The software_state. @@ -127,7 +132,8 @@ class ServiceNode(ActiveNode): ) def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState): - """Sets the software_state of a service (protocol) on the node. + """ + Sets the software_state of a service (protocol) on the node. Done if the software_state is not "compromised". @@ -153,7 +159,8 @@ class ServiceNode(ActiveNode): ) def get_service_state(self, protocol_name): - """Gets the state of a service. + """ + Gets the state of a service. :return: The software_state of the service. """ diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index da65da38..6ca1d3f6 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__) def start_jupyter_session(): - """Starts a new Jupyter notebook session in the app notebooks directory. + """ + Starts a new Jupyter notebook session in the app notebooks directory. Currently only works on Windows OS. diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index 91a6f787..e9dfef8c 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -25,7 +25,8 @@ def apply_iers( acl: AccessControlList, step: int, ): - """Applies IERs to the links (link pattern of life). + """ + Applies IERs to the links (link pattern of life). Args: network: The network modelled in the environment @@ -217,7 +218,8 @@ def apply_node_pol( node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], step: int, ): - """Applies node pattern of life. + """ + Applies node pattern of life. Args: nodes: The nodes within the environment diff --git a/src/primaite/pol/ier.py b/src/primaite/pol/ier.py index 913a06da..2de8fe6f 100644 --- a/src/primaite/pol/ier.py +++ b/src/primaite/pol/ier.py @@ -1,5 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -"""Information Exchange Requirements for APE. +""" +Information Exchange Requirements for APE. Used to represent an information flow from source to destination. """ @@ -47,7 +48,8 @@ class IER(object): self.running = _running def get_id(self): - """Gets IER ID. + """ + Gets IER ID. Returns: IER ID @@ -55,7 +57,8 @@ class IER(object): return self.id def get_start_step(self): - """Gets IER start step. + """ + Gets IER start step. Returns: IER start step @@ -63,7 +66,8 @@ class IER(object): return self.start_step def get_end_step(self): - """Gets IER end step. + """ + Gets IER end step. Returns: IER end step @@ -71,7 +75,8 @@ class IER(object): return self.end_step def get_load(self): - """Gets IER load. + """ + Gets IER load. Returns: IER load @@ -79,7 +84,8 @@ class IER(object): return self.load def get_protocol(self): - """Gets IER protocol. + """ + Gets IER protocol. Returns: IER protocol @@ -87,7 +93,8 @@ class IER(object): return self.protocol def get_port(self): - """Gets IER port. + """ + Gets IER port. Returns: IER port @@ -95,7 +102,8 @@ class IER(object): return self.port def get_source_node_id(self): - """Gets IER source node ID. + """ + Gets IER source node ID. Returns: IER source node ID @@ -103,7 +111,8 @@ class IER(object): return self.source_node_id def get_dest_node_id(self): - """Gets IER destination node ID. + """ + Gets IER destination node ID. Returns: IER destination node ID @@ -111,7 +120,8 @@ class IER(object): return self.dest_node_id def get_is_running(self): - """Informs whether the IER is currently running. + """ + Informs whether the IER is currently running. Returns: True if running @@ -119,7 +129,8 @@ class IER(object): return self.running def set_is_running(self, _value): - """Sets the running state of the IER. + """ + Sets the running state of the IER. Args: _value: running status @@ -127,7 +138,8 @@ class IER(object): self.running = _value def get_mission_criticality(self): - """Gets the IER mission criticality (used in the reward function). + """ + Gets the IER mission criticality (used in the reward function). Returns: Mission criticality value (0 lowest to 5 highest) diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index 86482903..bff19bf8 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -24,7 +24,8 @@ def apply_red_agent_iers( acl: AccessControlList, step: int, ): - """Applies IERs to the links (link POL) resulting from red agent attack. + """ + Applies IERs to the links (link POL) resulting from red agent attack. Args: network: The network modelled in the environment @@ -213,7 +214,8 @@ def apply_red_agent_node_pol( node_pol: Dict[str, NodeStateInstructionRed], step: int, ): - """Applies node pattern of life. + """ + Applies node pattern of life. Args: nodes: The nodes within the environment @@ -295,7 +297,8 @@ def apply_red_agent_node_pol( def is_red_ier_incoming(node, iers, node_pol_type): - """Checks if the RED IER is incoming. + """ + Checks if the RED IER is incoming. TODO: Write more descriptive docstring with params and returns. """ diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 1bfb7403..caa85e9e 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -125,7 +125,8 @@ class PrimaiteSession: self, **kwargs, ): - """Train the agent. + """ + Train the agent. :param kwargs: Any agent-framework specific key word args. """ @@ -136,7 +137,8 @@ class PrimaiteSession: self, **kwargs, ): - """Evaluate the agent. + """ + Evaluate the agent. :param kwargs: Any agent-framework specific key word args. """ diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 8d2a94c7..793f9ade 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -12,7 +12,8 @@ _LOGGER = getLogger(__name__) def run(overwrite_existing: bool = True): - """Resets the demo jupyter notebooks in the users app notebooks directory. + """ + Resets the demo jupyter notebooks in the users app notebooks directory. :param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off. """ diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index a2e1f2c9..599de8dc 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -11,7 +11,8 @@ _LOGGER = getLogger(__name__) def run(overwrite_existing=True): - """Resets the example config files in the users app config directory. + """ + Resets the example config files in the users app config directory. :param overwrite_existing: A bool to toggle replacing existing edited config on or off. """ diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py index bf7dbe59..693b11c1 100644 --- a/src/primaite/setup/setup_app_dirs.py +++ b/src/primaite/setup/setup_app_dirs.py @@ -5,7 +5,8 @@ _LOGGER = getLogger(__name__) def run(): - """Handles creation of application directories and user directories. + """ + Handles creation of application directories and user directories. Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required app directories in the correct locations based on the users OS. diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index a74ef4f9..3a5a13db 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -39,7 +39,8 @@ class Transaction(object): "The env observation space description" def as_csv_data(self) -> Tuple[List, List]: - """Converts the Transaction to a csv data row and provides a header. + """ + Converts the Transaction to a csv data row and provides a header. :return: A tuple consisting of (header, data). """ @@ -68,7 +69,8 @@ class Transaction(object): def _turn_action_space_to_array(action_space) -> List[str]: - """Turns action space into a string array so it can be saved to csv. + """ + Turns action space into a string array so it can be saved to csv. :param action_space: The action space :return: The action space as an array of strings @@ -80,7 +82,8 @@ def _turn_action_space_to_array(action_space) -> List[str]: def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]: - """Turns observation space into a string array so it can be saved to csv. + """ + Turns observation space into a string array so it can be saved to csv. :param obs_space: The observation space :param obs_assets: The number of assets (i.e. nodes or links) in the observation space diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py index 463a4309..59f36851 100644 --- a/src/primaite/utils/package_data.py +++ b/src/primaite/utils/package_data.py @@ -10,7 +10,8 @@ _LOGGER = getLogger(__name__) def get_file_path(path: str) -> Path: - """Get PrimAITE package data. + """ + Get PrimAITE package data. :Example: diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index 6b5cfdc3..e70c98e2 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -7,7 +7,8 @@ import polars as pl def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: - """Read an average rewards per episode csv file and return as a dict. + """ + Read an average rewards per episode csv file and return as a dict. The dictionary keys are the episode number, and the values are the mean reward that episode. diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index 5852a84d..104acc62 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -77,7 +77,8 @@ class SessionOutputWriter: _LOGGER.debug(f"Finished writing file: {self._csv_file_path}") def write(self, data: Union[Tuple, Transaction]): - """Write a row of session data. + """ + Write a row of session data. :param data: The row of data to write. Can be a Tuple or an instance of Transaction. """ diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index d1082049..d5844fd9 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -75,7 +75,8 @@ class TestNodeLinkTable: assert env.env_obs.shape == (5, 6) def test_value(self, temp_primaite_session): - """Test that the observation is generated correctly. + """ + Test that the observation is generated correctly. The laydown has: * 3 nodes (2 service nodes and 1 active node) @@ -157,7 +158,8 @@ class TestNodeStatuses: assert env.env_obs.shape == (15,) def test_values(self, temp_primaite_session): - """Test that the hardware and software states are encoded correctly. + """ + Test that the hardware and software states are encoded correctly. The laydown has: * one node with a compromised operating system state @@ -213,7 +215,8 @@ class TestLinkTrafficLevels: assert env.env_obs.shape == (2 * 2,) def test_values(self, temp_primaite_session): - """Test that traffic values are encoded correctly. + """ + Test that traffic values are encoded correctly. The laydown has: * two services