diff --git a/docs/conf.py b/docs/conf.py index 4e22ebc8..d6923446 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,4 +55,4 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] html_theme = "furo" html_static_path = ["_static"] -html_favicon = 'source/primaite.ico' \ No newline at end of file +html_favicon = "source/primaite.ico" diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 030860d8..a2d157c6 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -66,8 +66,7 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``. # region Setup Logging class _LevelFormatter(Formatter): - """ - A custom level-specific formatter. + """A custom level-specific formatter. Credit to: https://stackoverflow.com/a/68154386 """ @@ -135,8 +134,7 @@ _LOGGER.addHandler(_FILE_HANDLER) def getLogger(name: str) -> Logger: # noqa - """ - Get a PrimAITE logger. + """Get a PrimAITE logger. :param name: The logger name. Use ``__name__``. :return: An instance of :py:class:`logging.Logger` with the PrimAITE diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 3b0e9234..42460b94 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -13,8 +13,7 @@ class AccessControlList: self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): - """ - Checks for IP address matches. + """Checks for IP address matches. Args: _rule: The rule being checked @@ -35,8 +34,7 @@ class AccessControlList: return False def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): - """ - Checks for rules that block a protocol / port. + """Checks for rules that block a protocol / port. Args: _source_ip_address: the source IP address to check @@ -62,8 +60,7 @@ class AccessControlList: return True def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): - """ - Adds a new rule. + """Adds a new rule. Args: _permission: the permission value (e.g. "ALLOW" or "DENY") @@ -77,8 +74,7 @@ class AccessControlList: self.acl[hash_value] = new_rule def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): - """ - Removes a rule. + """Removes a rule. Args: _permission: the permission value (e.g. "ALLOW" or "DENY") @@ -100,8 +96,7 @@ class AccessControlList: self.acl.clear() def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): - """ - Produces a hash value for a rule. + """Produces a hash value for a rule. Args: _permission: the permission value (e.g. "ALLOW" or "DENY") diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index 05daecc4..29f52f88 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -6,8 +6,7 @@ class ACLRule: """Access Control List Rule class.""" def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port): - """ - Init. + """Init. Args: _permission: The permission (ALLOW or DENY) @@ -23,8 +22,7 @@ class ACLRule: self.port = _port def __hash__(self): - """ - Override the hash function. + """Override the hash function. Returns: Returns hash of core parameters. @@ -40,8 +38,7 @@ class ACLRule: ) def get_permission(self): - """ - Gets the permission attribute. + """Gets the permission attribute. Returns: Returns permission attribute @@ -49,8 +46,7 @@ class ACLRule: return self.permission def get_source_ip(self): - """ - Gets the source IP address attribute. + """Gets the source IP address attribute. Returns: Returns source IP address attribute @@ -58,8 +54,7 @@ class ACLRule: return self.source_ip def get_dest_ip(self): - """ - Gets the desintation IP address attribute. + """Gets the desintation IP address attribute. Returns: Returns destination IP address attribute @@ -67,8 +62,7 @@ class ACLRule: return self.dest_ip def get_protocol(self): - """ - Gets the protocol attribute. + """Gets the protocol attribute. Returns: Returns protocol attribute @@ -76,8 +70,7 @@ class ACLRule: return self.protocol def get_port(self): - """ - Gets the port attribute. + """Gets the port attribute. Returns: Returns port attribute diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 685fe776..a43f2d0b 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -21,8 +21,7 @@ _LOGGER = getLogger(__name__) def get_session_path(session_timestamp: datetime) -> Path: - """ - Get the directory path the session will output to. + """Get the directory path the session will output to. This is set in the format of: ~/primaite/sessions//_. @@ -39,11 +38,10 @@ def get_session_path(session_timestamp: datetime) -> Path: class AgentSessionABC(ABC): - """ - An ABC that manages training and/or evaluation of agents in PrimAITE. + """An ABC that manages training and/or evaluation of agents in PrimAITE. - This class cannot be directly instantiated and must be inherited from - with all implemented abstract methods implemented. + This class cannot be directly instantiated and must be inherited from with all implemented abstract methods + implemented. """ @abstractmethod @@ -186,8 +184,7 @@ class AgentSessionABC(ABC): self, **kwargs, ): - """ - Train the agent. + """Train the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -204,8 +201,7 @@ class AgentSessionABC(ABC): self, **kwargs, ): - """ - Evaluate the agent. + """Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -293,11 +289,10 @@ class AgentSessionABC(ABC): class HardCodedAgentSessionABC(AgentSessionABC): - """ - An Agent Session ABC for evaluation deterministic agents. + """An Agent Session ABC for evaluation deterministic agents. - This class cannot be directly instantiated and must be inherited from - with all implemented abstract methods implemented. + This class cannot be directly instantiated and must be inherited from with all implemented abstract methods + implemented. """ def __init__(self, training_config_path, lay_down_config_path): @@ -325,8 +320,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): self, **kwargs, ): - """ - Train the agent. + """Train the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -340,8 +334,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): self, **kwargs, ): - """ - Evaluate the agent. + """Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 263ccbdc..9ed9fd28 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -23,8 +23,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return self._calculate_action_full_view(obs) def get_blocked_green_iers(self, green_iers, acl, nodes): - """ - Get blocked green IERs. + """Get blocked green IERs. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -46,8 +45,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_green_iers def get_matching_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get matching ACL rules for an IER. + """Get matching ACL rules for an IER. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -63,8 +61,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return matching_rules def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get blocking ACL rules for an IER. + """Get blocking ACL rules for an IER. .. warning:: Can return empty dict but IER can still be blocked by default @@ -83,8 +80,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return blocked_rules def get_allow_acl_rules_for_ier(self, ier, acl, nodes): - """ - Get all allowing ACL rules for an IER. + """Get all allowing ACL rules for an IER. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -108,8 +104,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes, services_list, ): - """ - Get matching ACL rules. + """Get matching ACL rules. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -140,8 +135,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes, services_list, ): - """ - Get the ALLOW ACL rules. + """Get the ALLOW ACL rules. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -173,8 +167,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes, services_list, ): - """ - Get the DENY ACL rules. + """Get the DENY ACL rules. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -197,8 +190,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules def _calculate_action_full_view(self, obs): - """ - Calculate a good acl-based action for the blue agent to take. + """Calculate a good acl-based action for the blue agent to take. Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 310fc178..27a2a823 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -6,8 +6,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic Node agent.""" def _calculate_action(self, obs): - """ - Calculate a good node-based action for the blue agent to take. + """Calculate a good node-based action for the blue agent to take. TODO: Add params and return in docstring. TODO: Typehint params and return. diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d851ba9c..20503459 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -128,8 +128,7 @@ class RLlibAgent(AgentSessionABC): self, **kwargs, ): - """ - Evaluate the agent. + """Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -147,8 +146,7 @@ class RLlibAgent(AgentSessionABC): self, **kwargs, ): - """ - Evaluate the agent. + """Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index f5ac44cb..58148d1f 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -77,8 +77,7 @@ class SB3Agent(AgentSessionABC): self, **kwargs, ): - """ - Train the agent. + """Train the agent. :param kwargs: Any agent-specific key-word args to be passed. """ @@ -98,8 +97,7 @@ class SB3Agent(AgentSessionABC): deterministic: bool = True, **kwargs, ): - """ - Evaluate the agent. + """Evaluate the agent. :param deterministic: Whether the evaluation is deterministic. :param kwargs: Any agent-specific key-word args to be passed. diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index 5a6c9da5..df93e56d 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -3,8 +3,7 @@ from primaite.agents.utils import get_new_action, transform_action_acl_enum, tra class RandomAgent(HardCodedAgentSessionABC): - """ - A Random Agent. + """A Random Agent. Get a completely random action from the action space. """ @@ -14,11 +13,9 @@ class RandomAgent(HardCodedAgentSessionABC): class DummyAgent(HardCodedAgentSessionABC): - """ - A Dummy Agent. + """A Dummy Agent. - All action spaces setup so dummy action is always 0 regardless of action - type used. + All action spaces setup so dummy action is always 0 regardless of action type used. """ def _calculate_action(self, obs): @@ -26,8 +23,7 @@ class DummyAgent(HardCodedAgentSessionABC): class DoNothingACLAgent(HardCodedAgentSessionABC): - """ - A do nothing ACL agent. + """A do nothing ACL agent. A valid ACL action that has no effect; does nothing. """ @@ -41,8 +37,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC): class DoNothingNodeAgent(HardCodedAgentSessionABC): - """ - A do nothing Node agent. + """A do nothing Node agent. A valid Node action that has no effect; does nothing. """ diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 8c59faf7..8b3b57f5 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -11,8 +11,7 @@ from primaite.common.enums import ( def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. + """Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] @@ -34,8 +33,7 @@ def transform_action_node_readable(action): def transform_action_acl_readable(action): - """ - Transform an ACL action to a more readable format. + """Transform an ACL action to a more readable format. example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] @@ -94,8 +92,7 @@ def is_valid_node_action(action): def is_valid_acl_action(action): - """ - Is the ACL action an actual valid action. + """Is the ACL action an actual valid action. Only uses information about the action to determine if the action has an effect. @@ -127,8 +124,7 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """ - Harsher version of valid acl actions, does not allow action. + """Harsher version of valid acl actions, does not allow action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -151,8 +147,7 @@ def is_valid_acl_action_extra(action): def transform_change_obs_readable(obs): - """ - Transform list of transactions to readable list of each observation property. + """Transform list of transactions to readable list of each observation property. example: np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] @@ -174,8 +169,7 @@ def transform_change_obs_readable(obs): def transform_obs_readable(obs): - """ - Transform observation to readable format. + """Transform observation to readable format. np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] @@ -191,8 +185,7 @@ def transform_obs_readable(obs): def convert_to_new_obs(obs, num_nodes=10): - """ - Convert original gym Box observation space to new multiDiscrete observation space. + """Convert original gym Box observation space to new multiDiscrete observation space. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -203,8 +196,7 @@ def convert_to_new_obs(obs, num_nodes=10): def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """ - Convert to old observation. + """Convert to old observation. Links filled with 0's as no information is included in new observation space. @@ -240,8 +232,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """ - Return string describing change between two observations. + """Return string describing change between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) @@ -269,8 +260,7 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): def _describe_obs_change_helper(obs_change, is_link): - """ - Helper funcion to describe what has changed. + """Helper funcion to describe what has changed. example: [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" @@ -305,8 +295,7 @@ def _describe_obs_change_helper(obs_change, is_link): def transform_action_node_enum(action): - """ - Convert a node action from readable string format, to enumerated format. + """Convert a node action from readable string format, to enumerated format. example: [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] @@ -337,8 +326,7 @@ def transform_action_node_enum(action): def transform_action_node_readable(action): - """ - Convert a node action from enumerated format to readable format. + """Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] @@ -360,8 +348,7 @@ def transform_action_node_readable(action): def node_action_description(action): - """ - Generate string describing a node-based action. + """Generate string describing a node-based action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -388,8 +375,7 @@ def node_action_description(action): def transform_action_acl_enum(action): - """ - Convert acl action from readable str format, to enumerated format. + """Convert acl action from readable str format, to enumerated format. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -411,8 +397,7 @@ def transform_action_acl_enum(action): def acl_action_description(action): - """ - Generate string describing an acl-based action. + """Generate string describing an acl-based action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -432,8 +417,7 @@ def acl_action_description(action): def get_node_of_ip(ip, node_dict): - """ - Get the node ID of an IP address. + """Get the node ID of an IP address. node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) @@ -480,8 +464,7 @@ def is_valid_node_action(action): def is_valid_acl_action(action): - """ - Is the ACL action an actual valid action. + """Is the ACL action an actual valid action. Only uses information about the action to determine if the action has an effect @@ -513,8 +496,7 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """ - Harsher version of valid acl actions, does not allow action. + """Harsher version of valid acl actions, does not allow action. TODO: Add params and return in docstring. TODO: Typehint params and return. @@ -537,8 +519,7 @@ def is_valid_acl_action_extra(action): def get_new_action(old_action, action_dict): - """ - Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + """Get new action (e.g. 32) from old action e.g. [1,1,1,0]. Old_action can be either node or acl action type diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 40e8cf0d..42825144 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -28,8 +28,7 @@ def build_dirs(): @app.command() def reset_notebooks(overwrite: bool = True): - """ - Force a reset of the demo notebooks in the users notebooks directory. + """Force a reset of the demo notebooks in the users notebooks directory. :param overwrite: If True, will overwrite existing demo notebooks. """ @@ -40,8 +39,7 @@ def reset_notebooks(overwrite: bool = True): @app.command() def logs(last_n: Annotated[int, typer.Option("-n")]): - """ - Print the PrimAITE log file. + """Print the PrimAITE log file. :param last_n: The number of lines to print. Default value is 10. """ @@ -61,8 +59,7 @@ _LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # n @app.command() def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None): - """ - View or set the PrimAITE Log Level. + """View or set the PrimAITE Log Level. To View, simply call: primaite log-level @@ -113,8 +110,7 @@ def clean_up(): @app.command() def setup(overwrite_existing: bool = True): - """ - Perform the PrimAITE first-time setup. + """Perform the PrimAITE first-time setup. WARNING: All user-data will be lost. """ @@ -152,8 +148,7 @@ def setup(overwrite_existing: bool = True): @app.command() def session(tc: Optional[str] = None, ldc: Optional[str] = None): - """ - Run a PrimAITE session. + """Run a PrimAITE session. tc: The training config filepath. Optional. If no value is passed then example default training config is used from: @@ -178,8 +173,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None): @app.command() def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None): - """ - View or set the plotly template for Session plots. + """View or set the plotly template for Session plots. To View, simply call: primaite plotly-template diff --git a/src/primaite/common/protocol.py b/src/primaite/common/protocol.py index 2e3683e8..ebda1fcf 100644 --- a/src/primaite/common/protocol.py +++ b/src/primaite/common/protocol.py @@ -6,8 +6,7 @@ class Protocol(object): """Protocol class.""" def __init__(self, _name): - """ - Init. + """Init. Args: _name: The protocol name @@ -16,8 +15,7 @@ class Protocol(object): self.load = 0 # bps def get_name(self): - """ - Gets the protocol name. + """Gets the protocol name. Returns: The protocol name @@ -25,8 +23,7 @@ class Protocol(object): return self.name def get_load(self): - """ - Gets the protocol load. + """Gets the protocol load. Returns: The protocol load (bps) @@ -34,8 +31,7 @@ class Protocol(object): return self.load def add_load(self, _load): - """ - Adds load to the protocol. + """Adds load to the protocol. Args: _load: The load to add diff --git a/src/primaite/common/service.py b/src/primaite/common/service.py index 51403879..c381f51f 100644 --- a/src/primaite/common/service.py +++ b/src/primaite/common/service.py @@ -8,8 +8,7 @@ class Service(object): """Service class.""" def __init__(self, name: str, port: str, software_state: SoftwareState): - """ - Init. + """Init. :param name: The service name. :param port: The service port. diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 08f77b2f..587997b7 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -12,8 +12,7 @@ _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Convert a legacy lay down config dict to the new format. + """Convert a legacy lay down config dict to the new format. :param legacy_config_dict: A legacy lay down config dict. """ @@ -22,12 +21,10 @@ def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> D def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: - """ - Read in a lay down config yaml file. + """Read in a lay down config yaml file. :param file_path: The config file path. - :param legacy_file: True if the config file is legacy format, otherwise - False. + :param legacy_file: True if the config file is legacy format, otherwise False. :return: The lay down config as a dict. :raises ValueError: If the file_path does not exist. """ @@ -53,8 +50,7 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: def ddos_basic_one_config_path() -> Path: - """ - The path to the example lay_down_config_1_DDOS_basic.yaml file. + """The path to the example lay_down_config_1_DDOS_basic.yaml file. :return: The file path. """ @@ -68,8 +64,7 @@ def ddos_basic_one_config_path() -> Path: def ddos_basic_two_config_path() -> Path: - """ - The path to the example lay_down_config_2_DDOS_basic.yaml file. + """The path to the example lay_down_config_2_DDOS_basic.yaml file. :return: The file path. """ @@ -83,8 +78,7 @@ def ddos_basic_two_config_path() -> Path: def dos_very_basic_config_path() -> Path: - """ - The path to the example lay_down_config_3_DOS_very_basic.yaml file. + """The path to the example lay_down_config_3_DOS_very_basic.yaml file. :return: The file path. """ @@ -98,8 +92,7 @@ def dos_very_basic_config_path() -> Path: def data_manipulation_config_path() -> Path: - """ - The path to the example lay_down_config_5_data_manipulation.yaml file. + """The path to the example lay_down_config_5_data_manipulation.yaml file. :return: The file path. """ diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index bd73f65b..040ef6fa 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -24,8 +24,7 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training def main_training_config_path() -> Path: - """ - The path to the example training_config_main.yaml file. + """The path to the example training_config_main.yaml file. :return: The file path. """ @@ -180,8 +179,7 @@ class TrainingConfig: @classmethod def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: - """ - Create an instance of TrainingConfig from a dict. + """Create an instance of TrainingConfig from a dict. :param config_dict: The training config dict. :return: The instance of TrainingConfig. @@ -236,8 +234,7 @@ class TrainingConfig: def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: - """ - Read in a training config yaml file. + """Read in a training config yaml file. :param file_path: The config file path. :param legacy_file: True if the config file is legacy format, otherwise @@ -281,18 +278,14 @@ def convert_legacy_training_config_dict( action_type: ActionType = ActionType.ANY, num_steps: int = 256, ) -> Dict[str, Any]: - """ - Convert a legacy training config dict to the new format. + """Convert a legacy training config dict to the new format. :param legacy_config_dict: A legacy training config dict. - :param agent_framework: The agent framework to use as legacy training - configs don't have agent_framework values. - :param agent_identifier: The red agent identifier to use as legacy - training configs don't have agent_identifier values. - :param action_type: The action space type to set as legacy training configs - don't have action_type values. - :param num_steps: The number of steps to set as legacy training configs - don't have num_steps values. + :param agent_framework: The agent framework to use as legacy training configs don't have agent_framework values. + :param agent_identifier: The red agent identifier to use as legacy training configs don't have agent_identifier + values. + :param action_type: The action space type to set as legacy training configs don't have action_type values. + :param num_steps: The number of steps to set as legacy training configs don't have num_steps values. :return: The converted training config dict. """ config_dict = { @@ -312,8 +305,7 @@ def convert_legacy_training_config_dict( def _get_new_key_from_legacy(legacy_key: str) -> str: - """ - Maps legacy training config keys to the new format keys. + """Maps legacy training config keys to the new format keys. :param legacy_key: A legacy training config key. :return: The mapped key. diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py index 245b9774..542c6677 100644 --- a/src/primaite/data_viz/session_plots.py +++ b/src/primaite/data_viz/session_plots.py @@ -22,8 +22,7 @@ def plot_av_reward_per_episode( title: Optional[str] = None, subtitle: Optional[str] = None, ) -> Figure: - """ - Plot the average reward per episode from a csv session output. + """Plot the average reward per episode from a csv session output. :param av_reward_per_episode_csv: The average reward per episode csv file path. diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index f8b42e1c..e347a65c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -376,8 +376,8 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: """Component-based observation space handler. - This allows users to configure observation spaces by mixing and matching components. - Each component can also define further parameters to make them more flexible. + This allows users to configure observation spaces by mixing and matching components. Each component can also define + further parameters to make them more flexible. """ _REGISTRY: Final[Dict[str, type]] = { diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 03c23f93..29662988 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -67,14 +67,12 @@ class Primaite(Env): session_path: Path, timestamp_str: str, ): - """ - The Primaite constructor. + """The Primaite constructor. :param training_config_path: The training config filepath. :param lay_down_config_path: The lay down config filepath. :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. + :param timestamp_str: The session timestamp in the format: _. """ self.session_path: Final[Path] = session_path self.timestamp_str: Final[str] = timestamp_str @@ -256,8 +254,7 @@ class Primaite(Env): self.total_step_count = 0 def reset(self): - """ - AI Gym Reset function. + """AI Gym Reset function. Returns: Environment observation space (reset) @@ -293,8 +290,7 @@ class Primaite(Env): return self.env_obs def step(self, action): - """ - AI Gym Step function. + """AI Gym Step function. Args: action: Action space from agent @@ -432,8 +428,7 @@ class Primaite(Env): print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) def interpret_action_and_apply(self, _action): - """ - Applies agent actions to the nodes and Access Control List. + """Applies agent actions to the nodes and Access Control List. Args: _action: The action space from the agent @@ -452,8 +447,7 @@ class Primaite(Env): logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): - """ - Applies agent actions to the nodes. + """Applies agent actions to the nodes. Args: _action: The action space from the agent @@ -540,8 +534,7 @@ class Primaite(Env): return def apply_actions_to_acl(self, _action): - """ - Applies agent actions to the Access Control List [TO DO]. + """Applies agent actions to the Access Control List [TO DO]. Args: _action: The action space from the agent @@ -618,8 +611,7 @@ class Primaite(Env): return def apply_time_based_updates(self): - """ - Updates anything that needs to count down and then change state. + """Updates anything that needs to count down and then change state. e.g. reset / patching status """ @@ -716,8 +708,7 @@ class Primaite(Env): print("Environment configuration loaded") def create_node(self, item): - """ - Creates a node from config data. + """Creates a node from config data. Args: item: A config data item @@ -797,8 +788,7 @@ class Primaite(Env): self.network_reference.add_nodes_from([node_ref]) def create_link(self, item: Dict): - """ - Creates a link from config data. + """Creates a link from config data. Args: item: A config data item @@ -841,8 +831,7 @@ class Primaite(Env): ) def create_green_ier(self, item): - """ - Creates a green IER from config data. + """Creates a green IER from config data. Args: item: A config data item @@ -882,8 +871,7 @@ class Primaite(Env): ) def create_red_ier(self, item): - """ - Creates a red IER from config data. + """Creates a red IER from config data. Args: item: A config data item @@ -912,8 +900,7 @@ class Primaite(Env): ) def create_green_pol(self, item): - """ - Creates a green PoL object from config data. + """Creates a green PoL object from config data. Args: item: A config data item @@ -946,8 +933,7 @@ class Primaite(Env): ) def create_red_pol(self, item): - """ - Creates a red PoL object from config data. + """Creates a red PoL object from config data. Args: item: A config data item @@ -987,8 +973,7 @@ class Primaite(Env): ) def create_acl_rule(self, item): - """ - Creates an ACL rule from config data. + """Creates an ACL rule from config data. Args: item: A config data item @@ -1008,8 +993,7 @@ class Primaite(Env): ) def create_services_list(self, services): - """ - Creates a list of services (enum) from config data. + """Creates a list of services (enum) from config data. Args: item: A config data item representing the services @@ -1024,8 +1008,7 @@ class Primaite(Env): self.num_services = len(self.services_list) def create_ports_list(self, ports): - """ - Creates a list of ports from config data. + """Creates a list of ports from config data. Args: item: A config data item representing the ports @@ -1048,8 +1031,7 @@ class Primaite(Env): self.observation_type = ObservationType[observation_info["type"]] def get_action_info(self, action_info): - """ - Extracts action_info. + """Extracts action_info. Args: item: A config data item representing action info @@ -1069,11 +1051,9 @@ class Primaite(Env): self.obs_config = obs_config def reset_environment(self): - """ - # Resets environment. + """# Resets environment. - Uses config data config data in order to build the environment - configuration. + Uses config data config data in order to build the environment configuration. """ for item in self.lay_down_config: if item["item_type"] == "NODE": @@ -1095,8 +1075,7 @@ class Primaite(Env): ier_value.set_is_running(False) def reset_node(self, item): - """ - Resets the statuses of a node. + """Resets the statuses of a node. Args: item: A config data item @@ -1143,8 +1122,7 @@ class Primaite(Env): pass def create_node_action_dict(self): - """ - Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. + """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. Note: Only actions that have the potential to change the state exist in the mapping (except for key 0) @@ -1157,7 +1135,6 @@ class Primaite(Env): 5: [1, 3, 1, 0], ... } - """ # reserve 0 action to be a nothing action actions = {0: [1, 0, 0, 0]} @@ -1209,11 +1186,9 @@ class Primaite(Env): return actions def create_node_and_acl_action_dict(self): - """ - Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. + """Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. The dictionary contains actions of both Node and ACL action types. - """ node_action_dict = self.create_node_action_dict() acl_action_dict = self.create_acl_action_dict() diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 19094a18..5cef47ef 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -21,8 +21,7 @@ def calculate_reward_function( step_count, config_values, ): - """ - Compares the states of the initial and final nodes/links to get a reward. + """Compares the states of the initial and final nodes/links to get a reward. Args: initial_nodes: The nodes before red and blue agents take effect @@ -95,8 +94,7 @@ def calculate_reward_function( def score_node_operating_state(final_node, initial_node, reference_node, config_values): - """ - Calculates score relating to the hardware state of a node. + """Calculates score relating to the hardware state of a node. Args: final_node: The node after red and blue agents take effect @@ -144,8 +142,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ def score_node_os_state(final_node, initial_node, reference_node, config_values): - """ - Calculates score relating to the Software State of a node. + """Calculates score relating to the Software State of a node. Args: final_node: The node after red and blue agents take effect @@ -195,8 +192,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) def score_node_service_state(final_node, initial_node, reference_node, config_values): - """ - Calculates score relating to the service state(s) of a node. + """Calculates score relating to the service state(s) of a node. Args: final_node: The node after red and blue agents take effect @@ -267,8 +263,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va def score_node_file_system(final_node, initial_node, reference_node, config_values): - """ - Calculates score relating to the file system state of a node. + """Calculates score relating to the file system state of a node. Args: final_node: The node after red and blue agents take effect diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index 90235e9f..e8901b3d 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -9,8 +9,7 @@ class Link(object): """Link class.""" def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): - """ - Init. + """Init. Args: _id: The IER id @@ -30,8 +29,7 @@ class Link(object): self.add_protocol(protocol_name) def add_protocol(self, _protocol): - """ - Adds a new protocol to the list of protocols on this link. + """Adds a new protocol to the list of protocols on this link. Args: _protocol: The protocol to be added (enum) @@ -39,8 +37,7 @@ class Link(object): self.protocol_list.append(Protocol(_protocol)) def get_id(self): - """ - Gets link ID. + """Gets link ID. Returns: Link ID @@ -48,8 +45,7 @@ class Link(object): return self.id def get_source_node_name(self): - """ - Gets source node name. + """Gets source node name. Returns: Source node name @@ -57,8 +53,7 @@ class Link(object): return self.source_node_name def get_dest_node_name(self): - """ - Gets destination node name. + """Gets destination node name. Returns: Destination node name @@ -66,8 +61,7 @@ class Link(object): return self.dest_node_name def get_bandwidth(self): - """ - Gets bandwidth of link. + """Gets bandwidth of link. Returns: Link bandwidth (bps) @@ -75,8 +69,7 @@ class Link(object): return self.bandwidth def get_protocol_list(self): - """ - Gets list of protocols on this link. + """Gets list of protocols on this link. Returns: List of protocols on this link @@ -84,8 +77,7 @@ class Link(object): return self.protocol_list def get_current_load(self): - """ - Gets current total load on this link. + """Gets current total load on this link. Returns: Total load on this link (bps) @@ -96,8 +88,7 @@ class Link(object): return total_load def add_protocol_load(self, _protocol, _load): - """ - Adds a loading to a protocol on this link. + """Adds a loading to a protocol on this link. Args: _protocol: The protocol to load diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index 07a0ea0a..588ccd93 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -25,8 +25,7 @@ class ActiveNode(Node): file_system_state: FileSystemState, config_values: TrainingConfig, ): - """ - Init. + """Init. :param node_id: The node ID :param name: The node name @@ -52,8 +51,7 @@ class ActiveNode(Node): @property def software_state(self) -> SoftwareState: - """ - Get the software_state. + """Get the software_state. :return: The software_state. """ @@ -61,8 +59,7 @@ class ActiveNode(Node): @software_state.setter def software_state(self, software_state: SoftwareState): - """ - Get the software_state. + """Get the software_state. :param software_state: Software State. """ @@ -80,8 +77,7 @@ class ActiveNode(Node): ) def set_software_state_if_not_compromised(self, software_state: SoftwareState): - """ - Sets Software State if the node is not compromised. + """Sets Software State if the node is not compromised. Args: software_state: Software State @@ -107,8 +103,7 @@ class ActiveNode(Node): self._software_state = SoftwareState.GOOD def set_file_system_state(self, file_system_state: FileSystemState): - """ - Sets the file system state (actual and observed). + """Sets the file system state (actual and observed). Args: file_system_state: File system state @@ -134,8 +129,7 @@ class ActiveNode(Node): ) def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState): - """ - Sets the file system state (actual and observed) if not in a compromised state. + """Sets the file system state (actual and observed) if not in a compromised state. Use for green PoL to prevent it overturning a compromised state diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index bac1792d..40f6328f 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -18,8 +18,7 @@ class Node: hardware_state: HardwareState, config_values: TrainingConfig, ): - """ - Init. + """Init. :param node_id: The node id. :param name: The name of the node. diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 2b1d94be..e1244144 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -15,8 +15,7 @@ class NodeStateInstructionGreen(object): _service_name, _state, ): - """ - Init. + """Init. Args: _id: The node state instruction id @@ -36,8 +35,7 @@ class NodeStateInstructionGreen(object): self.state = _state def get_start_step(self): - """ - Gets the start step. + """Gets the start step. Returns: The start step @@ -45,8 +43,7 @@ class NodeStateInstructionGreen(object): return self.start_step def get_end_step(self): - """ - Gets the end step. + """Gets the end step. Returns: The end step @@ -54,8 +51,7 @@ class NodeStateInstructionGreen(object): return self.end_step def get_node_id(self): - """ - Gets the node ID. + """Gets the node ID. Returns: The node ID @@ -63,8 +59,7 @@ class NodeStateInstructionGreen(object): return self.node_id def get_node_pol_type(self): - """ - Gets the node pattern of life type (enum). + """Gets the node pattern of life type (enum). Returns: The node pattern of life type (enum) @@ -72,8 +67,7 @@ class NodeStateInstructionGreen(object): return self.node_pol_type def get_service_name(self): - """ - Gets the service name. + """Gets the service name. Returns: The service name @@ -81,8 +75,7 @@ class NodeStateInstructionGreen(object): return self.service_name def get_state(self): - """ - Gets the state (node or service). + """Gets the state (node or service). Returns: The state (node or service) diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 4272ce24..3e2e734d 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -23,8 +23,7 @@ class NodeStateInstructionRed(object): _pol_source_node_service, _pol_source_node_service_state, ): - """ - Init. + """Init. Args: _id: The node state instruction id @@ -52,8 +51,7 @@ class NodeStateInstructionRed(object): self.source_node_service_state = _pol_source_node_service_state def get_start_step(self): - """ - Gets the start step. + """Gets the start step. Returns: The start step @@ -61,8 +59,7 @@ class NodeStateInstructionRed(object): return self.start_step def get_end_step(self): - """ - Gets the end step. + """Gets the end step. Returns: The end step @@ -70,8 +67,7 @@ class NodeStateInstructionRed(object): return self.end_step def get_target_node_id(self): - """ - Gets the node ID. + """Gets the node ID. Returns: The node ID @@ -79,8 +75,7 @@ class NodeStateInstructionRed(object): return self.target_node_id def get_initiator(self): - """ - Gets the initiator. + """Gets the initiator. Returns: The initiator @@ -88,8 +83,7 @@ class NodeStateInstructionRed(object): return self.initiator def get_pol_type(self) -> NodePOLType: - """ - Gets the node pattern of life type (enum). + """Gets the node pattern of life type (enum). Returns: The node pattern of life type (enum) @@ -97,8 +91,7 @@ class NodeStateInstructionRed(object): return self.pol_type def get_service_name(self): - """ - Gets the service name. + """Gets the service name. Returns: The service name @@ -106,8 +99,7 @@ class NodeStateInstructionRed(object): return self.service_name def get_state(self): - """ - Gets the state (node or service). + """Gets the state (node or service). Returns: The state (node or service) @@ -115,8 +107,7 @@ class NodeStateInstructionRed(object): return self.state def get_source_node_id(self): - """ - Gets the source node id (used for initiator type SERVICE). + """Gets the source node id (used for initiator type SERVICE). Returns: The source node id @@ -124,8 +115,7 @@ class NodeStateInstructionRed(object): return self.source_node_id def get_source_node_service(self): - """ - Gets the source node service (used for initiator type SERVICE). + """Gets the source node service (used for initiator type SERVICE). Returns: The source node service @@ -133,8 +123,7 @@ class NodeStateInstructionRed(object): return self.source_node_service def get_source_node_service_state(self): - """ - Gets the source node service state (used for initiator type SERVICE). + """Gets the source node service state (used for initiator type SERVICE). Returns: The source node service state diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index 9aa5c7d7..188b4ee3 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -17,8 +17,7 @@ class PassiveNode(Node): hardware_state: HardwareState, config_values: TrainingConfig, ): - """ - Init. + """Init. :param node_id: The node id. :param name: The name of the node. @@ -32,8 +31,7 @@ class PassiveNode(Node): @property def ip_address(self) -> str: - """ - Gets the node IP address as an empty string. + """Gets the node IP address as an empty string. No concept of IP address for passive nodes for now. diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 5d69df92..0114f507 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -26,8 +26,7 @@ class ServiceNode(ActiveNode): file_system_state: FileSystemState, config_values: TrainingConfig, ): - """ - Init. + """Init. :param node_id: The node ID :param name: The node name @@ -53,16 +52,14 @@ class ServiceNode(ActiveNode): self.services: Dict[str, Service] = {} def add_service(self, service: Service): - """ - Adds a service to the node. + """Adds a service to the node. :param service: The service to add """ self.services[service.name] = service def has_service(self, protocol_name: str) -> bool: - """ - Indicates whether a service is on a node. + """Indicates whether a service is on a node. :param protocol_name: The service (protocol)e. :return: True if service (protocol) is on the node, otherwise False. @@ -73,12 +70,10 @@ class ServiceNode(ActiveNode): return False def service_running(self, protocol_name: str) -> bool: - """ - Indicates whether a service is in a running state on the node. + """Indicates whether a service is in a running state on the node. :param protocol_name: The service (protocol) - :return: True if service (protocol) is in a running state on the - node, otherwise False. + :return: True if service (protocol) is in a running state on the node, otherwise False. """ for service_key, service_value in self.services.items(): if service_key == protocol_name: @@ -89,12 +84,10 @@ class ServiceNode(ActiveNode): return False def service_is_overwhelmed(self, protocol_name: str) -> bool: - """ - Indicates whether a service is in an overwhelmed state on the node. + """Indicates whether a service is in an overwhelmed state on the node. :param protocol_name: The service (protocol) - :return: True if service (protocol) is in an overwhelmed state on the - node, otherwise False. + :return: True if service (protocol) is in an overwhelmed state on the node, otherwise False. """ for service_key, service_value in self.services.items(): if service_key == protocol_name: @@ -105,8 +98,7 @@ class ServiceNode(ActiveNode): return False def set_service_state(self, protocol_name: str, software_state: SoftwareState): - """ - Sets the software_state of a service (protocol) on the node. + """Sets the software_state of a service (protocol) on the node. :param protocol_name: The service (protocol). :param software_state: The software_state. @@ -134,8 +126,7 @@ class ServiceNode(ActiveNode): ) def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState): - """ - Sets the software_state of a service (protocol) on the node. + """Sets the software_state of a service (protocol) on the node. Done if the software_state is not "compromised". @@ -161,8 +152,7 @@ class ServiceNode(ActiveNode): ) def get_service_state(self, protocol_name): - """ - Gets the state of a service. + """Gets the state of a service. :return: The software_state of the service. """ diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 0e81e581..0730312e 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -10,8 +10,7 @@ _LOGGER = getLogger(__name__) def start_jupyter_session(): - """ - Starts a new Jupyter notebook session in the app notebooks directory. + """Starts a new Jupyter notebook session in the app notebooks directory. Currently only works on Windows OS. diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index e9dfef8c..91a6f787 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -25,8 +25,7 @@ def apply_iers( acl: AccessControlList, step: int, ): - """ - Applies IERs to the links (link pattern of life). + """Applies IERs to the links (link pattern of life). Args: network: The network modelled in the environment @@ -218,8 +217,7 @@ def apply_node_pol( node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], step: int, ): - """ - Applies node pattern of life. + """Applies node pattern of life. Args: nodes: The nodes within the environment diff --git a/src/primaite/pol/ier.py b/src/primaite/pol/ier.py index daa49727..09f32aeb 100644 --- a/src/primaite/pol/ier.py +++ b/src/primaite/pol/ier.py @@ -1,6 +1,5 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -""" -Information Exchange Requirements for APE. +"""Information Exchange Requirements for APE. Used to represent an information flow from source to destination. """ @@ -22,8 +21,7 @@ class IER(object): _mission_criticality, _running=False, ): - """ - Init. + """Init. Args: _id: The IER id @@ -49,8 +47,7 @@ class IER(object): self.running = _running def get_id(self): - """ - Gets IER ID. + """Gets IER ID. Returns: IER ID @@ -58,8 +55,7 @@ class IER(object): return self.id def get_start_step(self): - """ - Gets IER start step. + """Gets IER start step. Returns: IER start step @@ -67,8 +63,7 @@ class IER(object): return self.start_step def get_end_step(self): - """ - Gets IER end step. + """Gets IER end step. Returns: IER end step @@ -76,8 +71,7 @@ class IER(object): return self.end_step def get_load(self): - """ - Gets IER load. + """Gets IER load. Returns: IER load @@ -85,8 +79,7 @@ class IER(object): return self.load def get_protocol(self): - """ - Gets IER protocol. + """Gets IER protocol. Returns: IER protocol @@ -94,8 +87,7 @@ class IER(object): return self.protocol def get_port(self): - """ - Gets IER port. + """Gets IER port. Returns: IER port @@ -103,8 +95,7 @@ class IER(object): return self.port def get_source_node_id(self): - """ - Gets IER source node ID. + """Gets IER source node ID. Returns: IER source node ID @@ -112,8 +103,7 @@ class IER(object): return self.source_node_id def get_dest_node_id(self): - """ - Gets IER destination node ID. + """Gets IER destination node ID. Returns: IER destination node ID @@ -121,8 +111,7 @@ class IER(object): return self.dest_node_id def get_is_running(self): - """ - Informs whether the IER is currently running. + """Informs whether the IER is currently running. Returns: True if running @@ -130,8 +119,7 @@ class IER(object): return self.running def set_is_running(self, _value): - """ - Sets the running state of the IER. + """Sets the running state of the IER. Args: _value: running status @@ -139,8 +127,7 @@ class IER(object): self.running = _value def get_mission_criticality(self): - """ - Gets the IER mission criticality (used in the reward function). + """Gets the IER mission criticality (used in the reward function). Returns: Mission criticality value (0 lowest to 5 highest) diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index bff19bf8..86482903 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -24,8 +24,7 @@ def apply_red_agent_iers( acl: AccessControlList, step: int, ): - """ - Applies IERs to the links (link POL) resulting from red agent attack. + """Applies IERs to the links (link POL) resulting from red agent attack. Args: network: The network modelled in the environment @@ -214,8 +213,7 @@ def apply_red_agent_node_pol( node_pol: Dict[str, NodeStateInstructionRed], step: int, ): - """ - Applies node pattern of life. + """Applies node pattern of life. Args: nodes: The nodes within the environment @@ -297,8 +295,7 @@ def apply_red_agent_node_pol( def is_red_ier_incoming(node, iers, node_pol_type): - """ - Checks if the RED IER is incoming. + """Checks if the RED IER is incoming. TODO: Write more descriptive docstring with params and returns. """ diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index df3ebec1..ed2b9bf1 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -18,11 +18,9 @@ _LOGGER = getLogger(__name__) class PrimaiteSession: - """ - The PrimaiteSession class. + """The PrimaiteSession class. - Provides a single learning and evaluation entry point for all training - and lay down configurations. + Provides a single learning and evaluation entry point for all training and lay down configurations. """ def __init__( @@ -30,8 +28,7 @@ class PrimaiteSession: training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], ): - """ - The PrimaiteSession constructor. + """The PrimaiteSession constructor. :param training_config_path: The training config path. :param lay_down_config_path: The lay down config path. @@ -125,8 +122,7 @@ class PrimaiteSession: self, **kwargs, ): - """ - Train the agent. + """Train the agent. :param kwargs: Any agent-framework specific key word args. """ @@ -137,8 +133,7 @@ class PrimaiteSession: self, **kwargs, ): - """ - Evaluate the agent. + """Evaluate the agent. :param kwargs: Any agent-framework specific key word args. """ diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 7fa96783..8d2a94c7 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -12,11 +12,9 @@ _LOGGER = getLogger(__name__) def run(overwrite_existing: bool = True): - """ - Resets the demo jupyter notebooks in the users app notebooks directory. + """Resets the demo jupyter notebooks in the users app notebooks directory. - :param overwrite_existing: A bool to toggle replacing existing edited - notebooks on or off. + :param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off. """ notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data") for subdir, dirs, files in os.walk(notebooks_package_data_root): diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index 5d62298c..a2e1f2c9 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -11,11 +11,9 @@ _LOGGER = getLogger(__name__) def run(overwrite_existing=True): - """ - Resets the example config files in the users app config directory. + """Resets the example config files in the users app config directory. - :param overwrite_existing: A bool to toggle replacing existing edited - config on or off. + :param overwrite_existing: A bool to toggle replacing existing edited config on or off. """ configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data") diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py index 693b11c1..bf7dbe59 100644 --- a/src/primaite/setup/setup_app_dirs.py +++ b/src/primaite/setup/setup_app_dirs.py @@ -5,8 +5,7 @@ _LOGGER = getLogger(__name__) def run(): - """ - Handles creation of application directories and user directories. + """Handles creation of application directories and user directories. Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required app directories in the correct locations based on the users OS. diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 7db2444a..21d4ee05 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -10,8 +10,7 @@ class Transaction(object): """Transaction class.""" def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int): - """ - Transaction constructor. + """Transaction constructor. :param agent_identifier: An identifier for the agent in use :param episode_number: The episode number @@ -39,8 +38,7 @@ class Transaction(object): "The env observation space description" def as_csv_data(self) -> Tuple[List, List]: - """ - Converts the Transaction to a csv data row and provides a header. + """Converts the Transaction to a csv data row and provides a header. :return: A tuple consisting of (header, data). """ @@ -69,8 +67,7 @@ class Transaction(object): def _turn_action_space_to_array(action_space) -> List[str]: - """ - Turns action space into a string array so it can be saved to csv. + """Turns action space into a string array so it can be saved to csv. :param action_space: The action space :return: The action space as an array of strings @@ -82,12 +79,10 @@ def _turn_action_space_to_array(action_space) -> List[str]: def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]: - """ - Turns observation space into a string array so it can be saved to csv. + """Turns observation space into a string array so it can be saved to csv. :param obs_space: The observation space - :param obs_assets: The number of assets (i.e. nodes or links) in the - observation space + :param obs_assets: The number of assets (i.e. nodes or links) in the observation space :param obs_features: The number of features associated with the asset :return: The observation space as an array of strings """ diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py index 59f36851..463a4309 100644 --- a/src/primaite/utils/package_data.py +++ b/src/primaite/utils/package_data.py @@ -10,8 +10,7 @@ _LOGGER = getLogger(__name__) def get_file_path(path: str) -> Path: - """ - Get PrimAITE package data. + """Get PrimAITE package data. :Example: diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index d04f375e..6b5cfdc3 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -7,11 +7,9 @@ import polars as pl def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: - """ - Read an average rewards per episode csv file and return as a dict. + """Read an average rewards per episode csv file and return as a dict. - The dictionary keys are the episode number, and the values are the mean - reward that episode. + The dictionary keys are the episode number, and the values are the mean reward that episode. :param av_rewards_csv_file: The average rewards per episode csv file path. :return: The average rewards per episode cdv as a dict. diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index a05b0453..939ebdb5 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -12,8 +12,7 @@ _LOGGER: Logger = getLogger(__name__) class SessionOutputWriter: - """ - A session output writer class. + """A session output writer class. Is used to write session outputs to csv file. """ @@ -65,11 +64,9 @@ class SessionOutputWriter: _LOGGER.debug(f"Finished writing file: {self._csv_file_path}") def write(self, data: Union[Tuple, Transaction]): - """ - Write a row of session data. + """Write a row of session data. - :param data: The row of data to write. Can be a Tuple or an instance - of Transaction. + :param data: The row of data to write. Can be a Tuple or an instance of Transaction. """ if isinstance(data, Transaction): header, data = data.as_csv_data()