diff --git a/src/primaite/acl/__init__.py b/src/primaite/acl/__init__.py deleted file mode 100644 index 6dc02583..00000000 --- a/src/primaite/acl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Access Control List. Models firewall functionality.""" diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py deleted file mode 100644 index 88943f8f..00000000 --- a/src/primaite/acl/access_control_list.py +++ /dev/null @@ -1,198 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""A class that implements the access control list implementation for the network.""" -import logging -from typing import Dict, Final, List, Union - -from primaite.acl.acl_rule import ACLRule -from primaite.common.enums import RulePermissionType - -_LOGGER: Final[logging.Logger] = logging.getLogger(__name__) - - -class AccessControlList: - """Access Control List class.""" - - def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None: - """Init.""" - # Implicit ALLOW or DENY firewall spec - self.acl_implicit_permission = implicit_permission - # Implicit rule in ACL list - if self.acl_implicit_permission == RulePermissionType.DENY: - self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY") - elif self.acl_implicit_permission == RulePermissionType.ALLOW: - self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY") - else: - raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}") - - # Maximum number of ACL Rules in ACL - self.max_acl_rules: int = max_acl_rules - # A list of ACL Rules - self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1) - - @property - def acl(self) -> List[Union[ACLRule, None]]: - """Public access method for private _acl.""" - return self._acl + [self.acl_implicit_rule] - - def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool: - """Checks for IP address matches. - - :param _rule: The rule object to check - :type _rule: ACLRule - :param _source_ip_address: Source IP address to compare - :type _source_ip_address: str - :param _dest_ip_address: Destination IP address to compare - :type _dest_ip_address: str - :return: True if there is a match, otherwise False. - :rtype: bool - """ - if ( - (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) - or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address) - or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY") - or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY") - ): - return True - else: - return False - - def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool: - """ - Checks for rules that block a protocol / port. - - Args: - _source_ip_address: the source IP address to check - _dest_ip_address: the destination IP address to check - _protocol: the protocol to check - _port: the port to check - - Returns: - Indicates block if all conditions are satisfied. - """ - for rule in self.acl: - if isinstance(rule, ACLRule): - if self.check_address_match(rule, _source_ip_address, _dest_ip_address): - if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and ( - str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" - ): - # There's a matching rule. Get the permission - if rule.get_permission() == RulePermissionType.DENY: - return True - elif rule.get_permission() == RulePermissionType.ALLOW: - return False - - # If there has been no rule to allow the IER through, it will return a blocked signal by default - return True - - def add_rule( - self, - _permission: RulePermissionType, - _source_ip: str, - _dest_ip: str, - _protocol: str, - _port: str, - _position: str, - ) -> None: - """ - Adds a new rule. - - Args: - _permission: the permission value (e.g. "ALLOW" or "DENY") - _source_ip: the source IP address - _dest_ip: the destination IP address - _protocol: the protocol - _port: the port - _position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0) - """ - try: - position_index = int(_position) - except TypeError: - _LOGGER.info(f"Position {_position} could not be converted to integer.") - return - - new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) - # Checks position is in correct range - if self.max_acl_rules - 1 > position_index > -1: - try: - _LOGGER.info(f"Position {position_index} is valid.") - # Check to see Agent will not overwrite current ACL in ACL list - if self._acl[position_index] is None: - _LOGGER.info(f"Inserting rule {new_rule} at position {position_index}") - # Adds rule - self._acl[position_index] = new_rule - else: - # Cannot overwrite it - _LOGGER.info(f"Error: inserting rule at non-empty position {position_index}") - return - except Exception: - _LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.") - else: - _LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule") - - def remove_rule( - self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str - ) -> None: - """ - Removes a rule. - - Args: - _permission: the permission value (e.g. "ALLOW" or "DENY") - _source_ip: the source IP address - _dest_ip: the destination IP address - _protocol: the protocol - _port: the port - """ - rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) - delete_rule_hash = hash(rule_to_delete) - - for index in range(0, len(self._acl)): - if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash: - self._acl[index] = None - - def remove_all_rules(self) -> None: - """Removes all rules.""" - for i in range(len(self._acl)): - self._acl[i] = None - - def get_dictionary_hash( - self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str - ) -> int: - """ - Produces a hash value for a rule. - - Args: - _permission: the permission value (e.g. "ALLOW" or "DENY") - _source_ip: the source IP address - _dest_ip: the destination IP address - _protocol: the protocol - _port: the port - - Returns: - Hash value based on rule parameters. - """ - rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) - hash_value = hash(rule) - return hash_value - - def get_relevant_rules( - self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str - ) -> Dict[int, ACLRule]: - """Get all ACL rules that relate to the given arguments. - - :param _source_ip_address: the source IP address to check - :param _dest_ip_address: the destination IP address to check - :param _protocol: the protocol to check - :param _port: the port to check - :return: Dictionary of all ACL rules that relate to the given arguments - :rtype: Dict[int, ACLRule] - """ - relevant_rules = {} - for rule in self.acl: - if self.check_address_match(rule, _source_ip_address, _dest_ip_address): - if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and ( - str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY" - ): - # There's a matching rule. - relevant_rules[self._acl.index(rule)] = rule - - return relevant_rules diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py deleted file mode 100644 index 9c8deacd..00000000 --- a/src/primaite/acl/acl_rule.py +++ /dev/null @@ -1,87 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""A class that implements an access control list rule.""" -from primaite.common.enums import RulePermissionType - - -class ACLRule: - """Access Control List Rule class.""" - - def __init__( - self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str - ) -> None: - """ - Initialise an ACL Rule. - - :param _permission: The permission (ALLOW or DENY) - :param _source_ip: The source IP address - :param _dest_ip: The destination IP address - :param _protocol: The rule protocol - :param _port: The rule port - """ - self.permission: RulePermissionType = _permission - self.source_ip: str = _source_ip - self.dest_ip: str = _dest_ip - self.protocol: str = _protocol - self.port: str = _port - - def __hash__(self) -> int: - """ - Override the hash function. - - Returns: - Returns hash of core parameters. - """ - return hash( - ( - self.permission, - self.source_ip, - self.dest_ip, - self.protocol, - self.port, - ) - ) - - def get_permission(self) -> str: - """ - Gets the permission attribute. - - Returns: - Returns permission attribute - """ - return self.permission - - def get_source_ip(self) -> str: - """ - Gets the source IP address attribute. - - Returns: - Returns source IP address attribute - """ - return self.source_ip - - def get_dest_ip(self) -> str: - """ - Gets the desintation IP address attribute. - - Returns: - Returns destination IP address attribute - """ - return self.dest_ip - - def get_protocol(self) -> str: - """ - Gets the protocol attribute. - - Returns: - Returns protocol attribute - """ - return self.protocol - - def get_port(self) -> str: - """ - Gets the port attribute. - - Returns: - Returns port attribute - """ - return self.port diff --git a/src/primaite/agents/__init__.py b/src/primaite/agents/__init__.py deleted file mode 100644 index c742daf3..00000000 --- a/src/primaite/agents/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Common interface between RL agents from different libraries and PrimAITE.""" diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py deleted file mode 100644 index 359790ad..00000000 --- a/src/primaite/agents/agent_abc.py +++ /dev/null @@ -1,319 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from __future__ import annotations - -import json -from abc import ABC, abstractmethod -from datetime import datetime -from logging import Logger -from pathlib import Path -from typing import Any, Dict, Optional, Union -from uuid import uuid4 - -import primaite -from primaite import getLogger, PRIMAITE_PATHS -from primaite.config import lay_down_config, training_config -from primaite.config.training_config import TrainingConfig -from primaite.data_viz.session_plots import plot_av_reward_per_episode -from primaite.environment.primaite_env import Primaite -from primaite.utils.session_metadata_parser import parse_session_metadata - -_LOGGER: Logger = getLogger(__name__) - - -def get_session_path(session_timestamp: datetime) -> Path: - """ - Get the directory path the session will output to. - - This is set in the format of: - ~/primaite/2.0.0/sessions//_. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path - session_path.mkdir(exist_ok=True, parents=True) - - return session_path - - -class AgentSessionABC(ABC): - """ - An ABC that manages training and/or evaluation of agents in PrimAITE. - - This class cannot be directly instantiated and must be inherited from with all implemented abstract methods - implemented. - """ - - @abstractmethod - def __init__( - self, - training_config_path: Optional[Union[str, Path]] = None, - lay_down_config_path: Optional[Union[str, Path]] = None, - session_path: Optional[Union[str, Path]] = None, - legacy_training_config: bool = False, - legacy_lay_down_config: bool = False, - ) -> None: - """ - Initialise an agent session from config files, or load a previous session. - - If training configuration and laydown configuration are provided with a session path, - the session path will be used. - - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - :param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0, - otherwise False. - :param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0, - otherwise False. - :param session_path: directory path of the session to load - """ - # initialise variables - self._env: Primaite - self._agent = None - self._can_learn: bool = False - self._can_evaluate: bool = False - self.is_eval = False - self.legacy_training_config = legacy_training_config - self.legacy_lay_down_config = legacy_lay_down_config - - self.session_timestamp: datetime = datetime.now() - - # convert session to path - if session_path is not None: - if not isinstance(session_path, Path): - session_path = Path(session_path) - - # if a session path is provided, load it - if not session_path.exists(): - raise Exception(f"Session could not be loaded. Path does not exist: {session_path}") - - # load session - self.load(session_path) - else: - # set training config path - if not isinstance(training_config_path, Path): - training_config_path = Path(training_config_path) - self._training_config_path: Union[Path, str] = training_config_path - self._training_config: TrainingConfig = training_config.load( - self._training_config_path, legacy_file=legacy_training_config - ) - - if not isinstance(lay_down_config_path, Path): - lay_down_config_path = Path(lay_down_config_path) - self._lay_down_config_path: Union[Path, str] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) - self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - - # set random UUID for session - self._uuid = str(uuid4()) - "The session timestamp" - self.session_path = get_session_path(self.session_timestamp) - "The Session path" - - @property - def timestamp_str(self) -> str: - """The session timestamp as a string.""" - return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - - @property - def learning_path(self) -> Path: - """The learning outputs path.""" - path = self.session_path / "learning" - path.mkdir(exist_ok=True, parents=True) - return path - - @property - def evaluation_path(self) -> Path: - """The evaluation outputs path.""" - path = self.session_path / "evaluation" - path.mkdir(exist_ok=True, parents=True) - return path - - @property - def checkpoints_path(self) -> Path: - """The Session checkpoints path.""" - path = self.learning_path / "checkpoints" - path.mkdir(exist_ok=True, parents=True) - return path - - @property - def uuid(self) -> str: - """The Agent Session UUID.""" - return self._uuid - - def _write_session_metadata_file(self) -> None: - """ - Write the ``session_metadata.json`` file. - - Creates a ``session_metadata.json`` in the ``session_path`` directory - and adds the following key/value pairs: - - - uuid: The UUID assigned to the session upon instantiation. - - start_datetime: The date & time the session started in iso format. - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - - env: - - training_config: - - All training config items - - lay_down_config: - - All lay down config items - - """ - metadata_dict = { - "uuid": self.uuid, - "start_datetime": self.session_timestamp.isoformat(), - "end_datetime": None, - "learning": {"total_episodes": None, "total_time_steps": None}, - "evaluation": {"total_episodes": None, "total_time_steps": None}, - "env": { - "training_config": self._training_config.to_dict(json_serializable=True), - "lay_down_config": self._lay_down_config, - }, - } - filepath = self.session_path / "session_metadata.json" - _LOGGER.debug(f"Writing Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - _LOGGER.debug("Finished writing session metadata file") - - def _update_session_metadata_file(self) -> None: - """ - Update the ``session_metadata.json`` file. - - Updates the `session_metadata.json`` in the ``session_path`` directory - with the following key/value pairs: - - - end_datetime: The date & time the session ended in iso format. - - total_episodes: The total number of training episodes completed. - - total_time_steps: The total number of training time steps completed. - """ - with open(self.session_path / "session_metadata.json", "r") as file: - metadata_dict = json.load(file) - - metadata_dict["end_datetime"] = datetime.now().isoformat() - if not self.is_eval: - metadata_dict["learning"]["total_episodes"] = self._env.actual_episode_count # noqa - metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa - else: - metadata_dict["evaluation"]["total_episodes"] = self._env.actual_episode_count # noqa - metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa - - filepath = self.session_path / "session_metadata.json" - _LOGGER.debug(f"Updating Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - _LOGGER.debug("Finished updating session metadata file") - - @abstractmethod - def _setup(self) -> None: - _LOGGER.info( - "Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})" - ) - _LOGGER.info(f"The output directory for this session is: {self.session_path}") - self._write_session_metadata_file() - self._can_learn = True - self._can_evaluate = False - - @abstractmethod - def _save_checkpoint(self) -> None: - pass - - @abstractmethod - def learn( - self, - **kwargs: Any, - ) -> None: - """ - Train the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - if self._can_learn: - _LOGGER.info("Finished learning") - _LOGGER.debug("Writing transactions") - self._update_session_metadata_file() - self._can_evaluate = True - self.is_eval = False - - @abstractmethod - def evaluate( - self, - **kwargs: Any, - ) -> None: - """ - Evaluate the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - if self._can_evaluate: - self._update_session_metadata_file() - self.is_eval = True - self._plot_av_reward_per_episode(learning_session=False) - _LOGGER.info("Finished evaluation") - - @abstractmethod - def _get_latest_checkpoint(self) -> None: - pass - - def load(self, path: Union[str, Path]) -> None: - """Load an agent from file.""" - md_dict, training_config_path, laydown_config_path = parse_session_metadata(path) - - # set training config path - self._training_config_path: Union[Path, str] = training_config_path - self._training_config: TrainingConfig = training_config.load(self._training_config_path) - self._lay_down_config_path: Union[Path, str] = laydown_config_path - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level - - # set random UUID for session - self._uuid = md_dict["uuid"] - - # set the session path - self.session_path = path - "The Session path" - - @property - def _saved_agent_path(self) -> Path: - file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip" - return self.learning_path / file_name - - @abstractmethod - def save(self) -> None: - """Save the agent.""" - pass - - @abstractmethod - def export(self) -> None: - """Export the agent to transportable file format.""" - pass - - def close(self) -> None: - """Closes the agent.""" - self._env.episode_av_reward_writer.close() # noqa - self._env.transaction_writer.close() # noqa - - def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None: - # self.close() - title = f"PrimAITE Session {self.timestamp_str} " - subtitle = str(self._training_config) - csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" - image_file = f"average_reward_per_episode_{self.timestamp_str}.png" - if learning_session: - title += "(Learning)" - path = self.learning_path / csv_file - image_path = self.learning_path / image_file - else: - title += "(Evaluation)" - path = self.evaluation_path / csv_file - image_path = self.evaluation_path / image_file - - fig = plot_av_reward_per_episode(path, title, subtitle) - fig.write_image(image_path) - _LOGGER.debug(f"Saved average rewards per episode plot to: {path}") diff --git a/src/primaite/agents/hardcoded_abc.py b/src/primaite/agents/hardcoded_abc.py deleted file mode 100644 index e75edbc5..00000000 --- a/src/primaite/agents/hardcoded_abc.py +++ /dev/null @@ -1,118 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -import time -from abc import abstractmethod -from pathlib import Path -from typing import Any, Optional, Union - -import numpy as np - -from primaite import getLogger -from primaite.agents.agent_abc import AgentSessionABC -from primaite.environment.primaite_env import Primaite - -_LOGGER = getLogger(__name__) - - -class HardCodedAgentSessionABC(AgentSessionABC): - """ - An Agent Session ABC for evaluation deterministic agents. - - This class cannot be directly instantiated and must be inherited from with all implemented abstract methods - implemented. - """ - - def __init__( - self, - training_config_path: Optional[Union[str, Path]] = "", - lay_down_config_path: Optional[Union[str, Path]] = "", - session_path: Optional[Union[str, Path]] = None, - ) -> None: - """ - Initialise a hardcoded agent session. - - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - """ - super().__init__(training_config_path, lay_down_config_path, session_path) - self._setup() - - def _setup(self) -> None: - self._env: Primaite = Primaite( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ) - super()._setup() - self._can_learn = False - self._can_evaluate = True - - def _save_checkpoint(self) -> None: - pass - - def _get_latest_checkpoint(self) -> None: - pass - - def learn( - self, - **kwargs: Any, - ) -> None: - """ - Train the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - _LOGGER.warning("Deterministic agents cannot learn") - - @abstractmethod - def _calculate_action(self, obs: np.ndarray) -> None: - pass - - def evaluate( - self, - **kwargs: Any, - ) -> None: - """ - Evaluate the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - self._env.set_as_eval() # noqa - self.is_eval = True - - time_steps = self._training_config.num_eval_steps - episodes = self._training_config.num_eval_episodes - - obs = self._env.reset() - for episode in range(episodes): - # Reset env and collect initial observation - for step in range(time_steps): - # Calculate action - action = self._calculate_action(obs) - - # Perform the step - obs, reward, done, info = self._env.step(action) - - if done: - break - - # Introduce a delay between steps - time.sleep(self._training_config.time_delay / 1000) - obs = self._env.reset() - self._env.close() - - @classmethod - def load(cls, path: Union[str, Path] = None) -> None: - """Load an agent from file.""" - _LOGGER.warning("Deterministic agents cannot be loaded") - - def save(self) -> None: - """Save the agent.""" - _LOGGER.warning("Deterministic agents cannot be saved") - - def export(self) -> None: - """Export the agent to transportable file format.""" - _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py deleted file mode 100644 index 2440da06..00000000 --- a/src/primaite/agents/hardcoded_acl.py +++ /dev/null @@ -1,515 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from typing import Dict, List, Union - -import numpy as np - -from primaite.acl.access_control_list import AccessControlList -from primaite.acl.acl_rule import ACLRule -from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC -from primaite.agents.utils import ( - get_new_action, - get_node_of_ip, - transform_action_acl_enum, - transform_change_obs_readable, -) -from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import HardCodedAgentView -from primaite.nodes.active_node import ActiveNode -from primaite.nodes.service_node import ServiceNode -from primaite.pol.ier import IER - - -class HardCodedACLAgent(HardCodedAgentSessionABC): - """An Agent Session class that implements a deterministic ACL agent.""" - - def _calculate_action(self, obs: np.ndarray) -> int: - if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: - # Basic view action using only the current observation - return self._calculate_action_basic_view(obs) - else: - # full view action using observation space, action - # history and reward feedback - return self._calculate_action_full_view(obs) - - def get_blocked_green_iers( - self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion] - ) -> Dict[str, IER]: - """Get blocked green IERs. - - :param green_iers: Green IERs to check for being - :type green_iers: Dict[str, IER] - :param acl: Firewall rules - :type acl: AccessControlList - :param nodes: Nodes in the network - :type nodes: Dict[str,NodeUnion] - :return: Same as `green_iers` input dict, but filtered to only contain the blocked ones. - :rtype: Dict[str, IER] - """ - blocked_green_iers = {} - - for green_ier_id, green_ier in green_iers.items(): - source_node_id = green_ier.get_source_node_id() - source_node_address = nodes[source_node_id].ip_address - dest_node_id = green_ier.get_dest_node_id() - dest_node_address = nodes[dest_node_id].ip_address - protocol = green_ier.get_protocol() # e.g. 'TCP' - port = green_ier.get_port() - - # Can be blocked by an ACL or by default (no allow rule exists) - if acl.is_blocked(source_node_address, dest_node_address, protocol, port): - blocked_green_iers[green_ier_id] = green_ier - - return blocked_green_iers - - def get_matching_acl_rules_for_ier( - self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] - ) -> Dict[int, ACLRule]: - """Get list of ACL rules which are relevant to an IER. - - :param ier: Information Exchange Request to query against the ACL list - :type ier: IER - :param acl: Firewall rules - :type acl: AccessControlList - :param nodes: Nodes in the network - :type nodes: Dict[str,NodeUnion] - :return: _description_ - :rtype: _type_ - """ - source_node_id = ier.get_source_node_id() - source_node_address = nodes[source_node_id].ip_address - dest_node_id = ier.get_dest_node_id() - dest_node_address = nodes[dest_node_id].ip_address - protocol = ier.get_protocol() # e.g. 'TCP' - port = ier.get_port() - matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) - return matching_rules - - def get_blocking_acl_rules_for_ier( - self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] - ) -> Dict[int, ACLRule]: - """ - Get blocking ACL rules for an IER. - - .. warning:: - Can return empty dict but IER can still be blocked by default - (No ALLOW rule, therefore blocked). - - :param ier: Information Exchange Request to query against the ACL list - :type ier: IER - :param acl: Firewall rules - :type acl: AccessControlList - :param nodes: Nodes in the network - :type nodes: Dict[str,NodeUnion] - :return: _description_ - :rtype: _type_ - """ - matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) - - blocked_rules = {} - for rule_key, rule_value in matching_rules.items(): - if rule_value.get_permission() == "DENY": - blocked_rules[rule_key] = rule_value - - return blocked_rules - - def get_allow_acl_rules_for_ier( - self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion] - ) -> Dict[int, ACLRule]: - """Get all allowing ACL rules for an IER. - - :param ier: Information Exchange Request to query against the ACL list - :type ier: IER - :param acl: Firewall rules - :type acl: AccessControlList - :param nodes: Nodes in the network - :type nodes: Dict[str,NodeUnion] - :return: _description_ - :rtype: _type_ - """ - matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) - - allowed_rules = {} - for rule_key, rule_value in matching_rules.items(): - if rule_value.get_permission() == "ALLOW": - allowed_rules[rule_key] = rule_value - - return allowed_rules - - def get_matching_acl_rules( - self, - source_node_id: str, - dest_node_id: str, - protocol: str, - port: str, - acl: AccessControlList, - nodes: Dict[str, Union[ServiceNode, ActiveNode]], - services_list: List[str], - ) -> Dict[int, ACLRule]: - """Filter ACL rules to only those which are relevant to the specified nodes. - - :param source_node_id: Source node - :type source_node_id: str - :param dest_node_id: Destination nodes - :type dest_node_id: str - :param protocol: Network protocol - :type protocol: str - :param port: Network port - :type port: str - :param acl: Access Control list which will be filtered - :type acl: AccessControlList - :param nodes: The environment's node directory. - :type nodes: Dict[str, Union[ServiceNode, ActiveNode]] - :param services_list: List of services registered for the environment. - :type services_list: List[str] - :return: Filtered version of 'acl' - :rtype: Dict[str, ACLRule] - """ - if source_node_id != "ANY": - source_node_address = nodes[str(source_node_id)].ip_address - else: - source_node_address = source_node_id - - if dest_node_id != "ANY": - dest_node_address = nodes[str(dest_node_id)].ip_address - else: - dest_node_address = dest_node_id - - if protocol != "ANY": - protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services - # TODO: This should throw an error because protocol is a string - - matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) - return matching_rules - - def get_allow_acl_rules( - self, - source_node_id: int, - dest_node_id: str, - protocol: int, - port: str, - acl: AccessControlList, - nodes: Dict[str, NodeUnion], - services_list: List[str], - ) -> Dict[int, ACLRule]: - """List ALLOW rules relating to specified nodes. - - :param source_node_id: Source node id - :type source_node_id: int - :param dest_node_id: Destination node - :type dest_node_id: str - :param protocol: Network protocol - :type protocol: int - :param port: Port - :type port: str - :param acl: Firewall ruleset which is applied to the network - :type acl: AccessControlList - :param nodes: The simulation's node store - :type nodes: Dict[str, NodeUnion] - :param services_list: Services list - :type services_list: List[str] - :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and - desination nodes - :rtype: Dict[str, ACLRule] - """ - matching_rules = self.get_matching_acl_rules( - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ) - - allowed_rules = {} - for rule_key, rule_value in matching_rules.items(): - if rule_value.get_permission() == "ALLOW": - allowed_rules[rule_key] = rule_value - - return allowed_rules - - def get_deny_acl_rules( - self, - source_node_id: int, - dest_node_id: str, - protocol: int, - port: str, - acl: AccessControlList, - nodes: Dict[str, NodeUnion], - services_list: List[str], - ) -> Dict[int, ACLRule]: - """List DENY rules relating to specified nodes. - - :param source_node_id: Source node id - :type source_node_id: int - :param dest_node_id: Destination node - :type dest_node_id: str - :param protocol: Network protocol - :type protocol: int - :param port: Port - :type port: str - :param acl: Firewall ruleset which is applied to the network - :type acl: AccessControlList - :param nodes: The simulation's node store - :type nodes: Dict[str, NodeUnion] - :param services_list: Services list - :type services_list: List[str] - :return: Filtered ACL Rule directory which includes only those rules which affect the specified source and - desination nodes - :rtype: Dict[str, ACLRule] - """ - matching_rules = self.get_matching_acl_rules( - source_node_id, - dest_node_id, - protocol, - port, - acl, - nodes, - services_list, - ) - - allowed_rules = {} - for rule_key, rule_value in matching_rules.items(): - if rule_value.get_permission() == "DENY": - allowed_rules[rule_key] = rule_value - - return allowed_rules - - def _calculate_action_full_view(self, obs: np.ndarray) -> int: - """ - Calculate a good acl-based action for the blue agent to take. - - Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: - - - Which ACL rules already exist, - otherwise: - - The agent would perminently get stuck in a loop of performing the same action over and over. - (best action is to block something, but its already blocked but doesn't know this) - - The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule, - if it doesnt know what rules exist) - - The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example - in the default config one of the green IERs is blocked by default, but it has no way of knowing this - based on the observation space. Additionally, potentially in the future, once a node state - has been fixed (no longer compromised), it needs a way to know it should reallow traffic. - A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this. - - There doesn't seem like there's much that can be done if an Operating or OS State is compromised - - If a service node becomes compromised there's a decision to make - do we block that service? - Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED - Cons: Will block a green IER, decreasing the reward - We decide to block the service. - - Potentially a better solution (for the reward) would be to block the incomming traffic from compromised - nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing - an overwhelmed state, so we don't do this. - - :param obs: current observation from the gym environment - :type obs: np.ndarray - :return: Optimal action to take in the environment (chosen from the discrete action space) - :rtype: int - """ - # obs = convert_to_old_obs(obs) - r_obs = transform_change_obs_readable(obs) - _, _, _, *s = r_obs - - if len(r_obs) == 4: # only 1 service - s = [*s] - - # 1. Check if node is compromised. If so we want to block its outwards services - # a. If it is comprimised check if there's an allow rule we should delete. - # cons: might delete a multi-rule from any source node (ANY -> x) - # b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate - # c. OPTIONAL (no allow rule = blocked): Add a DENY rule - found_action = False - for service_num, service_states in enumerate(s): - for x, service_state in enumerate(service_states): - if service_state == "COMPROMISED": - action_source_id = x + 1 # +1 as 0 is any - action_destination_id = "ANY" - action_protocol = service_num + 1 # +1 as 0 is any - action_port = "ANY" - - allow_rules = self.get_allow_acl_rules( - action_source_id, - action_destination_id, - action_protocol, - action_port, - self._env.acl, - self._env.nodes, - self._env.services_list, - ) - deny_rules = self.get_deny_acl_rules( - action_source_id, - action_destination_id, - action_protocol, - action_port, - self._env.acl, - self._env.nodes, - self._env.services_list, - ) - if len(allow_rules) > 0: - # Check if there's an allow rule we should delete - rule = list(allow_rules.values())[0] - action_decision = "DELETE" - action_permission = "ALLOW" - action_source_ip = rule.get_source_ip() - action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes)) - action_destination_ip = rule.get_dest_ip() - action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes)) - action_protocol_name = rule.get_protocol() - action_protocol = ( - self._env.services_list.index(action_protocol_name) + 1 - ) # convert name e.g. 'TCP' to index - action_port_name = rule.get_port() - action_port = ( - self._env.ports_list.index(action_port_name) + 1 - ) # convert port name e.g. '80' to index - - found_action = True - break - elif len(deny_rules) > 0: - # TODO OPTIONAL - # If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need - # to create another - # Check to see if the DENY rule really blocks everything (ANY) or just a specific rule - continue - else: - # TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked - action_decision = "CREATE" - action_permission = "DENY" - break - if found_action: - break - - # 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and - # add an Allow rule if the green IER is being blocked. - # a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule): - # If there's a DENY rule delete it if: - # - There isn't already a deny rule - # - It doesnt allows a comprimised node to become operational. - # b. Add an ALLOW rule if: - # - There isn't already an allow rule - # - It doesnt allows a comprimised node to become operational - - if not found_action: - # Which Green IERS are blocked - blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes) - for ier_key, ier in blocked_green_iers.items(): - # Which ALLOW rules are allowing this IER (none) - allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes) - - # If there are no blocking rules, it may be being blocked by default - # If there is already an allow rule - node_id_to_check = int(ier.get_source_node_id()) - service_name_to_check = ier.get_protocol() - service_id_to_check = self._env.services_list.index(service_name_to_check) - - # Service state of the the source node in the ier - service_state = s[service_id_to_check][node_id_to_check - 1] - - if len(allowing_rules) == 0 and service_state != "COMPROMISED": - action_decision = "CREATE" - action_permission = "ALLOW" - action_source_id = int(ier.get_source_node_id()) - action_destination_id = int(ier.get_dest_node_id()) - action_protocol_name = ier.get_protocol() - action_protocol = ( - self._env.services_list.index(action_protocol_name) + 1 - ) # convert name e.g. 'TCP' to index - action_port_name = ier.get_port() - action_port = ( - self._env.ports_list.index(action_port_name) + 1 - ) # convert port name e.g. '80' to index - - found_action = True - break - - if found_action: - action = [ - action_decision, - action_permission, - action_source_id, - action_destination_id, - action_protocol, - action_port, - ] - action = transform_action_acl_enum(action) - action = get_new_action(action, self._env.action_dict) - else: - # If no good/useful action has been found, just perform a nothing action - action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] - action = transform_action_acl_enum(action) - action = get_new_action(action, self._env.action_dict) - return action - - def _calculate_action_basic_view(self, obs: np.ndarray) -> int: - """ - Calculate a good acl-based action for the blue agent to take. - - Uses ONLY information from the current observation with NO knowledge - of previous actions taken and NO reward feedback. - - We rely on randomness to select the precise action, as we want to - block all traffic originating from a compromised node, without being - able to tell: - 1. Which ACL rules already exist - 2. Which actions the agent has already tried. - - There is a high probability that the correct rule will not be deleted - before the state becomes overwhelmed. - - Currently, a deny rule does not overwrite an allow rule. The allow - rules must be deleted. - - :param obs: current observation from the gym environment - :type obs: np.ndarray - :return: Optimal action to take in the environment (chosen from the discrete action space) - :rtype: int - """ - action_dict = self._env.action_dict - r_obs = transform_change_obs_readable(obs) - _, o, _, *s = r_obs - - if len(r_obs) == 4: # only 1 service - s = [*s] - - number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links) - for service_num, service_states in enumerate(s): - comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"] - if len(comprimised_states) == 0: - # No states are COMPROMISED, try the next service - continue - - compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any - action_decision = "DELETE" - action_permission = "ALLOW" - action_source_ip = compromised_node - # Randomly select a destination ID to block - action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"]) - action_destination_ip = ( - int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip - ) - action_protocol = service_num + 1 # +1 as 0 is any - # Randomly select a port to block - # Bad assumption that number of protocols equals number of ports - # AND no rules exist with an ANY port - action_port = np.random.choice(list(range(1, len(s) + 1))) - - action = [ - action_decision, - action_permission, - action_source_ip, - action_destination_ip, - action_protocol, - action_port, - ] - action = transform_action_acl_enum(action) - action = get_new_action(action, action_dict) - # We can only perform 1 action on each step - return action - - # If no good/useful action has been found, just perform a nothing action - nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] - nothing_action = transform_action_acl_enum(nothing_action) - nothing_action = get_new_action(nothing_action, action_dict) - return nothing_action diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py deleted file mode 100644 index b08d8967..00000000 --- a/src/primaite/agents/hardcoded_node.py +++ /dev/null @@ -1,125 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -import numpy as np - -from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC -from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable - - -class HardCodedNodeAgent(HardCodedAgentSessionABC): - """An Agent Session class that implements a deterministic Node agent.""" - - def _calculate_action(self, obs: np.ndarray) -> int: - """ - Calculate a good node-based action for the blue agent to take. - - :param obs: current observation from the gym environment - :type obs: np.ndarray - :return: Optimal action to take in the environment (chosen from the discrete action space) - :rtype: int - """ - action_dict = self._env.action_dict - r_obs = transform_change_obs_readable(obs) - _, o, os, *s = r_obs - - if len(r_obs) == 4: # only 1 service - s = [*s] - - # Check in order of most important states (order doesn't currently - # matter, but it probably should) - # First see if any OS states are compromised - for x, os_state in enumerate(os): - if os_state == "COMPROMISED": - action_node_id = x + 1 - action_node_property = "OS" - property_action = "PATCHING" - action_service_index = 0 # does nothing isn't relevant for os - action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - action = transform_action_node_enum(action) - action = get_new_action(action, action_dict) - # We can only perform 1 action on each step - return action - - # Next, see if any Services are compromised - # We fix the compromised state before overwhelemd state, - # If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored - for service_num, service in enumerate(s): - for x, service_state in enumerate(service): - if service_state == "COMPROMISED": - action_node_id = x + 1 - action_node_property = "SERVICE" - property_action = "PATCHING" - action_service_index = service_num - - action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - action = transform_action_node_enum(action) - action = get_new_action(action, action_dict) - # We can only perform 1 action on each step - return action - - # Next, See if any services are overwhelmed - # perhaps this should be fixed automatically when the compromised PCs issues are also resolved - # Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs - - for service_num, service in enumerate(s): - for x, service_state in enumerate(service): - if service_state == "OVERWHELMED": - action_node_id = x + 1 - action_node_property = "SERVICE" - property_action = "PATCHING" - action_service_index = service_num - - action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - action = transform_action_node_enum(action) - action = get_new_action(action, action_dict) - # We can only perform 1 action on each step - return action - - # Finally, turn on any off nodes - for x, operating_state in enumerate(o): - if os_state == "OFF": - action_node_id = x + 1 - action_node_property = "OPERATING" - property_action = "ON" # Why reset it when we can just turn it on - action_service_index = 0 # does nothing isn't relevant for operating state - action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - # TODO: transform_action_node_enum takes only one argument, not sure why two are given here. - action = transform_action_node_enum(action, action_dict) - action = get_new_action(action, action_dict) - # We can only perform 1 action on each step - return action - - # If no good actions, just go with an action that wont do any harm - action_node_id = 1 - action_node_property = "NONE" - property_action = "ON" - action_service_index = 0 - action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - action = transform_action_node_enum(action) - action = get_new_action(action, action_dict) - - return action diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py deleted file mode 100644 index 96bb0737..00000000 --- a/src/primaite/agents/rllib.py +++ /dev/null @@ -1,287 +0,0 @@ -# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -# from __future__ import annotations - -# import json -# import shutil -# import zipfile -# from datetime import datetime -# from logging import Logger -# from pathlib import Path -# from typing import Any, Callable, Dict, Optional, Union -# from uuid import uuid4 - -# from primaite import getLogger -# from primaite.agents.agent_abc import AgentSessionABC -# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType -# from primaite.environment.primaite_env import Primaite - -# # from ray.rllib.algorithms import Algorithm -# # from ray.rllib.algorithms.a2c import A2CConfig -# # from ray.rllib.algorithms.ppo import PPOConfig -# # from ray.tune.logger import UnifiedLogger -# # from ray.tune.registry import register_env - - -# # from primaite.exceptions import RLlibAgentError - -# _LOGGER: Logger = getLogger(__name__) - - -# # TODO: verify type of env_config -# def _env_creator(env_config: Dict[str, Any]) -> Primaite: -# return Primaite( -# training_config_path=env_config["training_config_path"], -# lay_down_config_path=env_config["lay_down_config_path"], -# session_path=env_config["session_path"], -# timestamp_str=env_config["timestamp_str"], -# ) - -# # # TODO: verify type hint return type -# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]: -# # logdir = session_path / "ray_results" -# # logdir.mkdir(parents=True, exist_ok=True) - -# # def logger_creator(config: Dict) -> UnifiedLogger: -# # return UnifiedLogger(config, logdir, loggers=None) - -# return logger_creator - - -# # class RLlibAgent(AgentSessionABC): -# # """An AgentSession class that implements a Ray RLlib agent.""" - -# # def __init__( -# # self, -# # training_config_path: Optional[Union[str, Path]] = "", -# # lay_down_config_path: Optional[Union[str, Path]] = "", -# # session_path: Optional[Union[str, Path]] = None, -# # ) -> None: -# # """ -# # Initialise the RLLib Agent training session. - -# # :param training_config_path: YAML file containing configurable items defined in -# # `primaite.config.training_config.TrainingConfig` -# # :type training_config_path: Union[path, str] -# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown. -# # :type lay_down_config_path: Union[path, str] -# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB") -# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO` -# # or `A2C`) -# # """ -# # # TODO: implement RLlib agent loading -# # if session_path is not None: -# # msg = "RLlib agent loading has not been implemented yet" -# # _LOGGER.critical(msg) -# # raise NotImplementedError(msg) - -# # super().__init__(training_config_path, lay_down_config_path) -# # if self._training_config.session_type == SessionType.EVAL: -# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." -# # _LOGGER.critical(msg) -# # raise RLlibAgentError(msg) -# # if not self._training_config.agent_framework == AgentFramework.RLLIB: -# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" -# # _LOGGER.error(msg) -# # raise ValueError(msg) -# # self._agent_config_class: Union[PPOConfig, A2CConfig] -# # if self._training_config.agent_identifier == AgentIdentifier.PPO: -# # self._agent_config_class = PPOConfig -# # elif self._training_config.agent_identifier == AgentIdentifier.A2C: -# # self._agent_config_class = A2CConfig -# # else: -# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" -# # _LOGGER.error(msg) -# # raise ValueError(msg) -# # self._agent_config: Union[PPOConfig, A2CConfig] - -# # self._current_result: dict -# # self._setup() -# # _LOGGER.debug( -# # f"Created {self.__class__.__name__} using: " -# # f"agent_framework={self._training_config.agent_framework}, " -# # f"agent_identifier=" -# # f"{self._training_config.agent_identifier}, " -# # f"deep_learning_framework=" -# # f"{self._training_config.deep_learning_framework}" -# # ) -# # self._train_agent = None # Required to capture the learning agent to close after eval - -# # def _update_session_metadata_file(self) -> None: -# # """ -# # Update the ``session_metadata.json`` file. - -# # Updates the `session_metadata.json`` in the ``session_path`` directory -# # with the following key/value pairs: - -# # - end_datetime: The date & time the session ended in iso format. -# # - total_episodes: The total number of training episodes completed. -# # - total_time_steps: The total number of training time steps completed. -# # """ -# # with open(self.session_path / "session_metadata.json", "r") as file: -# # metadata_dict = json.load(file) - -# # metadata_dict["end_datetime"] = datetime.now().isoformat() -# # if not self.is_eval: -# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa -# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa -# # else: -# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa -# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa - -# # filepath = self.session_path / "session_metadata.json" -# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}") -# # with open(filepath, "w") as file: -# # json.dump(metadata_dict, file) -# # _LOGGER.debug("Finished updating session metadata file") - -# # def _setup(self) -> None: -# # super()._setup() -# # register_env("primaite", _env_creator) -# # self._agent_config = self._agent_config_class() - -# # self._agent_config.environment( -# # env="primaite", -# # env_config=dict( -# # training_config_path=self._training_config_path, -# # lay_down_config_path=self._lay_down_config_path, -# # session_path=self.session_path, -# # timestamp_str=self.timestamp_str, -# # ), -# # ) -# # self._agent_config.seed = self._training_config.seed - -# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps) -# # self._agent_config.framework(framework="tf") - -# # self._agent_config.rollouts( -# # num_rollout_workers=1, -# # num_envs_per_worker=1, -# # horizon=self._training_config.num_train_steps, -# # ) -# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) - -# # def _save_checkpoint(self) -> None: -# # checkpoint_n = self._training_config.checkpoint_every_n_episodes -# # episode_count = self._current_result["episodes_total"] -# # save_checkpoint = False -# # if checkpoint_n: -# # save_checkpoint = episode_count % checkpoint_n == 0 -# # if episode_count and save_checkpoint: -# # self._agent.save(str(self.checkpoints_path)) - -# # def learn( -# # self, -# # **kwargs: Any, -# # ) -> None: -# # """ -# # Evaluate the agent. - -# # :param kwargs: Any agent-specific key-word args to be passed. -# # """ -# # time_steps = self._training_config.num_train_steps -# # episodes = self._training_config.num_train_episodes - -# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") -# # for i in range(episodes): -# # self._current_result = self._agent.train() -# # self._save_checkpoint() -# # self.save() -# # super().learn() -# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped -# # if self._training_config.session_type is not SessionType.TRAIN: -# # self._train_agent = self._agent -# # else: -# # self._agent.stop() -# # self._plot_av_reward_per_episode(learning_session=True) - -# # def _unpack_saved_agent_into_eval(self) -> Path: -# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" -# # agent_restore_path = self.evaluation_path / "agent_restore" -# # if agent_restore_path.exists(): -# # shutil.rmtree(agent_restore_path) -# # agent_restore_path.mkdir() -# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file: -# # zip_file.extractall(agent_restore_path) -# # return agent_restore_path - -# # def _setup_eval(self): -# # self._can_learn = False -# # self._can_evaluate = True -# # self._agent.restore(str(self._unpack_saved_agent_into_eval())) - -# # def evaluate( -# # self, -# # **kwargs, -# # ): -# # """ -# # Evaluate the agent. - -# # :param kwargs: Any agent-specific key-word args to be passed. -# # """ -# # time_steps = self._training_config.num_eval_steps -# # episodes = self._training_config.num_eval_episodes - -# # self._setup_eval() - -# # self._env: Primaite = Primaite( -# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str -# # ) - -# # self._env.set_as_eval() -# # self.is_eval = True -# # if self._training_config.deterministic: -# # deterministic_str = "deterministic" -# # else: -# # deterministic_str = "non-deterministic" -# # _LOGGER.info( -# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." -# # ) -# # for episode in range(episodes): -# # obs = self._env.reset() -# # for step in range(time_steps): -# # action = self._agent.compute_single_action(observation=obs, explore=False) - -# # obs, rewards, done, info = self._env.step(action) - -# # self._env.reset() -# # self._env.close() -# # super().evaluate() -# # # Now we're safe to close the learning agent and write the mean rewards per episode for it -# # if self._training_config.session_type is not SessionType.TRAIN: -# # self._train_agent.stop() -# # self._plot_av_reward_per_episode(learning_session=True) -# # # Perform a clean-up of the unpacked agent -# # if (self.evaluation_path / "agent_restore").exists(): -# # shutil.rmtree((self.evaluation_path / "agent_restore")) - -# # def _get_latest_checkpoint(self) -> None: -# # raise NotImplementedError - -# # @classmethod -# # def load(cls, path: Union[str, Path]) -> RLlibAgent: -# # """Load an agent from file.""" -# # raise NotImplementedError - -# # def save(self, overwrite_existing: bool = True) -> None: -# # """Save the agent.""" -# # # Make temp dir to save in isolation -# # temp_dir = self.learning_path / str(uuid4()) -# # temp_dir.mkdir() - -# # # Save the agent to the temp dir -# # self._agent.save(str(temp_dir)) - -# # # Capture the saved Rllib checkpoint inside the temp directory -# # for file in temp_dir.iterdir(): -# # checkpoint_dir = file -# # break - -# # # Zip the folder -# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa - -# # # Drop the temp directory -# # shutil.rmtree(temp_dir) - -# # def export(self) -> None: -# # """Export the agent to transportable file format.""" -# # raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py deleted file mode 100644 index 92c5ee5f..00000000 --- a/src/primaite/agents/sb3.py +++ /dev/null @@ -1,206 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from __future__ import annotations - -import json -from logging import Logger -from pathlib import Path -from typing import Any, Optional, Union - -import numpy as np -from stable_baselines3 import A2C, PPO -from stable_baselines3.ppo import MlpPolicy as PPOMlp - -from primaite import getLogger -from primaite.agents.agent_abc import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier -from primaite.environment.primaite_env import Primaite - -_LOGGER: Logger = getLogger(__name__) - - -class SB3Agent(AgentSessionABC): - """An AgentSession class that implements a Stable Baselines3 agent.""" - - def __init__( - self, - training_config_path: Optional[Union[str, Path]] = None, - lay_down_config_path: Optional[Union[str, Path]] = None, - session_path: Optional[Union[str, Path]] = None, - legacy_training_config: bool = False, - legacy_lay_down_config: bool = False, - ) -> None: - """ - Initialise the SB3 Agent training session. - - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - :param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0, - otherwise False. - :param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0, - otherwise False. - :raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3") - :raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO` - or `A2C`) - """ - super().__init__( - training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config - ) - if not self._training_config.agent_framework == AgentFramework.SB3: - msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" - _LOGGER.error(msg) - raise ValueError(msg) - self._agent_class: Union[PPO, A2C] - if self._training_config.agent_identifier == AgentIdentifier.PPO: - self._agent_class = PPO - elif self._training_config.agent_identifier == AgentIdentifier.A2C: - self._agent_class = A2C - else: - msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}" - _LOGGER.error(msg) - raise ValueError(msg) - - self._tensorboard_log_path = self.learning_path / "tensorboard_logs" - self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) - - _LOGGER.debug( - f"Created {self.__class__.__name__} using: " - f"agent_framework={self._training_config.agent_framework}, " - f"agent_identifier=" - f"{self._training_config.agent_identifier}" - ) - - self.is_eval = False - - self._setup() - - def _setup(self) -> None: - """Set up the SB3 Agent.""" - self._env = Primaite( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - legacy_training_config=self.legacy_training_config, - legacy_lay_down_config=self.legacy_lay_down_config, - ) - - # check if there is a zip file that needs to be loaded - load_file = next(self.session_path.rglob("*.zip"), None) - - if not load_file: - # create a new env and agent - - self._agent = self._agent_class( - PPOMlp, - self._env, - verbose=self.sb3_output_verbose_level, - n_steps=self._training_config.num_train_steps, - tensorboard_log=str(self._tensorboard_log_path), - seed=self._training_config.seed, - ) - else: - # set env values from session metadata - with open(self.session_path / "session_metadata.json", "r") as file: - md_dict = json.load(file) - - # load environment values - if self.is_eval: - # evaluation always starts at 0 - self._env.episode_count = 0 - self._env.total_step_count = 0 - else: - # carry on from previous learning sessions - self._env.episode_count = md_dict["learning"]["total_episodes"] - self._env.total_step_count = md_dict["learning"]["total_time_steps"] - - # load the file - self._agent = self._agent_class.load(load_file, env=self._env) - - # set agent values - self._agent.verbose = self.sb3_output_verbose_level - self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs" - - super()._setup() - - def _save_checkpoint(self) -> None: - checkpoint_n = self._training_config.checkpoint_every_n_episodes - episode_count = self._env.episode_count - save_checkpoint = False - if checkpoint_n: - save_checkpoint = episode_count % checkpoint_n == 0 - if episode_count and save_checkpoint: - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - self._agent.save(checkpoint_path) - _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") - - def _get_latest_checkpoint(self) -> None: - pass - - def learn( - self, - **kwargs: Any, - ) -> None: - """ - Train the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - time_steps = self._training_config.num_train_steps - episodes = self._training_config.num_train_episodes - self.is_eval = False - _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") - for i in range(episodes): - self._agent.learn(total_timesteps=time_steps) - self._save_checkpoint() - self._env._write_av_reward_per_episode() # noqa - self.save() - self._env.close() - super().learn() - - # save agent - self.save() - - self._plot_av_reward_per_episode(learning_session=True) - - def evaluate( - self, - **kwargs: Any, - ) -> None: - """ - Evaluate the agent. - - :param kwargs: Any agent-specific key-word args to be passed. - """ - time_steps = self._training_config.num_eval_steps - episodes = self._training_config.num_eval_episodes - self._env.set_as_eval() - self.is_eval = True - if self._training_config.deterministic: - deterministic_str = "deterministic" - else: - deterministic_str = "non-deterministic" - _LOGGER.info( - f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." - ) - for episode in range(episodes): - obs = self._env.reset() - - for step in range(time_steps): - action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic) - if isinstance(action, np.ndarray): - action = np.int64(action) - obs, rewards, done, info = self._env.step(action) - self._env._write_av_reward_per_episode() # noqa - self._env.close() - super().evaluate() - - def save(self) -> None: - """Save the agent.""" - self._agent.save(self._saved_agent_path) - - def export(self) -> None: - """Export the agent to transportable file format.""" - raise NotImplementedError diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py deleted file mode 100644 index bfdff869..00000000 --- a/src/primaite/agents/simple.py +++ /dev/null @@ -1,59 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - -import numpy as np - -from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC -from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum - - -class RandomAgent(HardCodedAgentSessionABC): - """ - A Random Agent. - - Get a completely random action from the action space. - """ - - def _calculate_action(self, obs: np.ndarray) -> int: - return self._env.action_space.sample() - - -class DummyAgent(HardCodedAgentSessionABC): - """ - A Dummy Agent. - - All action spaces setup so dummy action is always 0 regardless of action type used. - """ - - def _calculate_action(self, obs: np.ndarray) -> int: - return 0 - - -class DoNothingACLAgent(HardCodedAgentSessionABC): - """ - A do nothing ACL agent. - - A valid ACL action that has no effect; does nothing. - """ - - def _calculate_action(self, obs: np.ndarray) -> int: - nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] - nothing_action = transform_action_acl_enum(nothing_action) - nothing_action = get_new_action(nothing_action, self._env.action_dict) - - return nothing_action - - -class DoNothingNodeAgent(HardCodedAgentSessionABC): - """ - A do nothing Node agent. - - A valid Node action that has no effect; does nothing. - """ - - def _calculate_action(self, obs: np.ndarray) -> int: - nothing_action = [1, "NONE", "ON", 0] - nothing_action = transform_action_node_enum(nothing_action) - nothing_action = get_new_action(nothing_action, self._env.action_dict) - # nothing_action should currently always be 0 - - return nothing_action diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py deleted file mode 100644 index 08d46294..00000000 --- a/src/primaite/agents/utils.py +++ /dev/null @@ -1,450 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from typing import Dict, List, Union - -import numpy as np - -from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( - HardwareState, - LinkStatus, - NodeHardwareAction, - NodePOLType, - NodeSoftwareAction, - SoftwareState, -) - - -def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]: - """Convert a node action from enumerated format to readable format. - - example: - [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - - :param action: Agent action, formatted as a list of ints, for more information check out - `primaite.environment.primaite_env.Primaite` - :type action: List[int] - :return: The same action list, but with the encodings translated back into meaningful labels - :rtype: List[Union[int,str]] - """ - action_node_property = NodePOLType(action[1]).name - - if action_node_property == "OPERATING": - property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: - property_action = NodeSoftwareAction(action[2]).name - else: - property_action = "NONE" - - new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]] - return new_action - - -def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]: - """ - Transform an ACL action to a more readable format. - - example: - [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] - - :param action: Agent action, formatted as a list of ints, for more information check out - `primaite.environment.primaite_env.Primaite` - :type action: List[int] - :return: The same action list, but with the encodings translated back into meaningful labels - :rtype: List[Union[int,str]] - """ - action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} - action_permissions = {0: "DENY", 1: "ALLOW"} - - action_decision = action_decisions[action[0]] - action_permission = action_permissions[action[1]] - - # For IPs, Ports and Protocols, 0 means any, otherwise its just an index - new_action = [action_decision, action_permission] + list(action[2:6]) - for n, val in enumerate(list(action[2:6])): - if val == 0: - new_action[n + 2] = "ANY" - - return new_action - - -def is_valid_node_action(action: List[int]) -> bool: - """ - Is the node action an actual valid action. - - Only uses information about the action to determine if the action has an effect - - Does NOT consider: - - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - - Node already being in that state (turning an ON node ON) - - :param action: Agent action, formatted as a list of ints, for more information check out - `primaite.environment.primaite_env.Primaite` - :type action: List[int] - :return: Whether the action is valid - :rtype: bool - """ - action_r = transform_action_node_readable(action) - - node_property = action_r[1] - node_action = action_r[2] - - # print("node property", node_property, "\nnode action", node_action) - - if node_property == "NONE": - return False - if node_action == "NONE": - return False - if node_property == "OPERATING" and node_action == "PATCHING": - # Operating State cannot PATCH - return False - if node_property != "OPERATING" and node_action not in [ - "NONE", - "PATCHING", - ]: - # Software States can only do Nothing or Patch - return False - return True - - -def is_valid_acl_action(action: List[int]) -> bool: - """ - Is the ACL action an actual valid action. - - Only uses information about the action to determine if the action has an effect. - - Does NOT consider: - - Trying to create identical rules - - Trying to create a rule which is a subset of another rule (caused by "ANY") - - :param action: Agent action, formatted as a list of ints, for more information check out - `primaite.environment.primaite_env.Primaite` - :type action: List[int] - :return: Whether the action is valid - :rtype: bool - """ - action_r = transform_action_acl_readable(action) - - action_decision = action_r[0] - action_permission = action_r[1] - action_source_id = action_r[2] - action_destination_id = action_r[3] - - if action_decision == "NONE": - return False - if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": - # ACL rule towards itself - return False - if action_permission == "DENY": - # DENY is unnecessary, we can create and delete allow rules instead - # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. - return False - - return True - - -def is_valid_acl_action_extra(action: List[int]) -> bool: - """ - Harsher version of valid acl actions, does not allow action. - - :param action: Agent action, formatted as a list of ints, for more information check out - `primaite.environment.primaite_env.Primaite` - :type action: List[int] - :return: Whether the action is valid - :rtype: bool - """ - if is_valid_acl_action(action) is False: - return False - - action_r = transform_action_acl_readable(action) - action_protocol = action_r[4] - action_port = action_r[5] - - # Don't allow protocols or ports to be ANY - # in the future we might want to do the opposite, and only have ANY option for ports and service - if action_protocol == "ANY": - return False - if action_port == "ANY": - return False - - return True - - -def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: - """Transform list of transactions to readable list of each observation property. - - example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] - - :param obs: Raw observation from the environment. - :type obs: np.ndarray - :return: The same observation, but the encoded integer values are replaced with readable names. - :rtype: List[List[Union[str, int]]] - """ - ids = [i for i in obs[:, 0]] - operating_states = [HardwareState(i).name for i in obs[:, 1]] - os_states = [SoftwareState(i).name for i in obs[:, 2]] - new_obs = [ids, operating_states, os_states] - - for service in range(4, obs.shape[1]): - # Links bit/s don't have a service state - service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] - new_obs.append(service_states) - - return new_obs - - -def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]: - """Transform observation to readable format. - - example - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] - - :param obs: Raw observation from the environment. - :type obs: np.ndarray - :return: The same observation, but the encoded integer values are replaced with readable names. - :rtype: List[List[Union[str, int]]] - """ - changed_obs = transform_change_obs_readable(obs) - new_obs = list(zip(*changed_obs)) - # Convert list of tuples to list of lists - new_obs = [list(i) for i in new_obs] - - return new_obs - - -def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray: - """Convert original gym Box observation space to new multiDiscrete observation space. - - :param obs: observation in the 'old' (NodeLinkTable) format - :type obs: np.ndarray - :param num_nodes: number of nodes in the network, defaults to 10 - :type num_nodes: int, optional - :return: reformatted observation - :rtype: np.ndarray - """ - # Remove ID columns, remove links and flatten to MultiDiscrete observation space - new_obs = obs[:num_nodes, 1:].flatten() - return new_obs - - -def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray: - """Convert to old observation. - - Links filled with 0's as no information is included in new observation space. - - example: - obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) - - new_obs = array([[ 1, 1, 1, 1], - [ 2, 1, 1, 1], - [ 3, 1, 1, 1], - ... - [20, 0, 0, 0]]) - - :param obs: observation in the 'new' (MultiDiscrete) format - :type obs: np.ndarray - :param num_nodes: number of nodes in the network, defaults to 10 - :type num_nodes: int, optional - :param num_links: number of links in the network, defaults to 10 - :type num_links: int, optional - :param num_services: number of services on the network, defaults to 1 - :type num_services: int, optional - :return: 2-d BOX observation space, in the same format as NodeLinkTable - :rtype: np.ndarray - """ - # Convert back to more readable, original format - reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) - - # Add empty links back and add node ID back - s = np.zeros( - [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], - dtype=np.int64, - ) - s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back - s[:num_nodes, 1:] = reshaped_nodes # put values back in - new_obs = s - - # Add links back in - links = obs[-num_links:] - # Links will be added to the last protocol/service slot but they are not specific to that service - new_obs[num_nodes:, -1] = links - - return new_obs - - -def describe_obs_change( - obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1 -) -> str: - """Build a string describing the difference between two observations. - - example: - obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) - obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) - output = 'ID 1: SERVICE 2 set to GOOD' - - :param obs1: First observation - :type obs1: np.ndarray - :param obs2: Second observation - :type obs2: np.ndarray - :param num_nodes: How many nodes are in the network laydown, defaults to 10 - :type num_nodes: int, optional - :param num_links: How many links are in the network laydown, defaults to 10 - :type num_links: int, optional - :param num_services: How many services are configured for this scenario, defaults to 1 - :type num_services: int, optional - :return: A multi-line string with a human-readable description of the difference. - :rtype: str - """ - obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) - obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) - list_of_changes = [] - for n, row in enumerate(obs1 - obs2): - if row.any() != 0: - relevant_changes = np.where(row != 0, obs2[n], -1) - relevant_changes[0] = obs2[n, 0] # ID is always relevant - is_link = relevant_changes[0] > num_nodes - desc = _describe_obs_change_helper(relevant_changes, is_link) - list_of_changes.append(desc) - - change_string = "\n ".join(list_of_changes) - if len(list_of_changes) > 0: - change_string = "\n " + change_string - return change_string - - -def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str: - """ - Helper funcion to describe what has changed. - - example: - [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" - - Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' - - :param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one - row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new - status where it has changed. - :type obs_change: List[int] - :param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node. - :type is_link: bool - :return: A human-readable description of the difference between the two observation rows. - :rtype: str - """ - # Indexes where a change has occured, not including 0th index - index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] - # Node pol types, Indexes >= 3 are service nodes - NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed] - # Account for hardware states, software sattes and links - states = [ - LinkStatus(obs_change[i]).name - if is_link - else HardwareState(obs_change[i]).name - if i == 1 - else SoftwareState(obs_change[i]).name - for i in index_changed - ] - - if not is_link: - desc = f"ID {obs_change[0]}:" - for node_pol_type, state in list(zip(NodePOLTypes, states)): - desc = desc + " " + node_pol_type + " changed to " + state + "." - else: - desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." - - return desc - - -def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]: - """Convert a node action from readable string format, to enumerated format. - - example: - [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - :param action: Action in 'readable' format - :type action: List[Union[str,int]] - :return: Action with verbs encoded as ints - :rtype: List[int] - """ - action_node_id = action[0] - action_node_property = NodePOLType[action[1]].value - - if action[1] == "OPERATING": - property_action = NodeHardwareAction[action[2]].value - elif action[1] == "OS" or action[1] == "SERVICE": - property_action = NodeSoftwareAction[action[2]].value - else: - property_action = 0 - - action_service_index = action[3] - - new_action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - - return new_action - - -def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray: - """ - Convert acl action from readable str format, to enumerated format. - - :param action: ACL-based action expressed as a list of human-readable ints and strings - :type action: List[Union[int,str]] - :return: The same action but encoded to contain only integers. - :rtype: np.ndarray - """ - action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} - action_permissions = {"DENY": 0, "ALLOW": 1} - - action_decision = action_decisions[action[0]] - action_permission = action_permissions[action[1]] - - # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index - new_action = [action_decision, action_permission] + list(action[2:6]) - for n, val in enumerate(list(action[2:6])): - if val == "ANY": - new_action[n + 2] = 0 - - new_action = np.array(new_action) - return new_action - - -def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str: - """Get the node ID of an IP address. - - node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - - :param ip: The IP address of the node whose ID is required - :type ip: str - :param node_dict: The environment's node registry dictionary - :type node_dict: Dict[str,NodeUnion] - :return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip` - :rtype: str - """ - for node_key, node_value in node_dict.items(): - node_ip = node_value.ip_address - if node_ip == ip: - return node_key - - -def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int: - """ - Get new action (e.g. 32) from old action e.g. [1,1,1,0]. - - Old_action can be either node or acl action type - - :param old_action: Action expressed as a list of choices, eg. [1,1,1,0] - :type old_action: np.ndarray - :param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions. - :type action_dict: Dict[int,List] - :return: Action key correspoinding to the input `old_action` - :rtype: int - """ - for key, val in action_dict.items(): - if list(val) == list(old_action): - return key - # Not all possible actions are included in dict, only valid action are - # if action is not in the dict, its an invalid action so return 0 - return 0 diff --git a/src/primaite/common/__init__.py b/src/primaite/common/__init__.py deleted file mode 100644 index 5770bcbc..00000000 --- a/src/primaite/common/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Objects which are shared between many PrimAITE modules.""" diff --git a/src/primaite/common/custom_typing.py b/src/primaite/common/custom_typing.py deleted file mode 100644 index 4130e71a..00000000 --- a/src/primaite/common/custom_typing.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Union - -from primaite.nodes.active_node import ActiveNode -from primaite.nodes.passive_node import PassiveNode -from primaite.nodes.service_node import ServiceNode - -NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode] -"""A Union of ActiveNode, PassiveNode, and ServiceNode.""" diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py deleted file mode 100644 index c33e764b..00000000 --- a/src/primaite/common/enums.py +++ /dev/null @@ -1,208 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Enumerations for APE.""" - -from enum import Enum, IntEnum - - -class NodeType(Enum): - """Node type enumeration.""" - - CCTV = 1 - SWITCH = 2 - COMPUTER = 3 - LINK = 4 - MONITOR = 5 - PRINTER = 6 - LOP = 7 - RTU = 8 - ACTUATOR = 9 - SERVER = 10 - - -class Priority(Enum): - """Node priority enumeration.""" - - P1 = 1 - P2 = 2 - P3 = 3 - P4 = 4 - P5 = 5 - - -class HardwareState(Enum): - """Node hardware state enumeration.""" - - NONE = 0 - ON = 1 - OFF = 2 - RESETTING = 3 - SHUTTING_DOWN = 4 - BOOTING = 5 - - -class SoftwareState(Enum): - """Software or Service state enumeration.""" - - NONE = 0 - GOOD = 1 - PATCHING = 2 - COMPROMISED = 3 - OVERWHELMED = 4 - - -class NodePOLType(Enum): - """Node Pattern of Life type enumeration.""" - - NONE = 0 - OPERATING = 1 - OS = 2 - SERVICE = 3 - FILE = 4 - - -class NodePOLInitiator(Enum): - """Node Pattern of Life initiator enumeration.""" - - DIRECT = 1 - IER = 2 - SERVICE = 3 - - -class Protocol(Enum): - """Service protocol enumeration.""" - - LDAP = 0 - FTP = 1 - HTTPS = 2 - SMTP = 3 - RTP = 4 - IPP = 5 - TCP = 6 - NONE = 7 - - -class SessionType(Enum): - """The type of PrimAITE Session to be run.""" - - TRAIN = 1 - "Train an agent" - EVAL = 2 - "Evaluate an agent" - TRAIN_EVAL = 3 - "Train then evaluate an agent" - - -class AgentFramework(Enum): - """The agent algorithm framework/package.""" - - CUSTOM = 0 - "Custom Agent" - SB3 = 1 - "Stable Baselines3" - # RLLIB = 2 - # "Ray RLlib" - - -class DeepLearningFramework(Enum): - """The deep learning framework.""" - - TF = "tf" - "Tensorflow" - TF2 = "tf2" - "Tensorflow 2.x" - TORCH = "torch" - "PyTorch" - - -class AgentIdentifier(Enum): - """The Red Agent algo/class.""" - - A2C = 1 - "Advantage Actor Critic" - PPO = 2 - "Proximal Policy Optimization" - HARDCODED = 3 - "The Hardcoded agents" - DO_NOTHING = 4 - "The DoNothing agents" - RANDOM = 5 - "The RandomAgent" - DUMMY = 6 - "The DummyAgent" - - -class HardCodedAgentView(Enum): - """The view the deterministic hard-coded agent has of the environment.""" - - BASIC = 1 - "The current observation space only" - FULL = 2 - "Full environment view with actions taken and reward feedback" - - -class ActionType(Enum): - """Action type enumeration.""" - - NODE = 0 - ACL = 1 - ANY = 2 - - -# TODO: this is not used anymore, write a ticket to delete it. -class ObservationType(Enum): - """Observation type enumeration.""" - - BOX = 0 - MULTIDISCRETE = 1 - - -class FileSystemState(Enum): - """File System State.""" - - GOOD = 1 - CORRUPT = 2 - DESTROYED = 3 - REPAIRING = 4 - RESTORING = 5 - - -class NodeHardwareAction(Enum): - """Node hardware action.""" - - NONE = 0 - ON = 1 - OFF = 2 - RESET = 3 - - -class NodeSoftwareAction(Enum): - """Node software action.""" - - NONE = 0 - PATCHING = 1 - - -class LinkStatus(Enum): - """Link traffic status.""" - - NONE = 0 - LOW = 1 - MEDIUM = 2 - HIGH = 3 - OVERLOAD = 4 - - -class SB3OutputVerboseLevel(IntEnum): - """The Stable Baselines3 learn/eval output verbosity level.""" - - NONE = 0 - INFO = 1 - DEBUG = 2 - - -class RulePermissionType(Enum): - """Any firewall rule type.""" - - NONE = 0 - DENY = 1 - ALLOW = 2 diff --git a/src/primaite/common/protocol.py b/src/primaite/common/protocol.py deleted file mode 100644 index 6940ba3f..00000000 --- a/src/primaite/common/protocol.py +++ /dev/null @@ -1,47 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""The protocol class.""" - - -class Protocol(object): - """Protocol class.""" - - def __init__(self, _name: str) -> None: - """ - Initialise a protocol. - - :param _name: The name of the protocol - :type _name: str - """ - self.name: str = _name - self.load: int = 0 # bps - - def get_name(self) -> str: - """ - Gets the protocol name. - - Returns: - The protocol name - """ - return self.name - - def get_load(self) -> int: - """ - Gets the protocol load. - - Returns: - The protocol load (bps) - """ - return self.load - - def add_load(self, _load: int) -> None: - """ - Adds load to the protocol. - - Args: - _load: The load to add - """ - self.load += _load - - def clear_load(self) -> None: - """Clears the load on this protocol.""" - self.load = 0 diff --git a/src/primaite/common/service.py b/src/primaite/common/service.py deleted file mode 100644 index 956815e8..00000000 --- a/src/primaite/common/service.py +++ /dev/null @@ -1,28 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""The Service class.""" - -from primaite.common.enums import SoftwareState - - -class Service(object): - """Service class.""" - - def __init__(self, name: str, port: str, software_state: SoftwareState) -> None: - """ - Initialise a service. - - :param name: The service name. - :param port: The service port. - :param software_state: The service SoftwareState. - """ - self.name: str = name - self.port: str = port - self.software_state: SoftwareState = software_state - self.patching_count: int = 0 - - def reduce_patching_count(self) -> None: - """Reduces the patching count for the service.""" - self.patching_count -= 1 - if self.patching_count <= 0: - self.patching_count = 0 - self.software_state = SoftwareState.GOOD diff --git a/src/primaite/config/__init__.py b/src/primaite/config/__init__.py deleted file mode 100644 index 92f5a7d2..00000000 --- a/src/primaite/config/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Configuration parameters for running experiments.""" diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml deleted file mode 100644 index dad0ff4b..00000000 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml +++ /dev/null @@ -1,166 +0,0 @@ -- item_type: PORTS - ports_list: - - port: '80' -- item_type: SERVICES - service_list: - - name: TCP -- item_type: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '3' - name: PC2 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.4 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '4' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.5 - software_state: GOOD - file_system_state: GOOD -- item_type: NODE - node_id: '5' - name: SWITCH2 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.6 - software_state: GOOD - file_system_state: GOOD -- item_type: NODE - node_id: '6' - name: SWITCH3 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.7 - software_state: GOOD - file_system_state: GOOD -- item_type: LINK - id: '7' - name: link1 - bandwidth: 1000000000 - source: '1' - destination: '4' -- item_type: LINK - id: '8' - name: link2 - bandwidth: 1000000000 - source: '4' - destination: '2' -- item_type: LINK - id: '9' - name: link3 - bandwidth: 1000000000 - source: '2' - destination: '5' -- item_type: LINK - id: '10' - name: link4 - bandwidth: 1000000000 - source: '2' - destination: '6' -- item_type: LINK - id: '11' - name: link5 - bandwidth: 1000000000 - source: '5' - destination: '3' -- item_type: LINK - id: '12' - name: link6 - bandwidth: 1000000000 - source: '6' - destination: '3' -- item_type: GREEN_IER - id: '13' - start_step: 1 - end_step: 128 - load: 100000 - protocol: TCP - port: '80' - source: '3' - destination: '2' - mission_criticality: 5 -- item_type: RED_POL - id: '14' - start_step: 50 - end_step: 50 - targetNodeId: '1' - initiator: DIRECT - type: SERVICE - protocol: TCP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- item_type: RED_IER - id: '15' - start_step: 60 - end_step: 100 - load: 1000000 - protocol: TCP - port: '80' - source: '1' - destination: '2' - mission_criticality: 0 -- item_type: RED_POL - id: '16' - start_step: 80 - end_step: 80 - targetNodeId: '2' - initiator: IER - type: SERVICE - protocol: TCP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- item_type: ACL_RULE - id: '17' - permission: ALLOW - source: ANY - destination: ANY - protocol: ANY - port: ANY - position: 0 diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml deleted file mode 100644 index e91859d2..00000000 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml +++ /dev/null @@ -1,366 +0,0 @@ -- item_type: PORTS - ports_list: - - port: '80' -- item_type: SERVICES - service_list: - - name: TCP -- item_type: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.11 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '2' - name: PC2 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.12 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '3' - name: PC3 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.13 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '4' - name: PC4 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.20.14 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '5' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD -- item_type: NODE - node_id: '6' - name: IDS - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.4 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '7' - name: SWITCH2 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD -- item_type: NODE - node_id: '8' - name: LOP1 - node_class: SERVICE - node_type: LOP - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.12 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '9' - name: SERVER1 - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.14 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '10' - name: SERVER2 - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.20.15 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: LINK - id: '11' - name: link1 - bandwidth: 1000000000 - source: '1' - destination: '5' -- item_type: LINK - id: '12' - name: link2 - bandwidth: 1000000000 - source: '2' - destination: '5' -- item_type: LINK - id: '13' - name: link3 - bandwidth: 1000000000 - source: '3' - destination: '5' -- item_type: LINK - id: '14' - name: link4 - bandwidth: 1000000000 - source: '4' - destination: '5' -- item_type: LINK - id: '15' - name: link5 - bandwidth: 1000000000 - source: '5' - destination: '6' -- item_type: LINK - id: '16' - name: link6 - bandwidth: 1000000000 - source: '5' - destination: '8' -- item_type: LINK - id: '17' - name: link7 - bandwidth: 1000000000 - source: '6' - destination: '7' -- item_type: LINK - id: '18' - name: link8 - bandwidth: 1000000000 - source: '8' - destination: '7' -- item_type: LINK - id: '19' - name: link9 - bandwidth: 1000000000 - source: '7' - destination: '9' -- item_type: LINK - id: '20' - name: link10 - bandwidth: 1000000000 - source: '7' - destination: '10' -- item_type: GREEN_IER - id: '21' - start_step: 1 - end_step: 128 - load: 100000 - protocol: TCP - port: '80' - source: '1' - destination: '9' - mission_criticality: 2 -- item_type: GREEN_IER - id: '22' - start_step: 1 - end_step: 128 - load: 100000 - protocol: TCP - port: '80' - source: '2' - destination: '9' - mission_criticality: 2 -- item_type: GREEN_IER - id: '23' - start_step: 1 - end_step: 128 - load: 100000 - protocol: TCP - port: '80' - source: '9' - destination: '3' - mission_criticality: 5 -- item_type: GREEN_IER - id: '24' - start_step: 1 - end_step: 128 - load: 100000 - protocol: TCP - port: '80' - source: '4' - destination: '10' - mission_criticality: 2 -- item_type: ACL_RULE - id: '25' - permission: ALLOW - source: 192.168.10.11 - destination: 192.168.10.14 - protocol: TCP - port: 80 - position: 0 -- item_type: ACL_RULE - id: '26' - permission: ALLOW - source: 192.168.10.12 - destination: 192.168.10.14 - protocol: TCP - port: 80 - position: 1 -- item_type: ACL_RULE - id: '27' - permission: ALLOW - source: 192.168.10.13 - destination: 192.168.10.14 - protocol: TCP - port: 80 - position: 2 -- item_type: ACL_RULE - id: '28' - permission: ALLOW - source: 192.168.20.14 - destination: 192.168.20.15 - protocol: TCP - port: 80 - position: 3 -- item_type: ACL_RULE - id: '29' - permission: ALLOW - source: 192.168.10.14 - destination: 192.168.10.13 - protocol: TCP - port: 80 - position: 4 -- item_type: ACL_RULE - id: '30' - permission: DENY - source: 192.168.10.11 - destination: 192.168.20.15 - protocol: TCP - port: 80 - position: 5 -- item_type: ACL_RULE - id: '31' - permission: DENY - source: 192.168.10.12 - destination: 192.168.20.15 - protocol: TCP - port: 80 - position: 6 -- item_type: ACL_RULE - id: '32' - permission: DENY - source: 192.168.10.13 - destination: 192.168.20.15 - protocol: TCP - port: 80 - position: 7 -- item_type: ACL_RULE - id: '33' - permission: DENY - source: 192.168.20.14 - destination: 192.168.10.14 - protocol: TCP - port: 80 - position: 8 -- item_type: RED_POL - id: '34' - start_step: 20 - end_step: 20 - targetNodeId: '1' - initiator: DIRECT - type: SERVICE - protocol: TCP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- item_type: RED_POL - id: '35' - start_step: 20 - end_step: 20 - targetNodeId: '2' - initiator: DIRECT - type: SERVICE - protocol: TCP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- item_type: RED_IER - id: '36' - start_step: 30 - end_step: 128 - load: 440000000 - protocol: TCP - port: '80' - source: '1' - destination: '9' - mission_criticality: 0 -- item_type: RED_IER - id: '37' - start_step: 30 - end_step: 128 - load: 440000000 - protocol: TCP - port: '80' - source: '2' - destination: '9' - mission_criticality: 0 -- item_type: RED_POL - id: '38' - start_step: 30 - end_step: 30 - targetNodeId: '9' - initiator: IER - type: SERVICE - protocol: TCP - state: OVERWHELMED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml deleted file mode 100644 index 453b6abb..00000000 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml +++ /dev/null @@ -1,164 +0,0 @@ -- item_type: PORTS - ports_list: - - port: '80' -- item_type: SERVICES - service_list: - - name: TCP -- item_type: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '2' - name: PC2 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD -- item_type: NODE - node_id: '4' - name: SERVER1 - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.4 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: LINK - id: '5' - name: link1 - bandwidth: 1000000000 - source: '1' - destination: '3' -- item_type: LINK - id: '6' - name: link2 - bandwidth: 1000000000 - source: '2' - destination: '3' -- item_type: LINK - id: '7' - name: link3 - bandwidth: 1000000000 - source: '3' - destination: '4' -- item_type: GREEN_IER - id: '8' - start_step: 1 - end_step: 256 - load: 10000 - protocol: TCP - port: '80' - source: '1' - destination: '4' - mission_criticality: 1 -- item_type: GREEN_IER - id: '9' - start_step: 1 - end_step: 256 - load: 10000 - protocol: TCP - port: '80' - source: '2' - destination: '4' - mission_criticality: 1 -- item_type: GREEN_IER - id: '10' - start_step: 1 - end_step: 256 - load: 10000 - protocol: TCP - port: '80' - source: '4' - destination: '2' - mission_criticality: 5 -- item_type: ACL_RULE - id: '11' - permission: ALLOW - source: 192.168.1.2 - destination: 192.168.1.4 - protocol: TCP - port: 80 - position: 0 -- item_type: ACL_RULE - id: '12' - permission: ALLOW - source: 192.168.1.3 - destination: 192.168.1.4 - protocol: TCP - port: 80 - position: 1 -- item_type: ACL_RULE - id: '13' - permission: ALLOW - source: 192.168.1.4 - destination: 192.168.1.3 - protocol: TCP - port: 80 - position: 2 -- item_type: RED_POL - id: '14' - start_step: 20 - end_step: 20 - targetNodeId: '1' - initiator: DIRECT - type: SERVICE - protocol: TCP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- item_type: RED_IER - id: '15' - start_step: 30 - end_step: 256 - load: 10000000 - protocol: TCP - port: '80' - source: '1' - destination: '4' - mission_criticality: 0 -- item_type: RED_POL - id: '16' - start_step: 40 - end_step: 40 - targetNodeId: '4' - initiator: IER - type: SERVICE - protocol: TCP - state: OVERWHELMED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml deleted file mode 100644 index 96596514..00000000 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml +++ /dev/null @@ -1,546 +0,0 @@ -- item_type: PORTS - ports_list: - - port: '80' - - port: '1433' - - port: '53' -- item_type: SERVICES - service_list: - - name: TCP - - name: TCP_SQL - - name: UDP -- item_type: NODE - node_id: '1' - name: CLIENT_1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.11 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- item_type: NODE - node_id: '2' - name: CLIENT_2 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.10.12 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '3' - name: SWITCH_1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.10.1 - software_state: GOOD - file_system_state: GOOD -- item_type: NODE - node_id: '4' - name: SECURITY_SUITE - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.10 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- item_type: NODE - node_id: '5' - name: MANAGEMENT_CONSOLE - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.12 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- item_type: NODE - node_id: '6' - name: SWITCH_2 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.2.1 - software_state: GOOD - file_system_state: GOOD -- item_type: NODE - node_id: '7' - name: WEB_SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.2.10 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: TCP_SQL - port: '1433' - state: GOOD -- item_type: NODE - node_id: '8' - name: DATABASE_SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.2.14 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: TCP_SQL - port: '1433' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- item_type: NODE - node_id: '9' - name: BACKUP_SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.2.16 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: LINK - id: '10' - name: LINK_1 - bandwidth: 1000000000 - source: '1' - destination: '3' -- item_type: LINK - id: '11' - name: LINK_2 - bandwidth: 1000000000 - source: '2' - destination: '3' -- item_type: LINK - id: '12' - name: LINK_3 - bandwidth: 1000000000 - source: '3' - destination: '4' -- item_type: LINK - id: '13' - name: LINK_4 - bandwidth: 1000000000 - source: '3' - destination: '5' -- item_type: LINK - id: '14' - name: LINK_5 - bandwidth: 1000000000 - source: '4' - destination: '6' -- item_type: LINK - id: '15' - name: LINK_6 - bandwidth: 1000000000 - source: '5' - destination: '6' -- item_type: LINK - id: '16' - name: LINK_7 - bandwidth: 1000000000 - source: '6' - destination: '7' -- item_type: LINK - id: '17' - name: LINK_8 - bandwidth: 1000000000 - source: '6' - destination: '8' -- item_type: LINK - id: '18' - name: LINK_9 - bandwidth: 1000000000 - source: '6' - destination: '9' -- item_type: GREEN_IER - id: '19' - start_step: 1 - end_step: 256 - load: 10000 - protocol: TCP - port: '80' - source: '1' - destination: '7' - mission_criticality: 5 -- item_type: GREEN_IER - id: '20' - start_step: 1 - end_step: 256 - load: 10000 - protocol: TCP - port: '80' - source: '7' - destination: '1' - mission_criticality: 5 -- item_type: GREEN_IER - id: '21' - start_step: 1 - end_step: 256 - load: 10000 - protocol: TCP - port: '80' - source: '2' - destination: '7' - mission_criticality: 5 -- item_type: GREEN_IER - id: '22' - start_step: 1 - end_step: 256 - load: 10000 - protocol: TCP - port: '80' - source: '7' - destination: '2' - mission_criticality: 5 -- item_type: GREEN_IER - id: '23' - start_step: 1 - end_step: 256 - load: 5000 - protocol: TCP_SQL - port: '1433' - source: '7' - destination: '8' - mission_criticality: 5 -- item_type: GREEN_IER - id: '24' - start_step: 1 - end_step: 256 - load: 100000 - protocol: TCP_SQL - port: '1433' - source: '8' - destination: '7' - mission_criticality: 5 -- item_type: GREEN_IER - id: '25' - start_step: 1 - end_step: 256 - load: 50000 - protocol: TCP - port: '80' - source: '1' - destination: '9' - mission_criticality: 2 -- item_type: GREEN_IER - id: '26' - start_step: 1 - end_step: 256 - load: 50000 - protocol: TCP - port: '80' - source: '2' - destination: '9' - mission_criticality: 2 -- item_type: GREEN_IER - id: '27' - start_step: 1 - end_step: 256 - load: 5000 - protocol: TCP - port: '80' - source: '5' - destination: '7' - mission_criticality: 1 -- item_type: GREEN_IER - id: '28' - start_step: 1 - end_step: 256 - load: 5000 - protocol: TCP - port: '80' - source: '7' - destination: '5' - mission_criticality: 1 -- item_type: GREEN_IER - id: '29' - start_step: 1 - end_step: 256 - load: 5000 - protocol: TCP - port: '80' - source: '5' - destination: '8' - mission_criticality: 1 -- item_type: GREEN_IER - id: '30' - start_step: 1 - end_step: 256 - load: 5000 - protocol: TCP - port: '80' - source: '8' - destination: '5' - mission_criticality: 1 -- item_type: GREEN_IER - id: '31' - start_step: 1 - end_step: 256 - load: 5000 - protocol: TCP - port: '80' - source: '5' - destination: '9' - mission_criticality: 1 -- item_type: GREEN_IER - id: '32' - start_step: 1 - end_step: 256 - load: 5000 - protocol: TCP - port: '80' - source: '9' - destination: '5' - mission_criticality: 1 -- item_type: ACL_RULE - id: '33' - permission: ALLOW - source: 192.168.10.11 - destination: 192.168.2.10 - protocol: ANY - port: ANY - position: 0 -- item_type: ACL_RULE - id: '34' - permission: ALLOW - source: 192.168.10.11 - destination: 192.168.2.14 - protocol: ANY - port: ANY - position: 1 -- item_type: ACL_RULE - id: '35' - permission: ALLOW - source: 192.168.10.12 - destination: 192.168.2.14 - protocol: ANY - port: ANY - position: 2 -- item_type: ACL_RULE - id: '36' - permission: ALLOW - source: 192.168.10.12 - destination: 192.168.2.10 - protocol: ANY - port: ANY - position: 3 -- item_type: ACL_RULE - id: '37' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.10.11 - protocol: ANY - port: ANY - position: 4 -- item_type: ACL_RULE - id: '38' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.10.12 - protocol: ANY - port: ANY - position: 5 -- item_type: ACL_RULE - id: '39' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.2.14 - protocol: ANY - port: ANY - position: 6 -- item_type: ACL_RULE - id: '40' - permission: ALLOW - source: 192.168.2.14 - destination: 192.168.2.10 - protocol: ANY - port: ANY - position: 7 -- item_type: ACL_RULE - id: '41' - permission: ALLOW - source: 192.168.10.11 - destination: 192.168.2.16 - protocol: ANY - port: ANY - position: 8 -- item_type: ACL_RULE - id: '42' - permission: ALLOW - source: 192.168.10.12 - destination: 192.168.2.16 - protocol: ANY - port: ANY - position: 9 -- item_type: ACL_RULE - id: '43' - permission: ALLOW - source: 192.168.1.12 - destination: 192.168.2.10 - protocol: ANY - port: ANY - position: 10 -- item_type: ACL_RULE - id: '44' - permission: ALLOW - source: 192.168.1.12 - destination: 192.168.2.14 - protocol: ANY - port: ANY - position: 11 -- item_type: ACL_RULE - id: '45' - permission: ALLOW - source: 192.168.1.12 - destination: 192.168.2.16 - protocol: ANY - port: ANY - position: 12 -- item_type: ACL_RULE - id: '46' - permission: ALLOW - source: 192.168.2.10 - destination: 192.168.1.12 - protocol: ANY - port: ANY - position: 13 -- item_type: ACL_RULE - id: '47' - permission: ALLOW - source: 192.168.2.14 - destination: 192.168.1.12 - protocol: ANY - port: ANY - position: 14 -- item_type: ACL_RULE - id: '48' - permission: ALLOW - source: 192.168.2.16 - destination: 192.168.1.12 - protocol: ANY - port: ANY - position: 15 -- item_type: ACL_RULE - id: '49' - permission: DENY - source: ANY - destination: ANY - protocol: ANY - port: ANY - position: 16 -- item_type: RED_POL - id: '50' - start_step: 50 - end_step: 50 - targetNodeId: '1' - initiator: DIRECT - type: SERVICE - protocol: UDP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- item_type: RED_IER - id: '51' - start_step: 75 - end_step: 105 - load: 10000 - protocol: UDP - port: '53' - source: '1' - destination: '8' - mission_criticality: 0 -- item_type: RED_POL - id: '52' - start_step: 100 - end_step: 100 - targetNodeId: '8' - initiator: IER - type: SERVICE - protocol: UDP - state: COMPROMISED - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- item_type: RED_POL - id: '53' - start_step: 105 - end_step: 105 - targetNodeId: '8' - initiator: SERVICE - type: FILE - protocol: NA - state: CORRUPT - sourceNodeId: '8' - sourceNodeService: UDP - sourceNodeServiceState: COMPROMISED -- item_type: RED_POL - id: '54' - start_step: 105 - end_step: 105 - targetNodeId: '8' - initiator: SERVICE - type: SERVICE - protocol: TCP_SQL - state: COMPROMISED - sourceNodeId: '8' - sourceNodeService: UDP - sourceNodeServiceState: COMPROMISED -- item_type: RED_POL - id: '55' - start_step: 125 - end_step: 125 - targetNodeId: '7' - initiator: SERVICE - type: SERVICE - protocol: TCP - state: OVERWHELMED - sourceNodeId: '8' - sourceNodeService: TCP_SQL - sourceNodeServiceState: COMPROMISED diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml deleted file mode 100644 index db4ed692..00000000 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ /dev/null @@ -1,168 +0,0 @@ -# Training Config File - -# Sets which agent algorithm framework will be used. -# Options are: -# "SB3" (Stable Baselines3) -# "RLLIB" (Ray RLlib) -# "CUSTOM" (Custom Agent) -agent_framework: SB3 - -# Sets which deep learning framework will be used (by RLlib ONLY). -# Default is TF (Tensorflow). -# Options are: -# "TF" (Tensorflow) -# TF2 (Tensorflow 2.X) -# TORCH (PyTorch) -deep_learning_framework: TF2 - -# Sets which Agent class will be used. -# Options are: -# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) -# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) -# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) -# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) -# "RANDOM" (primaite.agents.simple.RandomAgent) -# "DUMMY" (primaite.agents.simple.DummyAgent) -agent_identifier: PPO - -# Sets whether Red Agent POL and IER is randomised. -# Options are: -# True -# False -random_red_agent: False - -# The (integer) seed to be used in random number generation -# Default is None (null) -seed: null - -# Set whether the agent evaluation will be deterministic instead of stochastic -# Options are: -# True -# False -deterministic: False - -# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. -# Options are: -# "BASIC" (The current observation space only) -# "FULL" (Full environment view with actions taken and reward feedback) -hard_coded_agent_view: FULL - -# Sets How the Action Space is defined: -# "NODE" -# "ACL" -# "ANY" node and acl actions -action_type: ANY -# observation space -observation_space: - flatten: true - components: - - name: NODE_LINK_TABLE - - name: NODE_STATUSES - - name: LINK_TRAFFIC_LEVELS - - name: ACCESS_CONTROL_LIST - -# Number of episodes for training to run per session -num_train_episodes: 10 - -# Number of time_steps for training per episode -num_train_steps: 256 - -# Number of episodes for evaluation to run per session -num_eval_episodes: 1 - -# Number of time_steps for evaluation per episode -num_eval_steps: 256 - -# Sets how often the agent will save a checkpoint (every n time episodes). -# Set to 0 if no checkpoints are required. Default is 10 -checkpoint_every_n_episodes: 10 - -# Time delay (milliseconds) between steps for CUSTOM agents. -time_delay: 5 - -# Type of session to be run. Options are: -# "TRAIN" (Trains an agent) -# "EVAL" (Evaluates an agent) -# "TRAIN_EVAL" (Trains then evaluates an agent) -session_type: TRAIN_EVAL - -# Environment config values -# The high value for the observation space -observation_space_high_value: 1000000000 - -# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY) -implicit_acl_rule: DENY -# Total number of ACL rules allowed in the environment -max_number_acl_rules: 30 - -# The Stable Baselines3 learn/eval output verbosity level: -# Options are: -# "NONE" (No Output) -# "INFO" (Info Messages (such as devices and wrappers used)) -# "DEBUG" (All Messages) -sb3_output_verbose_level: NONE - -# Reward values -# Generic -all_ok: 0 -# Node Hardware State -off_should_be_on: -0.001 -off_should_be_resetting: -0.0005 -on_should_be_off: -0.0002 -on_should_be_resetting: -0.0005 -resetting_should_be_on: -0.0005 -resetting_should_be_off: -0.0002 -resetting: -0.0003 -# Node Software or Service State -good_should_be_patching: 0.0002 -good_should_be_compromised: 0.0005 -good_should_be_overwhelmed: 0.0005 -patching_should_be_good: -0.0005 -patching_should_be_compromised: 0.0002 -patching_should_be_overwhelmed: 0.0002 -patching: -0.0003 -compromised_should_be_good: -0.002 -compromised_should_be_patching: -0.002 -compromised_should_be_overwhelmed: -0.002 -compromised: -0.002 -overwhelmed_should_be_good: -0.002 -overwhelmed_should_be_patching: -0.002 -overwhelmed_should_be_compromised: -0.002 -overwhelmed: -0.002 -# Node File System State -good_should_be_repairing: 0.0002 -good_should_be_restoring: 0.0002 -good_should_be_corrupt: 0.0005 -good_should_be_destroyed: 0.001 -repairing_should_be_good: -0.0005 -repairing_should_be_restoring: 0.0002 -repairing_should_be_corrupt: 0.0002 -repairing_should_be_destroyed: 0.0000 -repairing: -0.0003 -restoring_should_be_good: -0.001 -restoring_should_be_repairing: -0.0002 -restoring_should_be_corrupt: 0.0001 -restoring_should_be_destroyed: 0.0002 -restoring: -0.0006 -corrupt_should_be_good: -0.001 -corrupt_should_be_repairing: -0.001 -corrupt_should_be_restoring: -0.001 -corrupt_should_be_destroyed: 0.0002 -corrupt: -0.001 -destroyed_should_be_good: -0.002 -destroyed_should_be_repairing: -0.002 -destroyed_should_be_restoring: -0.002 -destroyed_should_be_corrupt: -0.002 -destroyed: -0.002 -scanning: -0.0002 -# IER status -red_ier_running: -0.0005 -green_ier_blocked: -0.001 - -# Patching / Reset durations -os_patching_duration: 5 # The time taken to patch the OS -node_reset_duration: 5 # The time taken to reset a node (hardware) -service_patching_duration: 5 # The time taken to patch a service -file_system_repairing_limit: 5 # The time take to repair the file system -file_system_restoring_limit: 5 # The time take to restore the file system -file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py deleted file mode 100644 index fe3e3429..00000000 --- a/src/primaite/config/lay_down_config.py +++ /dev/null @@ -1,141 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from logging import Logger -from pathlib import Path -from typing import Any, Dict, Final, List, Union - -import yaml - -from primaite import getLogger, PRIMAITE_PATHS - -_LOGGER: Logger = getLogger(__name__) - -_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down" - - -def convert_legacy_lay_down_config(legacy_config: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Convert a legacy lay down config to the new format. - - :param legacy_config: A legacy lay down config. - """ - field_conversion_map = { - "itemType": "item_type", - "portsList": "ports_list", - "serviceList": "service_list", - "baseType": "node_class", - "nodeType": "node_type", - "hardwareState": "hardware_state", - "softwareState": "software_state", - "startStep": "start_step", - "endStep": "end_step", - "fileSystemState": "file_system_state", - "ipAddress": "ip_address", - "missionCriticality": "mission_criticality", - } - new_config = [] - for item in legacy_config: - if "itemType" in item: - if item["itemType"] in ["ACTIONS", "STEPS"]: - continue - new_dict = {} - for key in item.keys(): - conversion_key = field_conversion_map.get(key) - if key == "id" and "itemType" in item: - if item["itemType"] == "NODE": - conversion_key = "node_id" - if conversion_key: - new_dict[conversion_key] = item[key] - else: - new_dict[key] = item[key] - new_config.append(new_dict) - return new_config - - -def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: - """ - Read in a lay down config yaml file. - - :param file_path: The config file path. - :param legacy_file: True if the config file is legacy format, otherwise False. - :return: The lay down config as a dict. - :raises ValueError: If the file_path does not exist. - """ - if not isinstance(file_path, Path): - file_path = Path(file_path) - if file_path.exists(): - with open(file_path, "r") as file: - config = yaml.safe_load(file) - _LOGGER.debug(f"Loading lay down config file: {file_path}") - if legacy_file: - try: - config = convert_legacy_lay_down_config(config) - except KeyError: - msg = ( - f"Failed to convert lay down config file {file_path} " - f"from legacy format. Attempting to use file as is." - ) - _LOGGER.error(msg) - return config - msg = f"Cannot load the lay down config as it does not exist: {file_path}" - _LOGGER.error(msg) - raise ValueError(msg) - - -def ddos_basic_one_config_path() -> Path: - """ - The path to the example lay_down_config_1_DDOS_basic.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path - - -def ddos_basic_two_config_path() -> Path: - """ - The path to the example lay_down_config_2_DDOS_basic.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path - - -def dos_very_basic_config_path() -> Path: - """ - The path to the example lay_down_config_3_DOS_very_basic.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path - - -def data_manipulation_config_path() -> Path: - """ - The path to the example lay_down_config_5_data_manipulation.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py deleted file mode 100644 index f81bb6f7..00000000 --- a/src/primaite/config/training_config.py +++ /dev/null @@ -1,438 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from __future__ import annotations - -from dataclasses import dataclass, field -from logging import Logger -from pathlib import Path -from typing import Any, Dict, Final, Optional, Union - -import yaml - -from primaite import getLogger, PRIMAITE_PATHS -from primaite.common.enums import ( - ActionType, - AgentFramework, - AgentIdentifier, - DeepLearningFramework, - HardCodedAgentView, - RulePermissionType, - SB3OutputVerboseLevel, - SessionType, -) - -_LOGGER: Logger = getLogger(__name__) - -_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training" - - -def main_training_config_path() -> Path: - """ - The path to the example training_config_main.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_TRAINING / "training_config_main.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path - - -@dataclass() -class TrainingConfig: - """The Training Config class.""" - - agent_framework: AgentFramework = AgentFramework.SB3 - "The AgentFramework" - - deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF - "The DeepLearningFramework" - - agent_identifier: AgentIdentifier = AgentIdentifier.PPO - "The AgentIdentifier" - - hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL - "The view the deterministic hard-coded agent has of the environment" - - random_red_agent: bool = False - "Creates Random Red Agent Attacks" - - action_type: ActionType = ActionType.ANY - "The ActionType to use" - - num_train_episodes: int = 10 - "The number of episodes to train over during an training session" - - num_train_steps: int = 256 - "The number of steps in an episode during an training session" - - num_eval_episodes: int = 1 - "The number of episodes to train over during an evaluation session" - - num_eval_steps: int = 256 - "The number of steps in an episode during an evaluation session" - - checkpoint_every_n_episodes: int = 5 - "The agent will save a checkpoint every n episodes" - - observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}) - "The observation space config dict" - - time_delay: int = 10 - "The delay between steps (ms). Applies to generic agents only" - - # file - session_type: SessionType = SessionType.TRAIN - "The type of PrimAITE session to run" - - load_agent: bool = False - "Determine whether to load an agent from file" - - agent_load_file: Optional[str] = None - "File path and file name of agent if you're loading one in" - - # Environment - observation_space_high_value: int = 1000000000 - "The high value for the observation space" - - sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE - "Stable Baselines3 learn/eval output verbosity level" - - implicit_acl_rule: RulePermissionType = RulePermissionType.DENY - "ALLOW or DENY implicit firewall rule to go at the end of list of ACL list." - - max_number_acl_rules: int = 30 - "Sets a limit for number of acl rules allowed in the list and environment." - - # Reward values - # Generic - all_ok: float = 0 - - # Node Hardware State - off_should_be_on: float = -0.001 - off_should_be_resetting: float = -0.0005 - on_should_be_off: float = -0.0002 - on_should_be_resetting: float = -0.0005 - resetting_should_be_on: float = -0.0005 - resetting_should_be_off: float = -0.0002 - resetting: float = -0.0003 - - # Node Software or Service State - good_should_be_patching: float = 0.0002 - good_should_be_compromised: float = 0.0005 - good_should_be_overwhelmed: float = 0.0005 - patching_should_be_good: float = -0.0005 - patching_should_be_compromised: float = 0.0002 - patching_should_be_overwhelmed: float = 0.0002 - patching: float = -0.0003 - compromised_should_be_good: float = -0.002 - compromised_should_be_patching: float = -0.002 - compromised_should_be_overwhelmed: float = -0.002 - compromised: float = -0.002 - overwhelmed_should_be_good: float = -0.002 - overwhelmed_should_be_patching: float = -0.002 - overwhelmed_should_be_compromised: float = -0.002 - overwhelmed: float = -0.002 - - # Node File System State - good_should_be_repairing: float = 0.0002 - good_should_be_restoring: float = 0.0002 - good_should_be_corrupt: float = 0.0005 - good_should_be_destroyed: float = 0.001 - repairing_should_be_good: float = -0.0005 - repairing_should_be_restoring: float = 0.0002 - repairing_should_be_corrupt: float = 0.0002 - repairing_should_be_destroyed: float = 0.0000 - repairing: float = -0.0003 - restoring_should_be_good: float = -0.001 - restoring_should_be_repairing: float = -0.0002 - restoring_should_be_corrupt: float = 0.0001 - restoring_should_be_destroyed: float = 0.0002 - restoring: float = -0.0006 - corrupt_should_be_good: float = -0.001 - corrupt_should_be_repairing: float = -0.001 - corrupt_should_be_restoring: float = -0.001 - corrupt_should_be_destroyed: float = 0.0002 - corrupt: float = -0.001 - destroyed_should_be_good: float = -0.002 - destroyed_should_be_repairing: float = -0.002 - destroyed_should_be_restoring: float = -0.002 - destroyed_should_be_corrupt: float = -0.002 - destroyed: float = -0.002 - scanning: float = -0.0002 - - # IER status - red_ier_running: float = -0.0005 - green_ier_blocked: float = -0.001 - - # Patching / Reset durations - os_patching_duration: int = 5 - "The time taken to patch the OS" - - node_reset_duration: int = 5 - "The time taken to reset a node (hardware)" - - node_booting_duration: int = 3 - "The Time taken to turn on the node" - - node_shutdown_duration: int = 2 - "The time taken to turn off the node" - - service_patching_duration: int = 5 - "The time taken to patch a service" - - file_system_repairing_limit: int = 5 - "The time take to repair the file system" - - file_system_restoring_limit: int = 5 - "The time take to restore the file system" - - file_system_scanning_limit: int = 5 - "The time taken to scan the file system" - - deterministic: bool = False - "If true, the training will be deterministic" - - seed: Optional[int] = None - "The random number generator seed to be used while training the agent" - - @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig: - """ - Create an instance of TrainingConfig from a dict. - - :param config_dict: The training config dict. - :return: The instance of TrainingConfig. - """ - field_enum_map = { - "agent_framework": AgentFramework, - "deep_learning_framework": DeepLearningFramework, - "agent_identifier": AgentIdentifier, - "action_type": ActionType, - "session_type": SessionType, - "sb3_output_verbose_level": SB3OutputVerboseLevel, - "hard_coded_agent_view": HardCodedAgentView, - "implicit_acl_rule": RulePermissionType, - } - - # convert the string representation of enums into the actual enum values themselves? - for key, value in field_enum_map.items(): - if key in config_dict: - config_dict[key] = value[config_dict[key]] - - return TrainingConfig(**config_dict) - - def to_dict(self, json_serializable: bool = True) -> Dict: - """ - Serialise the ``TrainingConfig`` as dict. - - :param json_serializable: If True, Enums are converted to their - string name. - :return: The ``TrainingConfig`` as a dict. - """ - data = self.__dict__ - if json_serializable: - data["agent_framework"] = self.agent_framework.name - data["deep_learning_framework"] = self.deep_learning_framework.name - data["agent_identifier"] = self.agent_identifier.name - data["action_type"] = self.action_type.name - data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name - data["session_type"] = self.session_type.name - data["hard_coded_agent_view"] = self.hard_coded_agent_view.name - data["implicit_acl_rule"] = self.implicit_acl_rule.name - - return data - - def __str__(self) -> str: - obs_str = ",".join([c["name"] for c in self.observation_space["components"]]) - tc = f"{self.agent_framework}, " - # if self.agent_framework is AgentFramework.RLLIB: - # tc += f"{self.deep_learning_framework}, " - tc += f"{self.agent_identifier}, " - if self.agent_identifier is AgentIdentifier.HARDCODED: - tc += f"{self.hard_coded_agent_view}, " - tc += f"{self.action_type}, " - tc += f"observation_space={obs_str}, " - if self.session_type is SessionType.TRAIN: - tc += f"{self.num_train_episodes} episodes @ " - tc += f"{self.num_train_steps} steps" - elif self.session_type is SessionType.EVAL: - tc += f"{self.num_eval_episodes} episodes @ " - tc += f"{self.num_eval_steps} steps" - else: - tc += f"Training: {self.num_eval_episodes} episodes @ " - tc += f"{self.num_eval_steps} steps" - tc += f"Evaluation: {self.num_eval_episodes} episodes @ " - tc += f"{self.num_eval_steps} steps" - return tc - - -def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: - """ - Read in a training config yaml file. - - :param file_path: The config file path. - :param legacy_file: True if the config file is legacy format, otherwise - False. - :return: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - :raises ValueError: If the file_path does not exist. - :raises TypeError: When the TrainingConfig object cannot be created - using the values from the config file read from ``file_path``. - """ - if not isinstance(file_path, Path): - file_path = Path(file_path) - if file_path.exists(): - with open(file_path, "r") as file: - config = yaml.safe_load(file) - _LOGGER.debug(f"Loading training config file: {file_path}") - if legacy_file: - try: - config = convert_legacy_training_config_dict(config) - - except KeyError as e: - msg = ( - f"Failed to convert training config file {file_path} " - f"from legacy format. Attempting to use file as is." - ) - _LOGGER.error(msg) - raise e - try: - return TrainingConfig.from_dict(config) - except TypeError as e: - msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}" - _LOGGER.critical(msg, exc_info=True) - raise e - msg = f"Cannot load the training config as it does not exist: {file_path}" - _LOGGER.error(msg) - raise ValueError(msg) - - -def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - agent_framework: AgentFramework = AgentFramework.SB3, - agent_identifier: AgentIdentifier = AgentIdentifier.PPO, - action_type: ActionType = ActionType.ANY, - num_train_steps: int = 256, - num_eval_steps: int = 256, - num_train_episodes: int = 10, - num_eval_episodes: int = 1, -) -> Dict[str, Any]: - """ - Convert a legacy training config dict to the new format. - - :param legacy_config_dict: A legacy training config dict. - :param agent_framework: The agent framework to use as legacy training - configs don't have agent_framework values. - :param agent_identifier: The red agent identifier to use as legacy - training configs don't have agent_identifier values. - :param action_type: The action space type to set as legacy training configs - don't have action_type values. - :param num_train_steps: The number of train steps to set as legacy training configs - don't have num_train_steps values. - :param num_eval_steps: The number of eval steps to set as legacy training configs - don't have num_eval_steps values. - :param num_train_episodes: The number of train episodes to set as legacy training configs - don't have num_train_episodes values. - :param num_eval_episodes: The number of eval episodes to set as legacy training configs - don't have num_eval_episodes values. - :return: The converted training config dict. - """ - config_dict = { - "agent_framework": agent_framework.name, - "agent_identifier": agent_identifier.name, - "action_type": action_type.name, - "num_train_steps": num_train_steps, - "num_eval_steps": num_eval_steps, - "num_train_episodes": num_train_episodes, - "num_eval_episodes": num_eval_episodes, - "sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name, - } - session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} - legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]] - for legacy_key, value in legacy_config_dict.items(): - new_key = _get_new_key_from_legacy(legacy_key) - if new_key: - config_dict[new_key] = value - return config_dict - - -def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]: - """ - Maps legacy training config keys to the new format keys. - - :param legacy_key: A legacy training config key. - :return: The mapped key. - """ - key_mapping = { - "agentIdentifier": None, - "numEpisodes": "num_train_episodes", - "numSteps": "num_train_steps", - "timeDelay": "time_delay", - "configFilename": None, - "sessionType": "session_type", - "loadAgent": "load_agent", - "agentLoadFile": "agent_load_file", - "observationSpaceHighValue": "observation_space_high_value", - "allOk": "all_ok", - "offShouldBeOn": "off_should_be_on", - "offShouldBeResetting": "off_should_be_resetting", - "onShouldBeOff": "on_should_be_off", - "onShouldBeResetting": "on_should_be_resetting", - "resettingShouldBeOn": "resetting_should_be_on", - "resettingShouldBeOff": "resetting_should_be_off", - "resetting": "resetting", - "goodShouldBePatching": "good_should_be_patching", - "goodShouldBeCompromised": "good_should_be_compromised", - "goodShouldBeOverwhelmed": "good_should_be_overwhelmed", - "patchingShouldBeGood": "patching_should_be_good", - "patchingShouldBeCompromised": "patching_should_be_compromised", - "patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed", - "patching": "patching", - "compromisedShouldBeGood": "compromised_should_be_good", - "compromisedShouldBePatching": "compromised_should_be_patching", - "compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed", - "compromised": "compromised", - "overwhelmedShouldBeGood": "overwhelmed_should_be_good", - "overwhelmedShouldBePatching": "overwhelmed_should_be_patching", - "overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised", - "overwhelmed": "overwhelmed", - "goodShouldBeRepairing": "good_should_be_repairing", - "goodShouldBeRestoring": "good_should_be_restoring", - "goodShouldBeCorrupt": "good_should_be_corrupt", - "goodShouldBeDestroyed": "good_should_be_destroyed", - "repairingShouldBeGood": "repairing_should_be_good", - "repairingShouldBeRestoring": "repairing_should_be_restoring", - "repairingShouldBeCorrupt": "repairing_should_be_corrupt", - "repairingShouldBeDestroyed": "repairing_should_be_destroyed", - "repairing": "repairing", - "restoringShouldBeGood": "restoring_should_be_good", - "restoringShouldBeRepairing": "restoring_should_be_repairing", - "restoringShouldBeCorrupt": "restoring_should_be_corrupt", - "restoringShouldBeDestroyed": "restoring_should_be_destroyed", - "restoring": "restoring", - "corruptShouldBeGood": "corrupt_should_be_good", - "corruptShouldBeRepairing": "corrupt_should_be_repairing", - "corruptShouldBeRestoring": "corrupt_should_be_restoring", - "corruptShouldBeDestroyed": "corrupt_should_be_destroyed", - "corrupt": "corrupt", - "destroyedShouldBeGood": "destroyed_should_be_good", - "destroyedShouldBeRepairing": "destroyed_should_be_repairing", - "destroyedShouldBeRestoring": "destroyed_should_be_restoring", - "destroyedShouldBeCorrupt": "destroyed_should_be_corrupt", - "destroyed": "destroyed", - "scanning": "scanning", - "redIerRunning": "red_ier_running", - "greenIerBlocked": "green_ier_blocked", - "osPatchingDuration": "os_patching_duration", - "nodeResetDuration": "node_reset_duration", - "nodeBootingDuration": "node_booting_duration", - "nodeShutdownDuration": "node_shutdown_duration", - "servicePatchingDuration": "service_patching_duration", - "fileSystemRepairingLimit": "file_system_repairing_limit", - "fileSystemRestoringLimit": "file_system_restoring_limit", - "fileSystemScanningLimit": "file_system_scanning_limit", - } - return key_mapping[legacy_key] diff --git a/src/primaite/data_viz/__init__.py b/src/primaite/data_viz/__init__.py deleted file mode 100644 index 260579da..00000000 --- a/src/primaite/data_viz/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Utility to generate plots of sessions metrics after PrimAITE.""" -from enum import Enum - - -class PlotlyTemplate(Enum): - """The built-in plotly templates.""" - - PLOTLY = "plotly" - PLOTLY_WHITE = "plotly_white" - PLOTLY_DARK = "plotly_dark" - GGPLOT2 = "ggplot2" - SEABORN = "seaborn" - SIMPLE_WHITE = "simple_white" - NONE = "none" diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py deleted file mode 100644 index 37750353..00000000 --- a/src/primaite/data_viz/session_plots.py +++ /dev/null @@ -1,73 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -from pathlib import Path -from typing import Dict, Optional, Union - -import plotly.graph_objects as go -import polars as pl -import yaml -from plotly.graph_objs import Figure - -from primaite import PRIMAITE_PATHS - - -def get_plotly_config() -> Dict: - """Get the plotly config from primaite_config.yaml.""" - with open(PRIMAITE_PATHS.app_config_file_path, "r") as file: - primaite_config = yaml.safe_load(file) - return primaite_config["session"]["outputs"]["plots"] - - -def plot_av_reward_per_episode( - av_reward_per_episode_csv: Union[str, Path], - title: Optional[str] = None, - subtitle: Optional[str] = None, -) -> Figure: - """ - Plot the average reward per episode from a csv session output. - - :param av_reward_per_episode_csv: The average reward per episode csv - file path. - :param title: The plot title. This is optional. - :param subtitle: The plot subtitle. This is optional. - :return: The plot as an instance of ``plotly.graph_objs._figure.Figure``. - """ - df = pl.read_csv(av_reward_per_episode_csv) - - if title: - if subtitle: - title = f"{title}
{subtitle}" - else: - if subtitle: - title = subtitle - - config = get_plotly_config() - layout = go.Layout( - autosize=config["size"]["auto_size"], - width=config["size"]["width"], - height=config["size"]["height"], - ) - # Create the line graph with a colored line - fig = go.Figure(layout=layout) - fig.update_layout(template=config["template"]) - fig.add_trace( - go.Scatter( - x=df["Episode"], - y=df["Average Reward"], - mode="lines", - name="Mean Reward per Episode", - ) - ) - - # Set the layout of the graph - fig.update_layout( - xaxis={ - "title": "Episode", - "type": "linear", - "rangeslider": {"visible": config["range_slider"]}, - }, - yaxis={"title": "Average Reward"}, - title=title, - showlegend=False, - ) - - return fig diff --git a/src/primaite/environment/__init__.py b/src/primaite/environment/__init__.py deleted file mode 100644 index f0fd21b9..00000000 --- a/src/primaite/environment/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network.""" diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py deleted file mode 100644 index 73b9e998..00000000 --- a/src/primaite/environment/observations.py +++ /dev/null @@ -1,735 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Module for handling configurable observation spaces in PrimAITE.""" -import logging -from abc import ABC, abstractmethod -from logging import Logger -from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union - -import numpy as np -from gymnasium import spaces - -from primaite.acl.acl_rule import ACLRule -from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState -from primaite.nodes.active_node import ActiveNode -from primaite.nodes.service_node import ServiceNode - -# This dependency is only needed for type hints, -# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking -# Therefore, this avoids circular dependency problem. -if TYPE_CHECKING: - from primaite.environment.primaite_env import Primaite - - -_LOGGER: Logger = logging.getLogger(__name__) - - -class AbstractObservationComponent(ABC): - """Represents a part of the PrimAITE observation space.""" - - @abstractmethod - def __init__(self, env: "Primaite") -> None: - """ - Initialise observation component. - - :param env: Primaite training environment. - :type env: Primaite - """ - _LOGGER.info(f"Initialising {self} observation component") - self.env: "Primaite" = env - self.space: spaces.Space - self.current_observation: np.ndarray # type might be too restrictive? - self.structure: List[str] - return NotImplemented - - @abstractmethod - def update(self) -> None: - """Update the observation based on the current state of the environment.""" - self.current_observation = NotImplemented - - @abstractmethod - def generate_structure(self) -> List[str]: - """Return a list of labels for the components of the flattened observation space.""" - return NotImplemented - - -class NodeLinkTable(AbstractObservationComponent): - """ - Table with nodes and links as rows and hardware/software status as cols. - - This will create the observation space formatted as a table of integers. - There is one row per node, followed by one row per link. - The number of columns is 4 plus one per service. They are: - - * node/link ID - * node hardware status / 0 for links - * node operating system status (if active/service) / 0 for links - * node file system status (active/service only) / 0 for links - * node service1 status / traffic load from that service for links - * node service2 status / traffic load from that service for links - * ... - * node serviceN status / traffic load from that service for links - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(12, 7)`` - """ - - _FIXED_PARAMETERS: int = 4 - _MAX_VAL: int = 1_000_000_000 - _DATA_TYPE: type = np.int64 - - def __init__(self, env: "Primaite") -> None: - """ - Initialise a NodeLinkTable observation space component. - - :param env: Training environment. - :type env: Primaite - """ - super().__init__(env) - - # 1. Define the shape of your observation space component - num_items = self.env.num_links + self.env.num_nodes - num_columns = self.env.num_services + self._FIXED_PARAMETERS - observation_shape = (num_items, num_columns) - - # 2. Create Observation space - self.space = spaces.Box( - low=0, - high=self._MAX_VAL, - shape=observation_shape, - dtype=self._DATA_TYPE, - ) - - # 3. Initialise Observation with zeroes - self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) - - self.structure = self.generate_structure() - - def update(self) -> None: - """ - Update the observation based on current environment state. - - The structure of the observation space is described in :class:`.NodeLinkTable` - """ - item_index = 0 - nodes = self.env.nodes - links = self.env.links - # Do nodes first - for _, node in nodes.items(): - self.current_observation[item_index][0] = int(node.node_id) - self.current_observation[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.current_observation[item_index][2] = node.software_state.value - self.current_observation[item_index][3] = node.file_system_state_observed.value - else: - self.current_observation[item_index][2] = 0 - self.current_observation[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.env.services_list: - if node.has_service(service): - self.current_observation[item_index][service_index] = node.get_service_state(service).value - else: - self.current_observation[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.env.services_list: - self.current_observation[item_index][service_index] = 0 - service_index += 1 - item_index += 1 - - # Now do links - for _, link in links.items(): - self.current_observation[item_index][0] = int(link.get_id()) - self.current_observation[item_index][1] = 0 - self.current_observation[item_index][2] = 0 - self.current_observation[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.current_observation[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 - - def generate_structure(self) -> List[str]: - """Return a list of labels for the components of the flattened observation space.""" - nodes = self.env.nodes.values() - links = self.env.links.values() - - structure = [] - - for i, node in enumerate(nodes): - node_id = node.node_id - node_labels = [ - f"node_{node_id}_id", - f"node_{node_id}_hardware_status", - f"node_{node_id}_os_status", - f"node_{node_id}_fs_status", - ] - for j, serv in enumerate(self.env.services_list): - node_labels.append(f"node_{node_id}_service_{serv}_status") - - structure.extend(node_labels) - - for i, link in enumerate(links): - link_id = link.id - link_labels = [ - f"link_{link_id}_id", - f"link_{link_id}_n/a", - f"link_{link_id}_n/a", - f"link_{link_id}_n/a", - ] - for j, serv in enumerate(self.env.services_list): - link_labels.append(f"link_{link_id}_service_{serv}_load") - - structure.extend(link_labels) - return structure - - -class NodeStatuses(AbstractObservationComponent): - """ - Flat list of nodes' hardware, OS, file system, and service states. - - The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by - integers. - Each node has 3 elements plus 1 per service. It will have the following structure: - .. code-block:: - - [ - node1 hardware state, - node1 OS state, - node1 file system state, - node1 service1 state, - node1 service2 state, - node1 serviceN state (one for each service), - node2 hardware state, - node2 OS state, - node2 file system state, - node2 service1 state, - node2 service2 state, - node2 serviceN state (one for each service), - ... - ] - """ - - _DATA_TYPE: type = np.int64 - - def __init__(self, env: "Primaite") -> None: - """ - Initialise a NodeStatuses observation component. - - :param env: Training environment. - :type env: Primaite - """ - super().__init__(env) - - # 1. Define the shape of your observation space component - node_shape = [ - len(HardwareState) + 1, - len(SoftwareState) + 1, - len(FileSystemState) + 1, - ] - services_shape = [len(SoftwareState) + 1] * self.env.num_services - node_shape = node_shape + services_shape - - shape = node_shape * self.env.num_nodes - # 2. Create Observation space - self.space = spaces.MultiDiscrete(shape) - - # 3. Initialise observation with zeroes - self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - self.structure = self.generate_structure() - - def update(self) -> None: - """ - Update the observation based on current environment state. - - The structure of the observation space is described in :class:`.NodeStatuses` - """ - obs = [] - for _, node in self.env.nodes.items(): - hardware_state = node.hardware_state.value - software_state = 0 - file_system_state = 0 - service_states = [0] * self.env.num_services - - if isinstance(node, ActiveNode): - software_state = node.software_state.value - file_system_state = node.file_system_state_observed.value - - if isinstance(node, ServiceNode): - for i, service in enumerate(self.env.services_list): - if node.has_service(service): - service_states[i] = node.get_service_state(service).value - obs.extend( - [ - hardware_state, - software_state, - file_system_state, - *service_states, - ] - ) - self.current_observation[:] = obs - - def generate_structure(self) -> List[str]: - """Return a list of labels for the components of the flattened observation space.""" - services = self.env.services_list - - structure = [] - - for _, node in self.env.nodes.items(): - node_id = node.node_id - structure.append(f"node_{node_id}_hardware_state_NONE") - for state in HardwareState: - structure.append(f"node_{node_id}_hardware_state_{state.name}") - structure.append(f"node_{node_id}_software_state_NONE") - for state in SoftwareState: - structure.append(f"node_{node_id}_software_state_{state.name}") - structure.append(f"node_{node_id}_file_system_state_NONE") - for state in FileSystemState: - structure.append(f"node_{node_id}_file_system_state_{state.name}") - for service in services: - structure.append(f"node_{node_id}_service_{service}_state_NONE") - for state in SoftwareState: - structure.append(f"node_{node_id}_service_{service}_state_{state.name}") - return structure - - -class LinkTrafficLevels(AbstractObservationComponent): - """ - Flat list of traffic levels encoded into banded categories. - - For each link, total traffic or traffic per service is encoded into a categorical value. - For example, if ``quantisation_levels=5``, the traffic levels represent these values: - - * 0 = No traffic (0% of bandwidth) - * 1 = No traffic (0%-33% of bandwidth) - * 2 = No traffic (33%-66% of bandwidth) - * 3 = No traffic (66%-100% of bandwidth) - * 4 = No traffic (100% of bandwidth) - - .. note:: - The lowest category always corresponds to no traffic and the highest category to the link being at max capacity. - Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories. - - """ - - _DATA_TYPE: type = np.int64 - - def __init__( - self, - env: "Primaite", - combine_service_traffic: bool = False, - quantisation_levels: int = 5, - ) -> None: - """ - Initialise a LinkTrafficLevels observation component. - - :param env: The environment that forms the basis of the observations - :type env: Primaite - :param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually, - defaults to False - :type combine_service_traffic: bool, optional - :param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical - value, defaults to 5 - :type quantisation_levels: int, optional - """ - if quantisation_levels < 3: - _msg = ( - f"quantisation_levels must be 3 or more because the lowest and highest levels are " - f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. " - f"Resetting to default value (5)" - ) - _LOGGER.warning(_msg) - quantisation_levels = 5 - - super().__init__(env) - - self._combine_service_traffic: bool = combine_service_traffic - self._quantisation_levels: int = quantisation_levels - self._entries_per_link: int = 1 - - if not self._combine_service_traffic: - self._entries_per_link = self.env.num_services - - # 1. Define the shape of your observation space component - shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link - - # 2. Create Observation space - self.space = spaces.MultiDiscrete(shape) - - # 3. Initialise observation with zeroes - self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - - self.structure = self.generate_structure() - - def update(self) -> None: - """ - Update the observation based on current environment state. - - The structure of the observation space is described in :class:`.LinkTrafficLevels` - """ - obs = [] - for _, link in self.env.links.items(): - bandwidth = link.bandwidth - if self._combine_service_traffic: - loads = [link.get_current_load()] - else: - loads = [protocol.get_load() for protocol in link.protocol_list] - - for load in loads: - if load <= 0: - traffic_level = 0 - elif load >= bandwidth: - traffic_level = self._quantisation_levels - 1 - else: - traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1 - - obs.append(int(traffic_level)) - - self.current_observation[:] = obs - - def generate_structure(self) -> List[str]: - """Return a list of labels for the components of the flattened observation space.""" - structure = [] - for _, link in self.env.links.items(): - link_id = link.id - if self._combine_service_traffic: - protocols = ["overall"] - else: - protocols = [protocol.name for protocol in link.protocol_list] - - for p in protocols: - for i in range(self._quantisation_levels): - structure.append(f"link_{link_id}_{p}_traffic_level_{i}") - return structure - - -class AccessControlList(AbstractObservationComponent): - """Flat list of all the Access Control Rules in the Access Control List. - - The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by - integers. - - Each ACL Rule has 6 elements. It will have the following structure: - .. code-block:: - [ - acl_rule1 permission, - acl_rule1 source_ip, - acl_rule1 dest_ip, - acl_rule1 protocol, - acl_rule1 port, - acl_rule1 position, - acl_rule2 permission, - acl_rule2 source_ip, - acl_rule2 dest_ip, - acl_rule2 protocol, - acl_rule2 port, - acl_rule2 position, - ... - ] - - - Terms (for ACL Observation Space): - [0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW) - [0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs) - [0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs) - [0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol) - [0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port) - [0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list) - - NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object. - """ - - _DATA_TYPE: type = np.int64 - - def __init__(self, env: "Primaite") -> None: - """ - Initialise an AccessControlList observation component. - - :param env: The environment that forms the basis of the observations - :type env: Primaite - """ - super().__init__(env) - - # 1. Define the shape of your observation space component - # The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports. - # Number of ACL rules incremented by 1 for positions starting at index 0. - acl_shape = [ - len(RulePermissionType), - len(env.nodes) + 2, - len(env.nodes) + 2, - len(env.services_list) + 2, - len(env.ports_list) + 2, - env.max_number_acl_rules, - ] - shape = acl_shape * self.env.max_number_acl_rules - - # 2. Create Observation space - self.space = spaces.MultiDiscrete(shape) - - # 3. Initialise observation with zeroes - self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - - self.structure = self.generate_structure() - - def update(self) -> None: - """Update the observation based on current environment state. - - The structure of the observation space is described in :class:`.AccessControlList` - """ - obs = [] - - for index in range(0, len(self.env.acl.acl)): - acl_rule = self.env.acl.acl[index] - if isinstance(acl_rule, ACLRule): - permission = acl_rule.permission - source_ip = acl_rule.source_ip - dest_ip = acl_rule.dest_ip - protocol = acl_rule.protocol - port = acl_rule.port - position = index - # Map each ACL attribute from what it was to an integer to fit the observation space - source_ip_int = None - dest_ip_int = None - if permission == RulePermissionType.DENY: - permission_int = 1 - else: - permission_int = 2 - if source_ip == "ANY": - source_ip_int = 1 - else: - # Map Node ID (+ 1) to source IP address - nodes = list(self.env.nodes.values()) - for node in nodes: - if ( - isinstance(node, ServiceNode) or isinstance(node, ActiveNode) - ) and node.ip_address == source_ip: - source_ip_int = int(node.node_id) + 1 - break - if dest_ip == "ANY": - dest_ip_int = 1 - else: - # Map Node ID (+ 1) to dest IP address - # Index of Nodes start at 1 so + 1 is needed so NA can be added. - nodes = list(self.env.nodes.values()) - for node in nodes: - if ( - isinstance(node, ServiceNode) or isinstance(node, ActiveNode) - ) and node.ip_address == dest_ip: - dest_ip_int = int(node.node_id) + 1 - if protocol == "ANY": - protocol_int = 1 - else: - # Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY - try: - protocol_int = self.env.services_list.index(protocol) + 2 - except AttributeError: - _LOGGER.info(f"Service {protocol} could not be found") - protocol_int = None - if port == "ANY": - port_int = 1 - else: - if port in self.env.ports_list: - port_int = self.env.ports_list.index(port) + 2 - else: - _LOGGER.info(f"Port {port} could not be found.") - port_int = None - # Add to current obs - obs.extend( - [ - permission_int, - source_ip_int, - dest_ip_int, - protocol_int, - port_int, - position, - ] - ) - - else: - # The Nothing or NA representation of 'NONE' ACL rules - obs.extend([0, 0, 0, 0, 0, 0]) - - self.current_observation[:] = obs - - def generate_structure(self) -> List[str]: - """Return a list of labels for the components of the flattened observation space.""" - structure = [] - for acl_rule in self.env.acl.acl: - acl_rule_id = self.env.acl.acl.index(acl_rule) - - for permission in RulePermissionType: - structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}") - - structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY") - for node in self.env.nodes.keys(): - structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}") - - structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY") - for node in self.env.nodes.keys(): - structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}") - - structure.append(f"acl_rule_{acl_rule_id}_service_ANY") - for service in self.env.services_list: - structure.append(f"acl_rule_{acl_rule_id}_service_{service}") - - structure.append(f"acl_rule_{acl_rule_id}_port_ANY") - for port in self.env.ports_list: - structure.append(f"acl_rule_{acl_rule_id}_port_{port}") - - return structure - - -class ObservationsHandler: - """ - Component-based observation space handler. - - This allows users to configure observation spaces by mixing and matching components. Each component can also define - further parameters to make them more flexible. - """ - - _REGISTRY: Final[Dict[str, type]] = { - "NODE_LINK_TABLE": NodeLinkTable, - "NODE_STATUSES": NodeStatuses, - "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, - "ACCESS_CONTROL_LIST": AccessControlList, - } - - def __init__(self) -> None: - """Initialise the observation handler.""" - self.registered_obs_components: List[AbstractObservationComponent] = [] - - # internal the observation space (unflattened version of space if flatten=True) - self._space: spaces.Space - # flattened version of the observation space - self._flat_space: spaces.Space - - self._observation: Union[Tuple[np.ndarray], np.ndarray] - # used for transactions and when flatten=true - self._flat_observation: np.ndarray - - def update_obs(self) -> None: - """Fetch fresh information about the environment.""" - current_obs = [] - for obs in self.registered_obs_components: - obs.update() - current_obs.append(obs.current_observation) - - if len(current_obs) == 1: - self._observation = current_obs[0] - else: - self._observation = tuple(current_obs) - self._flat_observation = spaces.flatten(self._space, self._observation) - - def register(self, obs_component: AbstractObservationComponent) -> None: - """ - Add a component for this handler to track. - - :param obs_component: The component to add. - :type obs_component: AbstractObservationComponent - """ - self.registered_obs_components.append(obs_component) - self.update_space() - - def deregister(self, obs_component: AbstractObservationComponent) -> None: - """ - Remove a component from this handler. - - :param obs_component: Which component to remove. It must exist within this object's - ``registered_obs_components`` attribute. - :type obs_component: AbstractObservationComponent - """ - self.registered_obs_components.remove(obs_component) - self.update_space() - - def update_space(self) -> None: - """Rebuild the handler's composite observation space from its components.""" - component_spaces = [] - for obs_comp in self.registered_obs_components: - component_spaces.append(obs_comp.space) - - # if there are multiple components, build a composite tuple space - if len(component_spaces) == 1: - self._space = component_spaces[0] - else: - self._space = spaces.Tuple(component_spaces) - if len(component_spaces) > 0: - self._flat_space = spaces.flatten_space(self._space) - else: - self._flat_space = spaces.Box(0, 1, (0,)) - - @property - def space(self) -> spaces.Space: - """Observation space, return the flattened version if flatten is True.""" - if len(self.registered_obs_components) > 1: - return self._flat_space - else: - return self._space - - @property - def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]: - """Current observation, return the flattened version if flatten is True.""" - if len(self.registered_obs_components) > 1: - return self._flat_observation - else: - return self._observation - - @classmethod - def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler": - """ - Parse a config dictinary, return a new observation handler populated with new observation component objects. - - The expected format for the config dictionary is: - - .. code-block:: python - - config = { - components: [ - { - "name": "" - }, - { - "name": "" - "options": {"opt1": val1, "opt2": val2} - }, - { - ... - }, - ] - } - - :return: Observation handler - :rtype: primaite.environment.observations.ObservationsHandler - """ - # Instantiate the handler - handler = cls() - - for component_cfg in obs_space_config["components"]: - # Figure out which class can instantiate the desired component - comp_type = component_cfg["name"] - comp_class = cls._REGISTRY[comp_type] - - # Create the component with options from the YAML - options = component_cfg.get("options") or {} - component = comp_class(env, **options) - - handler.register(component) - - handler.update_obs() - return handler - - def describe_structure(self) -> List[str]: - """ - Create a list of names for the features of the obs space. - - The order of labels follows the flattened version of the space. - """ - # as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have - # to fake it. each component has to just hard-code the expected label order after flattening... - - labels = [] - for obs_comp in self.registered_obs_components: - labels.extend(obs_comp.structure) - - return labels diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py deleted file mode 100644 index a809772f..00000000 --- a/src/primaite/environment/primaite_env.py +++ /dev/null @@ -1,1408 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" -import copy -import logging -import uuid as uuid -from logging import Logger -from pathlib import Path -from random import choice, randint, sample, uniform -from typing import Any, Dict, Final, List, Tuple, Union - -import networkx as nx -import numpy as np -from gym import Env, spaces -from matplotlib import pyplot as plt - -from primaite import getLogger -from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action -from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( # AgentFramework, - ActionType, - AgentIdentifier, - FileSystemState, - HardwareState, - NodePOLInitiator, - NodePOLType, - NodeType, - ObservationType, - Priority, - SessionType, - SoftwareState, -) -from primaite.common.service import Service -from primaite.config import training_config -from primaite.config.lay_down_config import load -from primaite.config.training_config import TrainingConfig -from primaite.environment.observations import ObservationsHandler -from primaite.environment.reward import calculate_reward_function -from primaite.links.link import Link -from primaite.nodes.active_node import ActiveNode -from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen -from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed -from primaite.nodes.passive_node import PassiveNode -from primaite.nodes.service_node import ServiceNode -from primaite.pol.green_pol import apply_iers, apply_node_pol -from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol -from primaite.transactions.transaction import Transaction -from primaite.utils.session_output_writer import SessionOutputWriter - -_LOGGER: Logger = getLogger(__name__) - - -class Primaite(Env): - """PRIMmary AI Training Evironment (Primaite) class.""" - - # Action Space contants - ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5 - ACTION_SPACE_NODE_ACTION_VALUES: int = 4 - ACTION_SPACE_ACL_ACTION_VALUES: int = 3 - ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 - - def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - session_path: Path, - timestamp_str: str, - legacy_training_config: bool = False, - legacy_lay_down_config: bool = False, - ) -> None: - """ - The Primaite constructor. - - :param training_config_path: The training config filepath. - :param lay_down_config_path: The lay down config filepath. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: _. - :param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0, - otherwise False. - :param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0, - otherwise False. - """ - self.session_path: Final[Path] = session_path - self.timestamp_str: Final[str] = timestamp_str - self._training_config_path: Union[str, Path] = training_config_path - self._lay_down_config_path: Union[str, Path] = lay_down_config_path - self.legacy_training_config = legacy_training_config - self.legacy_lay_down_config = legacy_lay_down_config - - self.training_config: TrainingConfig = training_config.load(training_config_path, self.legacy_training_config) - _LOGGER.info(f"Using: {str(self.training_config)}") - - # Number of steps in an episode - self.episode_steps: int - if self.training_config.session_type == SessionType.TRAIN: - self.episode_steps = self.training_config.num_train_steps - elif self.training_config.session_type == SessionType.EVAL: - self.episode_steps = self.training_config.num_eval_steps - else: - self.episode_steps = self.training_config.num_train_steps - - super(Primaite, self).__init__() - - # The agent in use - self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier - - # Create a dictionary to hold all the nodes - self.nodes: Dict[str, NodeUnion] = {} - - # Create a dictionary to hold a reference set of nodes - self.nodes_reference: Dict[str, NodeUnion] = {} - - # Create a dictionary to hold all the links - self.links: Dict[str, Link] = {} - - # Create a dictionary to hold a reference set of links - self.links_reference: Dict[str, Link] = {} - - # Create a dictionary to hold all the green IERs (this will come from an external source) - self.green_iers: Dict[str, IER] = {} - self.green_iers_reference: Dict[str, IER] = {} - - # Create a dictionary to hold all the node PoLs (this will come from an external source) - self.node_pol: Dict[str, NodeStateInstructionGreen] = {} - - # Create a dictionary to hold all the red agent IERs (this will come from an external source) - self.red_iers: Dict[str, IER] = {} - - # Create a dictionary to hold all the red agent node PoLs (this will come from an external source) - self.red_node_pol: Dict[str, NodeStateInstructionRed] = {} - - # Create the Access Control List - self.acl: AccessControlList = AccessControlList( - self.training_config.implicit_acl_rule, - self.training_config.max_number_acl_rules, - ) - # Sets limit for number of ACL rules in environment - self.max_number_acl_rules: int = self.training_config.max_number_acl_rules - - # Create a list of services (enums) - self.services_list: List[str] = [] - - # Create a list of ports - self.ports_list: List[str] = [] - - # Create graph (network) - self.network: nx.Graph = nx.MultiGraph() - - # Create a graph (network) reference - self.network_reference: nx.Graph = nx.MultiGraph() - - # Create step count - self.step_count: int = 0 - - self.total_step_count: int = 0 - """The total number of time steps completed.""" - - # Create step info dictionary - self.step_info: Dict[Any] = {} - - # Total reward - self.total_reward: float = 0 - - # Average reward - self.average_reward: float = 0 - - # Episode count - self.episode_count: int = 0 - - # Number of nodes - gets a value by examining the nodes dictionary after it's been populated - self.num_nodes: int = 0 - - # Number of links - gets a value by examining the links dictionary after it's been populated - self.num_links: int = 0 - - # Number of services - gets a value when config is loaded - self.num_services: int = 0 - - # Number of ports - gets a value when config is loaded - self.num_ports: int = 0 - - # The action type - # TODO: confirm type - self.action_type: int = 0 - - # TODO fix up with TrainingConfig - # stores the observation config from the yaml, default is NODE_LINK_TABLE - self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]} - if self.training_config.observation_space is not None: - self.obs_config = self.training_config.observation_space - - # Observation Handler manages the user-configurable observation space. - # It will be initialised later. - self.obs_handler: ObservationsHandler - - self._obs_space_description: List[str] = None - "The env observation space description for transactions writing" - - self.lay_down_config = load(self._lay_down_config_path, self.legacy_lay_down_config) - self.load_lay_down_config() - - # Store the node objects as node attributes - # (This is so we can access them as objects) - for node in self.network: - self.network.nodes[node]["self"] = node - - for node in self.network_reference: - self.network_reference.nodes[node]["self"] = node - - self.num_nodes = len(self.nodes) - self.num_links = len(self.links) - - # Visualise in PNG - try: - plt.tight_layout() - nx.draw_networkx(self.network, with_labels=True) - - file_path = session_path / f"network_{timestamp_str}.png" - plt.savefig(file_path, format="PNG") - plt.clf() - except Exception: - _LOGGER.error("Could not save network diagram", exc_info=True) - - # Initiate observation space - self.observation_space: spaces.Space - self.env_obs: np.ndarray - self.observation_space, self.env_obs = self.init_observations() - - # Define Action Space - depends on action space type (Node or ACL) - self.action_dict: Dict[int, List[int]] - self.action_space: spaces.Space - if self.training_config.action_type == ActionType.NODE: - _LOGGER.debug("Action space type NODE selected") - # Terms (for node action space): - # [0, num nodes] - node ID (0 = nothing, node ID) - # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, # noqa - # service state, file system state) # noqa - # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa - # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa - self.action_dict = self.create_node_action_dict() - self.action_space = spaces.Discrete(len(self.action_dict)) - elif self.training_config.action_type == ActionType.ACL: - _LOGGER.debug("Action space type ACL selected") - # Terms (for ACL action space): - # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) - # [0, 1] - Permission (0 = DENY, 1 = ALLOW) - # [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) - # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) - # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) - # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) - self.action_dict = self.create_acl_action_dict() - self.action_space = spaces.Discrete(len(self.action_dict)) - elif self.training_config.action_type == ActionType.ANY: - _LOGGER.debug("Action space type ANY selected - Node + ACL") - self.action_dict = self.create_node_and_acl_action_dict() - self.action_space = spaces.Discrete(len(self.action_dict)) - else: - _LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}") - - self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter( - self, transaction_writer=False, learning_session=True - ) - self.transaction_writer: SessionOutputWriter = SessionOutputWriter( - self, transaction_writer=True, learning_session=True - ) - - self.is_eval = False - - @property - def actual_episode_count(self) -> int: - """Shifts the episode_count by -1 for RLlib learning session.""" - # if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval: - # return self.episode_count - 1 - return self.episode_count - - def set_as_eval(self) -> None: - """Set the writers to write to eval directories.""" - self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False) - self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False) - self.episode_count = 0 - self.step_count = 0 - self.total_step_count = 0 - self.episode_steps = self.training_config.num_eval_steps - self.is_eval = True - - def _write_av_reward_per_episode(self) -> None: - if self.actual_episode_count > 0: - csv_data = self.actual_episode_count, self.average_reward - self.episode_av_reward_writer.write(csv_data) - - def reset(self) -> np.ndarray: - """ - AI Gym Reset function. - - Returns: - Environment observation space (reset) - """ - self._write_av_reward_per_episode() - self.episode_count += 1 - - # Don't need to reset links, as they are cleared and recalculated every - # step - - # Clear the ACL - self.init_acl() - - # Reset the node statuses and recreate the ACL from config - # Does this for both live and reference nodes - self.reset_environment() - - # Create a random red agent to use for this episode - if self.training_config.random_red_agent: - self._create_random_red_agent() - - # Reset counters and totals - self.total_reward = 0.0 - self.step_count = 0 - self.average_reward = 0.0 - - # Update observations space and return - self.update_environent_obs() - - return self.env_obs - - def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]: - """ - AI Gym Step function. - - Args: - action: Action space from agent - - Returns: - env_obs: Observation space - reward: Reward value for this step - done: Indicates episode is complete if True - step_info: Additional information relating to this step - """ - # TEMP - done = False - self.step_count += 1 - self.total_step_count += 1 - - # Need to clear traffic on all links first - for link_key, link_value in self.links.items(): - link_value.clear_traffic() - - for link in self.links_reference.values(): - link.clear_traffic() - - # Create a Transaction (metric) object for this step - transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) - # Load the initial observation space into the transaction - transaction.obs_space = self.obs_handler._flat_observation - - # Set the transaction obs space description - transaction.obs_space_description = self._obs_space_description - - # Load the action space into the transaction - transaction.action_space = copy.deepcopy(action) - - # 1. Implement Blue Action - self.interpret_action_and_apply(action) - # Take snapshots of nodes and links - self.nodes_post_blue = copy.deepcopy(self.nodes) - self.links_post_blue = copy.deepcopy(self.links) - - # 2. Perform any time-based activities (e.g. a component moving from patching to good) - self.apply_time_based_updates() - - # 3. Apply PoL - apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL - apply_iers( - self.network, - self.nodes, - self.links, - self.green_iers, - self.acl, - self.step_count, - ) # Network PoL - # Take snapshots of nodes and links - self.nodes_post_pol = copy.deepcopy(self.nodes) - self.links_post_pol = copy.deepcopy(self.links) - # Reference - apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL - apply_iers( - self.network_reference, - self.nodes_reference, - self.links_reference, - self.green_iers_reference, - self.acl, - self.step_count, - ) # Network PoL - - # 4. Implement Red Action - apply_red_agent_iers( - self.network, - self.nodes, - self.links, - self.red_iers, - self.acl, - self.step_count, - ) - apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count) - # Take snapshots of nodes and links - self.nodes_post_red = copy.deepcopy(self.nodes) - self.links_post_red = copy.deepcopy(self.links) - - # 5. Calculate reward signal (for RL) - reward = calculate_reward_function( - self.nodes_post_pol, - self.nodes_post_red, - self.nodes_reference, - self.green_iers, - self.green_iers_reference, - self.red_iers, - self.step_count, - self.training_config, - ) - _LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}") - self.total_reward += reward - if self.step_count == self.episode_steps: - self.average_reward = self.total_reward / self.step_count - if self.training_config.session_type is SessionType.EVAL: - # For evaluation, need to trigger the done value = True when - # step count is reached in order to prevent neverending episode - done = True - _LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}") - # Load the reward into the transaction - transaction.reward = reward - - # 6. Output Verbose - # self.output_link_status() - - # 7. Update env_obs - self.update_environent_obs() - - # Write transaction to file - if self.actual_episode_count > 0: - self.transaction_writer.write(transaction) - - # Return - return self.env_obs, reward, done, self.step_info - - def close(self) -> None: - """Override parent close and close writers.""" - # Close files if last episode/step - # if self.can_finish: - super().close() - - self.transaction_writer.close() - self.episode_av_reward_writer.close() - - def init_acl(self) -> None: - """Initialise the Access Control List.""" - self.acl.remove_all_rules() - - def output_link_status(self) -> None: - """Output the link status of all links to the console.""" - for link_key, link_value in self.links.items(): - _LOGGER.debug("Link ID: " + link_value.get_id()) - for protocol in link_value.protocol_list: - print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) - - def interpret_action_and_apply(self, _action: int) -> None: - """ - Applies agent actions to the nodes and Access Control List. - - Args: - _action: The action space from the agent - """ - # At the moment, actions are only affecting nodes - if self.training_config.action_type == ActionType.NODE: - self.apply_actions_to_nodes(_action) - elif self.training_config.action_type == ActionType.ACL: - self.apply_actions_to_acl(_action) - elif len(self.action_dict[_action]) == 7: # ACL actions in multidiscrete form have len 7 - self.apply_actions_to_acl(_action) - elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 - self.apply_actions_to_nodes(_action) - else: - logging.error("Invalid action type found") - - def apply_actions_to_nodes(self, _action: int) -> None: - """ - Applies agent actions to the nodes. - - Args: - _action: The action space from the agent - """ - readable_action = self.action_dict[_action] - node_id = readable_action[0] - node_property = readable_action[1] - property_action = readable_action[2] - service_index = readable_action[3] - - # Check that the action is requesting a valid node - try: - node = self.nodes[str(node_id)] - except Exception: - return - - if node_property == 0: - # This is the do nothing action - return - elif node_property == 1: - # This is an action on the node Hardware State - if property_action == 0: - # Do nothing - return - elif property_action == 1: - # Turn on (only applicable if it's OFF, not if it's patching) - if node.hardware_state == HardwareState.OFF: - node.turn_on() - elif property_action == 2: - # Turn off - node.turn_off() - elif property_action == 3: - # Reset (only applicable if it's ON) - if node.hardware_state == HardwareState.ON: - node.reset() - else: - return - elif node_property == 2: - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - # This is an action on the node Software State - if property_action == 0: - # Do nothing - return - elif property_action == 1: - # Patch (valid action if it's good or compromised) - node.software_state = SoftwareState.PATCHING - else: - # Node is not of Active or Service Type - return - elif node_property == 3: - # This is an action on a node Service State - if isinstance(node, ServiceNode): - # This is an action on a node Service State - if property_action == 0: - # Do nothing - return - elif property_action == 1: - # Patch (valid action if it's good or compromised) - node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING) - else: - # Node is not of Service Type - return - elif node_property == 4: - # This is an action on a node file system state - if isinstance(node, ActiveNode): - if property_action == 0: - # Do nothing - return - elif property_action == 1: - # Scan - node.start_file_system_scan() - elif property_action == 2: - # Repair - # You cannot repair a destroyed file system - it needs restoring - if node.file_system_state_actual != FileSystemState.DESTROYED: - node.set_file_system_state(FileSystemState.REPAIRING) - elif property_action == 3: - # Restore - node.set_file_system_state(FileSystemState.RESTORING) - else: - # Node is not of Active Type - return - else: - return - - def apply_actions_to_acl(self, _action: int) -> None: - """ - Applies agent actions to the Access Control List [TO DO]. - - Args: - _action: The action space from the agent - """ - # Convert discrete value back to multidiscrete - readable_action = self.action_dict[_action] - - action_decision = readable_action[0] - action_permission = readable_action[1] - action_source_ip = readable_action[2] - action_destination_ip = readable_action[3] - action_protocol = readable_action[4] - action_port = readable_action[5] - acl_rule_position = readable_action[6] - - if action_decision == 0: - # It's decided to do nothing - return - else: - # It's decided to create a new ACL rule or remove an existing rule - # Permission value - if action_permission == 0: - acl_rule_permission = "DENY" - else: - acl_rule_permission = "ALLOW" - # Source IP value - if action_source_ip == 0: - acl_rule_source = "ANY" - else: - node = list(self.nodes.values())[action_source_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): - acl_rule_source = node.ip_address - else: - return - # Destination IP value - if action_destination_ip == 0: - acl_rule_destination = "ANY" - else: - node = list(self.nodes.values())[action_destination_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): - acl_rule_destination = node.ip_address - else: - return - # Protocol value - if action_protocol == 0: - acl_rule_protocol = "ANY" - else: - acl_rule_protocol = self.services_list[action_protocol - 1] - # Port value - if action_port == 0: - acl_rule_port = "ANY" - else: - acl_rule_port = self.ports_list[action_port - 1] - - # Now add or remove - if action_decision == 1: - # Add the rule - self.acl.add_rule( - acl_rule_permission, - acl_rule_source, - acl_rule_destination, - acl_rule_protocol, - acl_rule_port, - acl_rule_position, - ) - elif action_decision == 2: - # Remove the rule - self.acl.remove_rule( - acl_rule_permission, - acl_rule_source, - acl_rule_destination, - acl_rule_protocol, - acl_rule_port, - ) - else: - return - - def apply_time_based_updates(self) -> None: - """ - Updates anything that needs to count down and then change state. - - e.g. reset / patching status - """ - for node_key, node in self.nodes.items(): - if node.hardware_state == HardwareState.RESETTING: - node.update_resetting_status() - else: - pass - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - node.update_file_system_state() - if node.software_state == SoftwareState.PATCHING: - node.update_os_patching_status() - else: - pass - else: - pass - if isinstance(node, ServiceNode): - node.update_services_patching_status() - else: - pass - - for node_key, node in self.nodes_reference.items(): - if node.hardware_state == HardwareState.RESETTING: - node.update_resetting_status() - else: - pass - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - node.update_file_system_state() - if node.software_state == SoftwareState.PATCHING: - node.update_os_patching_status() - else: - pass - else: - pass - if isinstance(node, ServiceNode): - node.update_services_patching_status() - else: - pass - - def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """ - Create the environment's observation handler. - - :return: The observation space, initial observation (zeroed out array with the correct shape) - :rtype: Tuple[spaces.Space, np.ndarray] - """ - self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) - - if not self._obs_space_description: - self._obs_space_description = self.obs_handler.describe_structure() - - return self.obs_handler.space, self.obs_handler.current_observation - - def update_environent_obs(self) -> None: - """Updates the observation space based on the node and link status.""" - self.obs_handler.update_obs() - self.env_obs = self.obs_handler.current_observation - - def load_lay_down_config(self) -> None: - """Loads config data in order to build the environment configuration.""" - for item in self.lay_down_config: - if item["item_type"] == "NODE": - # Create a node - self.create_node(item) - elif item["item_type"] == "LINK": - # Create a link - self.create_link(item) - elif item["item_type"] == "GREEN_IER": - # Create a Green IER - self.create_green_ier(item) - elif item["item_type"] == "GREEN_POL": - # Create a Green PoL - self.create_green_pol(item) - elif item["item_type"] == "RED_IER": - # Create a Red IER - self.create_red_ier(item) - elif item["item_type"] == "RED_POL": - # Create a Red PoL - self.create_red_pol(item) - elif item["item_type"] == "ACL_RULE": - # Create an ACL rule - self.create_acl_rule(item) - elif item["item_type"] == "SERVICES": - # Create the list of services - self.create_services_list(item) - elif item["item_type"] == "PORTS": - # Create the list of ports - self.create_ports_list(item) - else: - item_type = item["item_type"] - _LOGGER.error(f"Invalid item_type: {item_type}") - pass - - _LOGGER.info("Environment configuration loaded") - print("Environment configuration loaded") - - def create_node(self, item: Dict) -> None: - """ - Creates a node from config data. - - Args: - item: A config data item - """ - # All nodes have these parameters - node_id = item["node_id"] - node_name = item["name"] - node_class = item["node_class"] - node_type = NodeType[item["node_type"]] - node_priority = Priority[item["priority"]] - node_hardware_state = HardwareState[item["hardware_state"]] - - if node_class == "PASSIVE": - node = PassiveNode( - node_id, - node_name, - node_type, - node_priority, - node_hardware_state, - self.training_config, - ) - elif node_class == "ACTIVE": - # Active nodes have IP address, Software State and file system state - node_ip_address = item["ip_address"] - node_software_state = SoftwareState[item["software_state"]] - node_file_system_state = FileSystemState[item["file_system_state"]] - node = ActiveNode( - node_id, - node_name, - node_type, - node_priority, - node_hardware_state, - node_ip_address, - node_software_state, - node_file_system_state, - self.training_config, - ) - elif node_class == "SERVICE": - # Service nodes have IP address, Software State, file system state and list of services - node_ip_address = item["ip_address"] - node_software_state = SoftwareState[item["software_state"]] - node_file_system_state = FileSystemState[item["file_system_state"]] - node = ServiceNode( - node_id, - node_name, - node_type, - node_priority, - node_hardware_state, - node_ip_address, - node_software_state, - node_file_system_state, - self.training_config, - ) - node_services = item["services"] - for service in node_services: - service_protocol = service["name"] - service_port = service["port"] - service_state = SoftwareState[service["state"]] - node.add_service(Service(service_protocol, service_port, service_state)) - else: - # Bad formatting - pass - - # Copy the node for the reference version - node_ref = copy.deepcopy(node) - - # Add node to node dictionary - self.nodes[node_id] = node - - # Add reference node to reference node dictionary - self.nodes_reference[node_id] = node_ref - - # Add node to network - self.network.add_nodes_from([node]) - - # Add node to network (reference) - self.network_reference.add_nodes_from([node_ref]) - - def create_link(self, item: Dict) -> None: - """ - Creates a link from config data. - - Args: - item: A config data item - """ - link_id = item["id"] - link_name = item["name"] - link_bandwidth = item["bandwidth"] - link_source = item["source"] - link_destination = item["destination"] - - source_node: Node = self.nodes[link_source] - dest_node: Node = self.nodes[link_destination] - - # Add link to network - self.network.add_edge(source_node, dest_node, id=link_name) - - # Add link to link dictionary - self.links[link_name] = Link( - link_id, - link_bandwidth, - source_node.name, - dest_node.name, - self.services_list, - ) - - # Reference - source_node_ref: Node = self.nodes_reference[link_source] - dest_node_ref: Node = self.nodes_reference[link_destination] - - # Add link to network (reference) - self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name) - - # Add link to link dictionary (reference) - self.links_reference[link_name] = Link( - link_id, - link_bandwidth, - source_node_ref.name, - dest_node_ref.name, - self.services_list, - ) - - def create_green_ier(self, item: Dict) -> None: - """ - Creates a green IER from config data. - - Args: - item: A config data item - """ - ier_id = item["id"] - ier_start_step = item["start_step"] - ier_end_step = item["end_step"] - ier_load = item["load"] - ier_protocol = item["protocol"] - ier_port = item["port"] - ier_source = item["source"] - ier_destination = item["destination"] - ier_mission_criticality = item["mission_criticality"] - - # Create IER and add to green IER dictionary - self.green_iers[ier_id] = IER( - ier_id, - ier_start_step, - ier_end_step, - ier_load, - ier_protocol, - ier_port, - ier_source, - ier_destination, - ier_mission_criticality, - ) - self.green_iers_reference[ier_id] = IER( - ier_id, - ier_start_step, - ier_end_step, - ier_load, - ier_protocol, - ier_port, - ier_source, - ier_destination, - ier_mission_criticality, - ) - - def create_red_ier(self, item: Dict) -> None: - """ - Creates a red IER from config data. - - Args: - item: A config data item - """ - ier_id = item["id"] - ier_start_step = item["start_step"] - ier_end_step = item["end_step"] - ier_load = item["load"] - ier_protocol = item["protocol"] - ier_port = item["port"] - ier_source = item["source"] - ier_destination = item["destination"] - ier_mission_criticality = item["mission_criticality"] - - # Create IER and add to red IER dictionary - self.red_iers[ier_id] = IER( - ier_id, - ier_start_step, - ier_end_step, - ier_load, - ier_protocol, - ier_port, - ier_source, - ier_destination, - ier_mission_criticality, - ) - - def create_green_pol(self, item: Dict) -> None: - """ - Creates a green PoL object from config data. - - Args: - item: A config data item - """ - pol_id = item["id"] - pol_start_step = item["start_step"] - pol_end_step = item["end_step"] - pol_node = item["nodeId"] - pol_type = NodePOLType[item["type"]] - - # State depends on whether this is Operating, Software, file system or Service PoL type - if pol_type == NodePOLType.OPERATING: - pol_state = HardwareState[item["state"]] - pol_protocol = "" - elif pol_type == NodePOLType.FILE: - pol_state = FileSystemState[item["state"]] - pol_protocol = "" - else: - pol_protocol = item["protocol"] - pol_state = SoftwareState[item["state"]] - - self.node_pol[pol_id] = NodeStateInstructionGreen( - pol_id, - pol_start_step, - pol_end_step, - pol_node, - pol_type, - pol_protocol, - pol_state, - ) - - def create_red_pol(self, item: Dict) -> None: - """ - Creates a red PoL object from config data. - - Args: - item: A config data item - """ - pol_id = item["id"] - pol_start_step = item["start_step"] - pol_end_step = item["end_step"] - pol_target_node_id = item["targetNodeId"] - pol_initiator = NodePOLInitiator[item["initiator"]] - pol_type = NodePOLType[item["type"]] - pol_protocol = item["protocol"] - - # State depends on whether this is Operating, Software, file system or Service PoL type - if pol_type == NodePOLType.OPERATING: - pol_state = HardwareState[item["state"]] - elif pol_type == NodePOLType.FILE: - pol_state = FileSystemState[item["state"]] - else: - pol_state = SoftwareState[item["state"]] - - pol_source_node_id = item["sourceNodeId"] - pol_source_node_service = item["sourceNodeService"] - pol_source_node_service_state = item["sourceNodeServiceState"] - - self.red_node_pol[pol_id] = NodeStateInstructionRed( - pol_id, - pol_start_step, - pol_end_step, - pol_target_node_id, - pol_initiator, - pol_type, - pol_protocol, - pol_state, - pol_source_node_id, - pol_source_node_service, - pol_source_node_service_state, - ) - - def create_acl_rule(self, item: Dict) -> None: - """ - Creates an ACL rule from config data. - - Args: - item: A config data item - """ - acl_rule_permission = item["permission"] - acl_rule_source = item["source"] - acl_rule_destination = item["destination"] - acl_rule_protocol = item["protocol"] - acl_rule_port = item["port"] - acl_rule_position = item.get("position") - - self.acl.add_rule( - acl_rule_permission, - acl_rule_source, - acl_rule_destination, - acl_rule_protocol, - acl_rule_port, - acl_rule_position, - ) - - # TODO: confirm typehint using runtime - def create_services_list(self, services: Dict) -> None: - """ - Creates a list of services (enum) from config data. - - Args: - item: A config data item representing the services - """ - service_list = services["service_list"] - - for service in service_list: - service_name = service["name"] - self.services_list.append(service_name) - - # Set the number of services - self.num_services = len(self.services_list) - - def create_ports_list(self, ports: Dict) -> None: - """ - Creates a list of ports from config data. - - Args: - item: A config data item representing the ports - """ - ports_list = ports["ports_list"] - - for port in ports_list: - port_value = port["port"] - self.ports_list.append(port_value) - - # Set the number of ports - self.num_ports = len(self.ports_list) - - # TODO: this is not used anymore, write a ticket to delete it - def get_observation_info(self, observation_info: Dict) -> None: - """ - Extracts observation_info. - - :param observation_info: Config item that defines which type of observation space to use - :type observation_info: str - """ - self.observation_type = ObservationType[observation_info["type"]] - - # TODO: this is not used anymore, write a ticket to delete it. - def get_action_info(self, action_info: Dict) -> None: - """ - Extracts action_info. - - Args: - item: A config data item representing action info - """ - self.action_type = ActionType[action_info["type"]] - - def save_obs_config(self, obs_config: dict) -> None: - """ - Cache the config for the observation space. - - This is necessary as the observation space can't be built while reading the config, - it must be done after all the nodes, links, and services have been initialised. - - :param obs_config: Parsed config relating to the observation space. The format is described in - :py:meth:`primaite.environment.observations.ObservationsHandler.from_config` - :type obs_config: dict - """ - self.obs_config = obs_config - - def reset_environment(self) -> None: - """ - Resets environment. - - Uses config data config data in order to build the environment configuration. - """ - for item in self.lay_down_config: - if item["item_type"] == "NODE": - # Reset a node's state (normal and reference) - self.reset_node(item) - elif item["item_type"] == "ACL_RULE": - # Create an ACL rule (these are cleared on reset, so just need to recreate them) - self.create_acl_rule(item) - else: - # Do nothing (bad formatting or not relevant to reset) - pass - - # Reset the IER status so they are not running initially - # Green IERs - for ier_key, ier_value in self.green_iers.items(): - ier_value.set_is_running(False) - # Red IERs - for ier_key, ier_value in self.red_iers.items(): - ier_value.set_is_running(False) - - def reset_node(self, item: Dict) -> None: - """ - Resets the statuses of a node. - - Args: - item: A config data item - """ - # All nodes have these parameters - node_id = item["node_id"] - node_class = item["node_class"] - node_hardware_state: HardwareState = HardwareState[item["hardware_state"]] - - node: NodeUnion = self.nodes[node_id] - node_ref = self.nodes_reference[node_id] - - # Reset the hardware state (common for all node types) - node.hardware_state = node_hardware_state - node_ref.hardware_state = node_hardware_state - - if node_class == "ACTIVE": - # Active nodes have Software State - node_software_state = SoftwareState[item["software_state"]] - node_file_system_state = FileSystemState[item["file_system_state"]] - node.software_state = node_software_state - node_ref.software_state = node_software_state - node.set_file_system_state(node_file_system_state) - node_ref.set_file_system_state(node_file_system_state) - elif node_class == "SERVICE": - # Service nodes have Software State and list of services - node_software_state = SoftwareState[item["software_state"]] - node_file_system_state = FileSystemState[item["file_system_state"]] - node.software_state = node_software_state - node_ref.software_state = node_software_state - node.set_file_system_state(node_file_system_state) - node_ref.set_file_system_state(node_file_system_state) - # Update service states - node_services = item["services"] - for service in node_services: - service_protocol = service["name"] - service_state = SoftwareState[service["state"]] - # Update node service state - node.set_service_state(service_protocol, service_state) - # Update reference node service state - node_ref.set_service_state(service_protocol, service_state) - else: - # Bad formatting - pass - - def create_node_action_dict(self) -> Dict[int, List[int]]: - """ - Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. - - Note: Only actions that have the potential to change the state exist in the mapping (except for key 0) - - example return: - {0: [1, 0, 0, 0], - 1: [1, 1, 1, 0], - 2: [1, 1, 2, 0], - 3: [1, 1, 3, 0], - 4: [1, 2, 1, 0], - 5: [1, 3, 1, 0], - ... - } - """ - # Terms (for node action space): - # [0, num nodes] - node ID (0 = nothing, node ID) - # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa - # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa - # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa - # reserve 0 action to be a nothing action - actions = {0: [1, 0, 0, 0]} - action_key = 1 - for node in range(1, self.num_nodes + 1): - # 4 node properties (NONE, OPERATING, OS, SERVICE) - for node_property in range(4): - # Node Actions either: - # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state - # Use MAX to ensure we get them all - for node_action in range(4): - for service_state in range(self.num_services): - action = [node, node_property, node_action, service_state] - # check to see if it's a nothing action (has no effect) - if is_valid_node_action(action): - actions[action_key] = action - action_key += 1 - - return actions - - def create_acl_action_dict(self) -> Dict[int, List[int]]: - """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.""" - # Terms (for ACL action space): - # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) - # [0, 1] - Permission (0 = DENY, 1 = ALLOW) - # [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) - # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) - # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) - # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) - # [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list) - # reserve 0 action to be a nothing action - actions = {0: [0, 0, 0, 0, 0, 0, 0]} - - action_key = 1 - # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE - for action_decision in range(3): - # 2 possible action permissions 0 = DENY, 1 = CREATE - for action_permission in range(2): - # Number of nodes + 1 (for any) - for source_ip in range(self.num_nodes + 1): - for dest_ip in range(self.num_nodes + 1): - for protocol in range(self.num_services + 1): - for port in range(self.num_ports + 1): - for position in range(self.max_number_acl_rules - 1): - action = [ - action_decision, - action_permission, - source_ip, - dest_ip, - protocol, - port, - position, - ] - # Check to see if it is an action we want to include as possible - # i.e. not a nothing action - if is_valid_acl_action_extra(action): - actions[action_key] = action - action_key += 1 - - return actions - - def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]: - """ - Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. - - The dictionary contains actions of both Node and ACL action types. - """ - node_action_dict = self.create_node_action_dict() - acl_action_dict = self.create_acl_action_dict() - - # Change node keys to not overlap with acl keys - # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} - - # Combine the Node dict and ACL dict - combined_action_dict = {**acl_action_dict, **new_node_action_dict} - return combined_action_dict - - def _create_random_red_agent(self) -> None: - """Decide on random red agent for the episode to be called in env.reset().""" - # Reset the current red iers and red node pol - self.red_iers = {} - self.red_node_pol = {} - - # Decide how many nodes become compromised - node_list = list(self.nodes.values()) - computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] - max_num_nodes_compromised = len(computers) # only computers can become compromised - # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = randint(1, max_num_nodes_compromised) - - # Decide which of the nodes to compromise - nodes_to_be_compromised = sample(computers, num_nodes_to_compromise) - - # choose a random compromise node to be source of attacks - source_node = choice(nodes_to_be_compromised) - - # For each of the nodes to be compromised decide which step they become compromised - max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode - - # Bandwidth for all links - bandwidths = [i.get_bandwidth() for i in list(self.links.values())] - - if len(bandwidths) < 1: - msg = "Random red agent cannot be used on a network without any links" - _LOGGER.error(msg) - raise Exception(msg) - - servers = [node for node in node_list if node.node_type == NodeType.SERVER] - - for n, node in enumerate(nodes_to_be_compromised): - # 1: Use Node PoL to set node to compromised - - _id = str(uuid.uuid4()) - _start_step = randint(2, max_step_compromised + 1) # step compromised - pol_service_name = choice(list(node.services.keys())) - - source_node_service = choice(list(source_node.services.values())) - - red_pol = NodeStateInstructionRed( - _id=_id, - _start_step=_start_step, - _end_step=_start_step, # only run for 1 step - _target_node_id=node.node_id, - _pol_initiator="DIRECT", - _pol_type=NodePOLType["SERVICE"], - pol_protocol=pol_service_name, - _pol_state=SoftwareState.COMPROMISED, - _pol_source_node_id=source_node.node_id, - _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state, - ) - - self.red_node_pol[_id] = red_pol - - # 2: Launch the attack from compromised node - set the IER - - ier_id = str(uuid.uuid4()) - # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) - ier_end_step = self.episode_steps - - # Randomise the load, as a percentage of a random link bandwith - ier_load = uniform(0.4, 0.8) * choice(bandwidths) - ier_protocol = pol_service_name # Same protocol as compromised node - ier_service = node.services[pol_service_name] - ier_port = ier_service.port - ier_mission_criticality = 0 # Red IER will never be important to green agent success - # We choose a node to attack based on the first that applies: - # a. Green IERs, select dest node of the red ier based on dest node of green IER - # b. Attack a random server that doesn't have a DENY acl rule in default config - # c. Attack a random server - possible_ier_destinations = [ - ier.get_dest_node_id() - for ier in list(self.green_iers.values()) - if ier.get_source_node_id() == node.node_id - ] - if len(possible_ier_destinations) < 1: - for server in servers: - if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, - ): - possible_ier_destinations.append(server.node_id) - if len(possible_ier_destinations) < 1: - # If still none found choose from all servers - possible_ier_destinations = [server.node_id for server in servers] - ier_dest = choice(possible_ier_destinations) - self.red_iers[ier_id] = IER( - ier_id, - ier_start_step, - ier_end_step, - ier_load, - ier_protocol, - ier_port, - node.node_id, - ier_dest, - ier_mission_criticality, - ) - - overwhelm_pol = red_pol - overwhelm_pol.id = str(uuid.uuid4()) - overwhelm_pol.end_step = self.episode_steps - - # 3: Make sure the targetted node can be set to overwhelmed - with node pol - # # TODO remove duplicate red pol for same targetted service - must take into account start step - - o_pol_id = str(uuid.uuid4()) - o_red_pol = NodeStateInstructionRed( - _id=o_pol_id, - _start_step=ier_start_step, - _end_step=self.episode_steps, - _target_node_id=ier_dest, - _pol_initiator="DIRECT", - _pol_type=NodePOLType["SERVICE"], - pol_protocol=ier_protocol, - _pol_state=SoftwareState.OVERWHELMED, - _pol_source_node_id=source_node.node_id, - _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state, - ) - self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py deleted file mode 100644 index aa9dc97d..00000000 --- a/src/primaite/environment/reward.py +++ /dev/null @@ -1,386 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Implements reward function.""" -from logging import Logger -from typing import Dict, TYPE_CHECKING, Union - -from primaite import getLogger -from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import FileSystemState, HardwareState, SoftwareState -from primaite.common.service import Service -from primaite.nodes.active_node import ActiveNode -from primaite.nodes.service_node import ServiceNode - -if TYPE_CHECKING: - from primaite.config.training_config import TrainingConfig - from primaite.pol.ier import IER - -_LOGGER: Logger = getLogger(__name__) - - -def calculate_reward_function( - initial_nodes: Dict[str, NodeUnion], - final_nodes: Dict[str, NodeUnion], - reference_nodes: Dict[str, NodeUnion], - green_iers: Dict[str, "IER"], - green_iers_reference: Dict[str, "IER"], - red_iers: Dict[str, "IER"], - step_count: int, - config_values: "TrainingConfig", -) -> float: - """ - Compares the states of the initial and final nodes/links to get a reward. - - Args: - initial_nodes: The nodes before red and blue agents take effect - final_nodes: The nodes after red and blue agents take effect - reference_nodes: The nodes if there had been no red or blue effect - green_iers: The green IERs (should be running) - red_iers: Should be stopeed (ideally) by the blue agent - step_count: current step - config_values: Config values - """ - reward_value: float = 0.0 - - # For each node, compare hardware state, SoftwareState, service states - for node_key, final_node in final_nodes.items(): - initial_node = initial_nodes[node_key] - reference_node = reference_nodes[node_key] - - # Hardware State - reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values) - - # Software State - if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode): - reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values) - - # Service State - if isinstance(final_node, ServiceNode): - reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values) - - # File System State - if isinstance(final_node, ActiveNode): - reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values) - - # Go through each red IER - penalise if it is running - for ier_key, ier_value in red_iers.items(): - start_step = ier_value.get_start_step() - stop_step = ier_value.get_end_step() - if step_count >= start_step and step_count <= stop_step: - if ier_value.get_is_running(): - reward_value += config_values.red_ier_running - - # Go through each green IER - penalise if it's not running (weighted) - # but only if it's supposed to be running (it's running in reference) - for ier_key, ier_value in green_iers.items(): - reference_ier = green_iers_reference[ier_key] - start_step = ier_value.get_start_step() - stop_step = ier_value.get_end_step() - if step_count >= start_step and step_count <= stop_step: - reference_blocked = not reference_ier.get_is_running() - live_blocked = not ier_value.get_is_running() - ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality() - - if live_blocked and not reference_blocked: - reward_value += ier_reward - elif live_blocked and reference_blocked: - _LOGGER.debug( - ( - f"IER {ier_key} is blocked in the reference and live environments. " - f"Penalty of {ier_reward} was NOT applied." - ) - ) - elif not live_blocked and reference_blocked: - _LOGGER.debug( - ( - f"IER {ier_key} is blocked in the reference env but not in the live one. " - f"Penalty of {ier_reward} was NOT applied." - ) - ) - return reward_value - - -def score_node_operating_state( - final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig" -) -> float: - """ - Calculates score relating to the hardware state of a node. - - Args: - final_node: The node after red and blue agents take effect - initial_node: The node before red and blue agents take effect - reference_node: The node if there had been no red or blue effect - config_values: Config values - """ - score: float = 0.0 - final_node_operating_state = final_node.hardware_state - reference_node_operating_state = reference_node.hardware_state - - if final_node_operating_state == reference_node_operating_state: - # All is well - we're no different from the reference situation - score += config_values.all_ok - else: - # We're different from the reference situation - # Need to compare reference and final (current) state of node (i.e. at every step) - if reference_node_operating_state == HardwareState.ON: - if final_node_operating_state == HardwareState.OFF: - score += config_values.off_should_be_on - elif final_node_operating_state == HardwareState.RESETTING: - score += config_values.resetting_should_be_on - else: - pass - elif reference_node_operating_state == HardwareState.OFF: - if final_node_operating_state == HardwareState.ON: - score += config_values.on_should_be_off - elif final_node_operating_state == HardwareState.RESETTING: - score += config_values.resetting_should_be_off - else: - pass - elif reference_node_operating_state == HardwareState.RESETTING: - if final_node_operating_state == HardwareState.ON: - score += config_values.on_should_be_resetting - elif final_node_operating_state == HardwareState.OFF: - score += config_values.off_should_be_resetting - elif final_node_operating_state == HardwareState.RESETTING: - score += config_values.resetting - else: - pass - else: - pass - - return score - - -def score_node_os_state( - final_node: Union[ActiveNode, ServiceNode], - initial_node: Union[ActiveNode, ServiceNode], - reference_node: Union[ActiveNode, ServiceNode], - config_values: "TrainingConfig", -) -> float: - """ - Calculates score relating to the Software State of a node. - - Args: - final_node: The node after red and blue agents take effect - initial_node: The node before red and blue agents take effect - reference_node: The node if there had been no red or blue effect - config_values: Config values - """ - score: float = 0.0 - final_node_os_state = final_node.software_state - reference_node_os_state = reference_node.software_state - - if final_node_os_state == reference_node_os_state: - # All is well - we're no different from the reference situation - score += config_values.all_ok - else: - # We're different from the reference situation - # Need to compare reference and final (current) state of node (i.e. at every step) - if reference_node_os_state == SoftwareState.GOOD: - if final_node_os_state == SoftwareState.PATCHING: - score += config_values.patching_should_be_good - elif final_node_os_state == SoftwareState.COMPROMISED: - score += config_values.compromised_should_be_good - else: - pass - elif reference_node_os_state == SoftwareState.PATCHING: - if final_node_os_state == SoftwareState.GOOD: - score += config_values.good_should_be_patching - elif final_node_os_state == SoftwareState.COMPROMISED: - score += config_values.compromised_should_be_patching - elif final_node_os_state == SoftwareState.PATCHING: - score += config_values.patching - else: - pass - elif reference_node_os_state == SoftwareState.COMPROMISED: - if final_node_os_state == SoftwareState.GOOD: - score += config_values.good_should_be_compromised - elif final_node_os_state == SoftwareState.PATCHING: - score += config_values.patching_should_be_compromised - elif final_node_os_state == SoftwareState.COMPROMISED: - score += config_values.compromised - else: - pass - else: - pass - - return score - - -def score_node_service_state( - final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig" -) -> float: - """ - Calculates score relating to the service state(s) of a node. - - Args: - final_node: The node after red and blue agents take effect - initial_node: The node before red and blue agents take effect - reference_node: The node if there had been no red or blue effect - config_values: Config values - """ - score: float = 0.0 - final_node_services: Dict[str, Service] = final_node.services - reference_node_services: Dict[str, Service] = reference_node.services - - for service_key, final_service in final_node_services.items(): - reference_service = reference_node_services[service_key] - final_service = final_node_services[service_key] - - if final_service.software_state == reference_service.software_state: - # All is well - we're no different from the reference situation - score += config_values.all_ok - else: - # We're different from the reference situation - # Need to compare reference and final state of node (i.e. at every step) - if reference_service.software_state == SoftwareState.GOOD: - if final_service.software_state == SoftwareState.PATCHING: - score += config_values.patching_should_be_good - elif final_service.software_state == SoftwareState.COMPROMISED: - score += config_values.compromised_should_be_good - elif final_service.software_state == SoftwareState.OVERWHELMED: - score += config_values.overwhelmed_should_be_good - else: - pass - elif reference_service.software_state == SoftwareState.PATCHING: - if final_service.software_state == SoftwareState.GOOD: - score += config_values.good_should_be_patching - elif final_service.software_state == SoftwareState.COMPROMISED: - score += config_values.compromised_should_be_patching - elif final_service.software_state == SoftwareState.OVERWHELMED: - score += config_values.overwhelmed_should_be_patching - elif final_service.software_state == SoftwareState.PATCHING: - score += config_values.patching - else: - pass - elif reference_service.software_state == SoftwareState.COMPROMISED: - if final_service.software_state == SoftwareState.GOOD: - score += config_values.good_should_be_compromised - elif final_service.software_state == SoftwareState.PATCHING: - score += config_values.patching_should_be_compromised - elif final_service.software_state == SoftwareState.COMPROMISED: - score += config_values.compromised - elif final_service.software_state == SoftwareState.OVERWHELMED: - score += config_values.overwhelmed_should_be_compromised - else: - pass - elif reference_service.software_state == SoftwareState.OVERWHELMED: - if final_service.software_state == SoftwareState.GOOD: - score += config_values.good_should_be_overwhelmed - elif final_service.software_state == SoftwareState.PATCHING: - score += config_values.patching_should_be_overwhelmed - elif final_service.software_state == SoftwareState.COMPROMISED: - score += config_values.compromised_should_be_overwhelmed - elif final_service.software_state == SoftwareState.OVERWHELMED: - score += config_values.overwhelmed - else: - pass - else: - pass - - return score - - -def score_node_file_system( - final_node: Union[ActiveNode, ServiceNode], - initial_node: Union[ActiveNode, ServiceNode], - reference_node: Union[ActiveNode, ServiceNode], - config_values: "TrainingConfig", -) -> float: - """ - Calculates score relating to the file system state of a node. - - Args: - final_node: The node after red and blue agents take effect - initial_node: The node before red and blue agents take effect - reference_node: The node if there had been no red or blue effect - """ - score: float = 0.0 - final_node_file_system_state = final_node.file_system_state_actual - reference_node_file_system_state = reference_node.file_system_state_actual - - final_node_scanning_state = final_node.file_system_scanning - reference_node_scanning_state = reference_node.file_system_scanning - - # File System State - if final_node_file_system_state == reference_node_file_system_state: - # All is well - we're no different from the reference situation - score += config_values.all_ok - else: - # We're different from the reference situation - # Need to compare reference and final state of node (i.e. at every step) - if reference_node_file_system_state == FileSystemState.GOOD: - if final_node_file_system_state == FileSystemState.REPAIRING: - score += config_values.repairing_should_be_good - elif final_node_file_system_state == FileSystemState.RESTORING: - score += config_values.restoring_should_be_good - elif final_node_file_system_state == FileSystemState.CORRUPT: - score += config_values.corrupt_should_be_good - elif final_node_file_system_state == FileSystemState.DESTROYED: - score += config_values.destroyed_should_be_good - else: - pass - elif reference_node_file_system_state == FileSystemState.REPAIRING: - if final_node_file_system_state == FileSystemState.GOOD: - score += config_values.good_should_be_repairing - elif final_node_file_system_state == FileSystemState.RESTORING: - score += config_values.restoring_should_be_repairing - elif final_node_file_system_state == FileSystemState.CORRUPT: - score += config_values.corrupt_should_be_repairing - elif final_node_file_system_state == FileSystemState.DESTROYED: - score += config_values.destroyed_should_be_repairing - elif final_node_file_system_state == FileSystemState.REPAIRING: - score += config_values.repairing - else: - pass - elif reference_node_file_system_state == FileSystemState.RESTORING: - if final_node_file_system_state == FileSystemState.GOOD: - score += config_values.good_should_be_restoring - elif final_node_file_system_state == FileSystemState.REPAIRING: - score += config_values.repairing_should_be_restoring - elif final_node_file_system_state == FileSystemState.CORRUPT: - score += config_values.corrupt_should_be_restoring - elif final_node_file_system_state == FileSystemState.DESTROYED: - score += config_values.destroyed_should_be_restoring - elif final_node_file_system_state == FileSystemState.RESTORING: - score += config_values.restoring - else: - pass - elif reference_node_file_system_state == FileSystemState.CORRUPT: - if final_node_file_system_state == FileSystemState.GOOD: - score += config_values.good_should_be_corrupt - elif final_node_file_system_state == FileSystemState.REPAIRING: - score += config_values.repairing_should_be_corrupt - elif final_node_file_system_state == FileSystemState.RESTORING: - score += config_values.restoring_should_be_corrupt - elif final_node_file_system_state == FileSystemState.DESTROYED: - score += config_values.destroyed_should_be_corrupt - elif final_node_file_system_state == FileSystemState.CORRUPT: - score += config_values.corrupt - else: - pass - elif reference_node_file_system_state == FileSystemState.DESTROYED: - if final_node_file_system_state == FileSystemState.GOOD: - score += config_values.good_should_be_destroyed - elif final_node_file_system_state == FileSystemState.REPAIRING: - score += config_values.repairing_should_be_destroyed - elif final_node_file_system_state == FileSystemState.RESTORING: - score += config_values.restoring_should_be_destroyed - elif final_node_file_system_state == FileSystemState.CORRUPT: - score += config_values.corrupt_should_be_destroyed - elif final_node_file_system_state == FileSystemState.DESTROYED: - score += config_values.destroyed - else: - pass - else: - pass - - # Scanning State - if final_node_scanning_state == reference_node_scanning_state: - # All is well - we're no different from the reference situation - score += config_values.all_ok - else: - # We're different from the reference situation - # We're scanning the file system which incurs a penalty (as it slows down systems) - score += config_values.scanning - - return score diff --git a/src/primaite/links/__init__.py b/src/primaite/links/__init__.py deleted file mode 100644 index c91b6951..00000000 --- a/src/primaite/links/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Network connections between nodes in the simulation.""" diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py deleted file mode 100644 index 3830a15b..00000000 --- a/src/primaite/links/link.py +++ /dev/null @@ -1,114 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""The link class.""" -from typing import List - -from primaite.common.protocol import Protocol - - -class Link(object): - """Link class.""" - - def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None: - """ - Initialise a Link within the simulated network. - - :param _id: The IER id - :param _bandwidth: The bandwidth of the link (bps) - :param _source_node_name: The name of the source node - :param _dest_node_name: The name of the destination node - :param _protocols: The protocols to add to the link - """ - self.id: str = _id - self.bandwidth: int = _bandwidth - self.source_node_name: str = _source_node_name - self.dest_node_name: str = _dest_node_name - self.protocol_list: List[Protocol] = [] - - # Add the default protocols - for protocol_name in _services: - self.add_protocol(protocol_name) - - def add_protocol(self, _protocol: str) -> None: - """ - Adds a new protocol to the list of protocols on this link. - - Args: - _protocol: The protocol to be added (enum) - """ - self.protocol_list.append(Protocol(_protocol)) - - def get_id(self) -> str: - """ - Gets link ID. - - Returns: - Link ID - """ - return self.id - - def get_source_node_name(self) -> str: - """ - Gets source node name. - - Returns: - Source node name - """ - return self.source_node_name - - def get_dest_node_name(self) -> str: - """ - Gets destination node name. - - Returns: - Destination node name - """ - return self.dest_node_name - - def get_bandwidth(self) -> int: - """ - Gets bandwidth of link. - - Returns: - Link bandwidth (bps) - """ - return self.bandwidth - - def get_protocol_list(self) -> List[Protocol]: - """ - Gets list of protocols on this link. - - Returns: - List of protocols on this link - """ - return self.protocol_list - - def get_current_load(self) -> int: - """ - Gets current total load on this link. - - Returns: - Total load on this link (bps) - """ - total_load = 0 - for protocol in self.protocol_list: - total_load += protocol.get_load() - return total_load - - def add_protocol_load(self, _protocol: str, _load: int) -> None: - """ - Adds a loading to a protocol on this link. - - Args: - _protocol: The protocol to load - _load: The amount to load (bps) - """ - for protocol in self.protocol_list: - if protocol.get_name() == _protocol: - protocol.add_load(_load) - else: - pass - - def clear_traffic(self) -> None: - """Clears all traffic on this link.""" - for protocol in self.protocol_list: - protocol.clear_load() diff --git a/src/primaite/nodes/__init__.py b/src/primaite/nodes/__init__.py deleted file mode 100644 index 231b8d92..00000000 --- a/src/primaite/nodes/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Nodes represent network hosts in the simulation.""" diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py deleted file mode 100644 index 8f472e86..00000000 --- a/src/primaite/nodes/active_node.py +++ /dev/null @@ -1,208 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""An Active Node (i.e. not an actuator).""" -import logging -from typing import Final - -from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState -from primaite.config.training_config import TrainingConfig -from primaite.nodes.node import Node - -_LOGGER: Final[logging.Logger] = logging.getLogger(__name__) - - -class ActiveNode(Node): - """Active Node class.""" - - def __init__( - self, - node_id: str, - name: str, - node_type: NodeType, - priority: Priority, - hardware_state: HardwareState, - ip_address: str, - software_state: SoftwareState, - file_system_state: FileSystemState, - config_values: TrainingConfig, - ) -> None: - """ - Initialise an active node. - - :param node_id: The node ID - :param name: The node name - :param node_type: The node type (enum) - :param priority: The node priority (enum) - :param hardware_state: The node Hardware State - :param ip_address: The node IP address - :param software_state: The node Software State - :param file_system_state: The node file system state - :param config_values: The config values - """ - super().__init__(node_id, name, node_type, priority, hardware_state, config_values) - self.ip_address: str = ip_address - # Related to Software - self._software_state: SoftwareState = software_state - self.patching_count: int = 0 - # Related to File System - self.file_system_state_actual: FileSystemState = file_system_state - self.file_system_state_observed: FileSystemState = file_system_state - self.file_system_scanning: bool = False - self.file_system_scanning_count: int = 0 - self.file_system_action_count: int = 0 - - @property - def software_state(self) -> SoftwareState: - """ - Get the software_state. - - :return: The software_state. - """ - return self._software_state - - @software_state.setter - def software_state(self, software_state: SoftwareState) -> None: - """ - Get the software_state. - - :param software_state: Software State. - """ - if self.hardware_state != HardwareState.OFF: - self._software_state = software_state - if software_state == SoftwareState.PATCHING: - self.patching_count = self.config_values.os_patching_duration - else: - _LOGGER.info( - f"The Nodes hardware state is OFF so OS State cannot be " - f"changed. " - f"Node.node_id:{self.node_id}, " - f"Node.hardware_state:{self.hardware_state}, " - f"Node.software_state:{self._software_state}" - ) - - def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None: - """ - Sets Software State if the node is not compromised. - - Args: - software_state: Software State - """ - if self.hardware_state != HardwareState.OFF: - if self._software_state != SoftwareState.COMPROMISED: - self._software_state = software_state - if software_state == SoftwareState.PATCHING: - self.patching_count = self.config_values.os_patching_duration - else: - _LOGGER.info( - f"The Nodes hardware state is OFF so OS State cannot be changed." - f"Node.node_id:{self.node_id}, " - f"Node.hardware_state:{self.hardware_state}, " - f"Node.software_state:{self._software_state}" - ) - - def update_os_patching_status(self) -> None: - """Updates operating system status based on patching cycle.""" - self.patching_count -= 1 - if self.patching_count <= 0: - self.patching_count = 0 - self._software_state = SoftwareState.GOOD - - def set_file_system_state(self, file_system_state: FileSystemState) -> None: - """ - Sets the file system state (actual and observed). - - Args: - file_system_state: File system state - """ - if self.hardware_state != HardwareState.OFF: - self.file_system_state_actual = file_system_state - - if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = self.config_values.file_system_repairing_limit - self.file_system_state_observed = FileSystemState.REPAIRING - elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = self.config_values.file_system_restoring_limit - self.file_system_state_observed = FileSystemState.RESTORING - elif file_system_state == FileSystemState.GOOD: - self.file_system_state_observed = FileSystemState.GOOD - else: - _LOGGER.info( - f"The Nodes hardware state is OFF so File System State " - f"cannot be changed. " - f"Node.node_id:{self.node_id}, " - f"Node.hardware_state:{self.hardware_state}, " - f"Node.file_system_state.actual:{self.file_system_state_actual}" - ) - - def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None: - """ - Sets the file system state (actual and observed) if not in a compromised state. - - Use for green PoL to prevent it overturning a compromised state - - Args: - file_system_state: File system state - """ - if self.hardware_state != HardwareState.OFF: - if ( - self.file_system_state_actual != FileSystemState.CORRUPT - and self.file_system_state_actual != FileSystemState.DESTROYED - ): - self.file_system_state_actual = file_system_state - - if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = self.config_values.file_system_repairing_limit - self.file_system_state_observed = FileSystemState.REPAIRING - elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = self.config_values.file_system_restoring_limit - self.file_system_state_observed = FileSystemState.RESTORING - elif file_system_state == FileSystemState.GOOD: - self.file_system_state_observed = FileSystemState.GOOD - else: - _LOGGER.info( - f"The Nodes hardware state is OFF so File System State (if not " - f"compromised) cannot be changed. " - f"Node.node_id:{self.node_id}, " - f"Node.hardware_state:{self.hardware_state}, " - f"Node.file_system_state.actual:{self.file_system_state_actual}" - ) - - def start_file_system_scan(self) -> None: - """Starts a file system scan.""" - self.file_system_scanning = True - self.file_system_scanning_count = self.config_values.file_system_scanning_limit - - def update_file_system_state(self) -> None: - """Updates file system status based on scanning/restore/repair cycle.""" - # Deprecate both the action count (for restoring or reparing) and the scanning count - self.file_system_action_count -= 1 - self.file_system_scanning_count -= 1 - - # Reparing / Restoring updates - if self.file_system_action_count <= 0: - self.file_system_action_count = 0 - if ( - self.file_system_state_actual == FileSystemState.REPAIRING - or self.file_system_state_actual == FileSystemState.RESTORING - ): - self.file_system_state_actual = FileSystemState.GOOD - self.file_system_state_observed = FileSystemState.GOOD - - # Scanning updates - if self.file_system_scanning == True and self.file_system_scanning_count < 0: - self.file_system_state_observed = self.file_system_state_actual - self.file_system_scanning = False - self.file_system_scanning_count = 0 - - def update_resetting_status(self) -> None: - """Updates the reset count & makes software and file state to GOOD.""" - super().update_resetting_status() - if self.resetting_count <= 0: - self.file_system_state_actual = FileSystemState.GOOD - self.software_state = SoftwareState.GOOD - - def update_booting_status(self) -> None: - """Updates the booting software and file state to GOOD.""" - super().update_booting_status() - if self.booting_count <= 0: - self.file_system_state_actual = FileSystemState.GOOD - self.software_state = SoftwareState.GOOD diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py deleted file mode 100644 index fc4d41d3..00000000 --- a/src/primaite/nodes/node.py +++ /dev/null @@ -1,79 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""The base Node class.""" -from typing import Final - -from primaite.common.enums import HardwareState, NodeType, Priority -from primaite.config.training_config import TrainingConfig - - -class Node: - """Node class.""" - - def __init__( - self, - node_id: str, - name: str, - node_type: NodeType, - priority: Priority, - hardware_state: HardwareState, - config_values: TrainingConfig, - ) -> None: - """ - Initialise a node. - - :param node_id: The node id. - :param name: The name of the node. - :param node_type: The type of the node. - :param priority: The priority of the node. - :param hardware_state: The state of the node. - :param config_values: Config values. - """ - self.node_id: Final[str] = node_id - self.name: Final[str] = name - self.node_type: Final[NodeType] = node_type - self.priority = priority - self.hardware_state: HardwareState = hardware_state - self.resetting_count: int = 0 - self.config_values: TrainingConfig = config_values - self.booting_count: int = 0 - self.shutting_down_count: int = 0 - - def __repr__(self) -> str: - """Returns the name of the node.""" - return self.name - - def turn_on(self) -> None: - """Sets the node state to ON.""" - self.hardware_state = HardwareState.BOOTING - self.booting_count = self.config_values.node_booting_duration - - def turn_off(self) -> None: - """Sets the node state to OFF.""" - self.hardware_state = HardwareState.OFF - self.shutting_down_count = self.config_values.node_shutdown_duration - - def reset(self) -> None: - """Sets the node state to Resetting and starts the reset count.""" - self.hardware_state = HardwareState.RESETTING - self.resetting_count = self.config_values.node_reset_duration - - def update_resetting_status(self) -> None: - """Updates the resetting count.""" - self.resetting_count -= 1 - if self.resetting_count <= 0: - self.resetting_count = 0 - self.hardware_state = HardwareState.ON - - def update_booting_status(self) -> None: - """Updates the booting count.""" - self.booting_count -= 1 - if self.booting_count <= 0: - self.booting_count = 0 - self.hardware_state = HardwareState.ON - - def update_shutdown_status(self) -> None: - """Updates the shutdown count.""" - self.shutting_down_count -= 1 - if self.shutting_down_count <= 0: - self.shutting_down_count = 0 - self.hardware_state = HardwareState.OFF diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py deleted file mode 100644 index 6e35d0ec..00000000 --- a/src/primaite/nodes/node_state_instruction_green.py +++ /dev/null @@ -1,94 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Defines node behaviour for Green PoL.""" -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState - - -class NodeStateInstructionGreen(object): - """The Node State Instruction class.""" - - def __init__( - self, - _id: str, - _start_step: int, - _end_step: int, - _node_id: str, - _node_pol_type: "NodePOLType", - _service_name: str, - _state: Union["HardwareState", "SoftwareState", "FileSystemState"], - ) -> None: - """ - Initialise the Node State Instruction. - - :param _id: The node state instruction id - :param _start_step: The start step of the instruction - :param _end_step: The end step of the instruction - :param _node_id: The id of the associated node - :param _node_pol_type: The pattern of life type - :param _service_name: The service name - :param _state: The state (node or service) - """ - self.id = _id - self.start_step = _start_step - self.end_step = _end_step - self.node_id = _node_id - self.node_pol_type: "NodePOLType" = _node_pol_type - self.service_name: str = _service_name # Not used when not a service instruction - # TODO: confirm type of state - self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state - - def get_start_step(self) -> int: - """ - Gets the start step. - - Returns: - The start step - """ - return self.start_step - - def get_end_step(self) -> int: - """ - Gets the end step. - - Returns: - The end step - """ - return self.end_step - - def get_node_id(self) -> str: - """ - Gets the node ID. - - Returns: - The node ID - """ - return self.node_id - - def get_node_pol_type(self) -> "NodePOLType": - """ - Gets the node pattern of life type (enum). - - Returns: - The node pattern of life type (enum) - """ - return self.node_pol_type - - def get_service_name(self) -> str: - """ - Gets the service name. - - Returns: - The service name - """ - return self.service_name - - def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]: - """ - Gets the state (node or service). - - Returns: - The state (node or service) - """ - return self.state diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py deleted file mode 100644 index eb87924b..00000000 --- a/src/primaite/nodes/node_state_instruction_red.py +++ /dev/null @@ -1,143 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Defines node behaviour for Green PoL.""" -from typing import TYPE_CHECKING, Union - -from primaite.common.enums import NodePOLType - -if TYPE_CHECKING: - from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState - - -class NodeStateInstructionRed: - """The Node State Instruction class.""" - - def __init__( - self, - _id: str, - _start_step: int, - _end_step: int, - _target_node_id: str, - _pol_initiator: "NodePOLInitiator", - _pol_type: NodePOLType, - pol_protocol: str, - _pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"], - _pol_source_node_id: str, - _pol_source_node_service: str, - _pol_source_node_service_state: str, - ) -> None: - """ - Initialise the Node State Instruction for the red agent. - - :param _id: The node state instruction id - :param _start_step: The start step of the instruction - :param _end_step: The end step of the instruction - :param _target_node_id: The id of the associated node - :param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE) - :param _pol_type: The pattern of life type - :param pol_protocol: The pattern of life protocol/service affected - :param _pol_state: The state (node or service) - :param _pol_source_node_id: The source node Id (used for initiator type SERVICE) - :param _pol_source_node_service: The source node service (used for initiator type SERVICE) - :param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE) - """ - self.id: str = _id - self.start_step: int = _start_step - self.end_step: int = _end_step - self.target_node_id: str = _target_node_id - self.initiator: "NodePOLInitiator" = _pol_initiator - self.pol_type: NodePOLType = _pol_type - self.service_name: str = pol_protocol # Not used when not a service instruction - self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state - self.source_node_id: str = _pol_source_node_id - self.source_node_service: str = _pol_source_node_service - self.source_node_service_state = _pol_source_node_service_state - - def get_start_step(self) -> int: - """ - Gets the start step. - - Returns: - The start step - """ - return self.start_step - - def get_end_step(self) -> int: - """ - Gets the end step. - - Returns: - The end step - """ - return self.end_step - - def get_target_node_id(self) -> str: - """ - Gets the node ID. - - Returns: - The node ID - """ - return self.target_node_id - - def get_initiator(self) -> "NodePOLInitiator": - """ - Gets the initiator. - - Returns: - The initiator - """ - return self.initiator - - def get_pol_type(self) -> NodePOLType: - """ - Gets the node pattern of life type (enum). - - Returns: - The node pattern of life type (enum) - """ - return self.pol_type - - def get_service_name(self) -> str: - """ - Gets the service name. - - Returns: - The service name - """ - return self.service_name - - def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]: - """ - Gets the state (node or service). - - Returns: - The state (node or service) - """ - return self.state - - def get_source_node_id(self) -> str: - """ - Gets the source node id (used for initiator type SERVICE). - - Returns: - The source node id - """ - return self.source_node_id - - def get_source_node_service(self) -> str: - """ - Gets the source node service (used for initiator type SERVICE). - - Returns: - The source node service - """ - return self.source_node_service - - def get_source_node_service_state(self) -> str: - """ - Gets the source node service state (used for initiator type SERVICE). - - Returns: - The source node service state - """ - return self.source_node_service_state diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py deleted file mode 100644 index 08dcbfa2..00000000 --- a/src/primaite/nodes/passive_node.py +++ /dev/null @@ -1,42 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""The Passive Node class (i.e. an actuator).""" -from primaite.common.enums import HardwareState, NodeType, Priority -from primaite.config.training_config import TrainingConfig -from primaite.nodes.node import Node - - -class PassiveNode(Node): - """The Passive Node class.""" - - def __init__( - self, - node_id: str, - name: str, - node_type: NodeType, - priority: Priority, - hardware_state: HardwareState, - config_values: TrainingConfig, - ) -> None: - """ - Initialise a passive node. - - :param node_id: The node id. - :param name: The name of the node. - :param node_type: The type of the node. - :param priority: The priority of the node. - :param hardware_state: The state of the node. - :param config_values: Config values. - """ - # Pass through to Super for now - super().__init__(node_id, name, node_type, priority, hardware_state, config_values) - - @property - def ip_address(self) -> str: - """ - Gets the node IP address as an empty string. - - No concept of IP address for passive nodes for now. - - :return: The node IP address. - """ - return "" diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py deleted file mode 100644 index b0d42785..00000000 --- a/src/primaite/nodes/service_node.py +++ /dev/null @@ -1,190 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""A Service Node (i.e. not an actuator).""" -import logging -from typing import Dict, Final - -from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState -from primaite.common.service import Service -from primaite.config.training_config import TrainingConfig -from primaite.nodes.active_node import ActiveNode - -_LOGGER: Final[logging.Logger] = logging.getLogger(__name__) - - -class ServiceNode(ActiveNode): - """ServiceNode class.""" - - def __init__( - self, - node_id: str, - name: str, - node_type: NodeType, - priority: Priority, - hardware_state: HardwareState, - ip_address: str, - software_state: SoftwareState, - file_system_state: FileSystemState, - config_values: TrainingConfig, - ) -> None: - """ - Initialise a Service Node. - - :param node_id: The node ID - :param name: The node name - :param node_type: The node type (enum) - :param priority: The node priority (enum) - :param hardware_state: The node Hardware State - :param ip_address: The node IP address - :param software_state: The node Software State - :param file_system_state: The node file system state - :param config_values: The config values - """ - super().__init__( - node_id, - name, - node_type, - priority, - hardware_state, - ip_address, - software_state, - file_system_state, - config_values, - ) - self.services: Dict[str, Service] = {} - - def add_service(self, service: Service) -> None: - """ - Adds a service to the node. - - :param service: The service to add - """ - self.services[service.name] = service - - def has_service(self, protocol_name: str) -> bool: - """ - Indicates whether a service is on a node. - - :param protocol_name: The service (protocol)e. - :return: True if service (protocol) is on the node, otherwise False. - """ - for service_key, service_value in self.services.items(): - if service_key == protocol_name: - return True - return False - - def service_running(self, protocol_name: str) -> bool: - """ - Indicates whether a service is in a running state on the node. - - :param protocol_name: The service (protocol) - :return: True if service (protocol) is in a running state on the node, otherwise False. - """ - for service_key, service_value in self.services.items(): - if service_key == protocol_name: - if service_value.software_state != SoftwareState.PATCHING: - return True - else: - return False - return False - - def service_is_overwhelmed(self, protocol_name: str) -> bool: - """ - Indicates whether a service is in an overwhelmed state on the node. - - :param protocol_name: The service (protocol) - :return: True if service (protocol) is in an overwhelmed state on the node, otherwise False. - """ - for service_key, service_value in self.services.items(): - if service_key == protocol_name: - if service_value.software_state == SoftwareState.OVERWHELMED: - return True - else: - return False - return False - - def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None: - """ - Sets the software_state of a service (protocol) on the node. - - :param protocol_name: The service (protocol). - :param software_state: The software_state. - """ - if self.hardware_state != HardwareState.OFF: - service_key = protocol_name - service_value = self.services.get(service_key) - if service_value: - # Can't set to compromised if you're in a patching state - if ( - software_state == SoftwareState.COMPROMISED - and service_value.software_state != SoftwareState.PATCHING - ) or software_state != SoftwareState.COMPROMISED: - service_value.software_state = software_state - if software_state == SoftwareState.PATCHING: - service_value.patching_count = self.config_values.service_patching_duration - else: - _LOGGER.info( - f"The Nodes hardware state is OFF so the state of a service " - f"cannot be changed. " - f"Node.node_id:{self.node_id}, " - f"Node.hardware_state:{self.hardware_state}, " - f"Node.services[]:{protocol_name}, " - f"Node.services[].software_state:{software_state}" - ) - - def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None: - """ - Sets the software_state of a service (protocol) on the node. - - Done if the software_state is not "compromised". - - :param protocol_name: The service (protocol). - :param software_state: The software_state. - """ - if self.hardware_state != HardwareState.OFF: - service_key = protocol_name - service_value = self.services.get(service_key) - if service_value: - if service_value.software_state != SoftwareState.COMPROMISED: - service_value.software_state = software_state - if software_state == SoftwareState.PATCHING: - service_value.patching_count = self.config_values.service_patching_duration - else: - _LOGGER.info( - f"The Nodes hardware state is OFF so the state of a service " - f"cannot be changed. " - f"Node.node_id:{self.node_id}, " - f"Node.hardware_state:{self.hardware_state}, " - f"Node.services[]:{protocol_name}, " - f"Node.services[].software_state:{software_state}" - ) - - def get_service_state(self, protocol_name: str) -> SoftwareState: - """ - Gets the state of a service. - - :return: The software_state of the service. - """ - service_key = protocol_name - service_value = self.services.get(service_key) - if service_value: - return service_value.software_state - - def update_services_patching_status(self) -> None: - """Updates the patching counter for any service that are patching.""" - for service_key, service_value in self.services.items(): - if service_value.software_state == SoftwareState.PATCHING: - service_value.reduce_patching_count() - - def update_resetting_status(self) -> None: - """Update resetting counter and set software state if it reached 0.""" - super().update_resetting_status() - if self.resetting_count <= 0: - for service in self.services.values(): - service.software_state = SoftwareState.GOOD - - def update_booting_status(self) -> None: - """Update booting counter and set software to good if it reached 0.""" - super().update_booting_status() - if self.booting_count <= 0: - for service in self.services.values(): - service.software_state = SoftwareState.GOOD diff --git a/src/primaite/pol/__init__.py b/src/primaite/pol/__init__.py deleted file mode 100644 index d0d9f616..00000000 --- a/src/primaite/pol/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Pattern of Life- Represents the actions of users on the network.""" diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py deleted file mode 100644 index 814aa314..00000000 --- a/src/primaite/pol/green_pol.py +++ /dev/null @@ -1,264 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Implements Pattern of Life on the network (nodes and links).""" -from typing import Dict - -from networkx import MultiGraph, shortest_path - -from primaite.acl.access_control_list import AccessControlList -from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState -from primaite.links.link import Link -from primaite.nodes.active_node import ActiveNode -from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen -from primaite.nodes.service_node import ServiceNode -from primaite.pol.ier import IER - -_VERBOSE: bool = False - - -def apply_iers( - network: MultiGraph, - nodes: Dict[str, NodeUnion], - links: Dict[str, Link], - iers: Dict[str, IER], - acl: AccessControlList, - step: int, -) -> None: - """ - Applies IERs to the links (link pattern of life). - - Args: - network: The network modelled in the environment - nodes: The nodes within the environment - links: The links within the environment - iers: The IERs to apply to the links - acl: The Access Control List - step: The step number. - """ - if _VERBOSE: - print("Applying IERs") - - # Go through each IER and check the conditions for it being applied - # If everything is in place, apply the IER protocol load to the relevant links - for ier_key, ier_value in iers.items(): - start_step = ier_value.get_start_step() - stop_step = ier_value.get_end_step() - protocol = ier_value.get_protocol() - port = ier_value.get_port() - load = ier_value.get_load() - source_node_id = ier_value.get_source_node_id() - dest_node_id = ier_value.get_dest_node_id() - - # Need to set the running status to false first for all IERs - ier_value.set_is_running(False) - - source_valid = True - dest_valid = True - acl_block = False - - if step >= start_step and step <= stop_step: - # continue -------------------------- - - # Get the source and destination node for this link - source_node = nodes[source_node_id] - dest_node = nodes[dest_node_id] - - # 1. Check the source node situation - # TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch - # doesn't mean it has a software state? It could be a PassiveNode or ActiveNode - if source_node.node_type == NodeType.SWITCH: - # It's a switch - if ( - source_node.hardware_state == HardwareState.ON - and source_node.software_state != SoftwareState.PATCHING - ): - source_valid = True - else: - # IER no longer valid - source_valid = False - elif source_node.node_type == NodeType.ACTUATOR: - # It's an actuator - # TO DO - pass - else: - # It's not a switch or an actuator (so active node) - if ( - source_node.hardware_state == HardwareState.ON - and source_node.software_state != SoftwareState.PATCHING - ): - if source_node.has_service(protocol): - if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol): - source_valid = True - else: - source_valid = False - else: - # Do nothing - IER is not valid on this node - # (This shouldn't happen if the IER has been written correctly) - source_valid = False - else: - # Do nothing - IER no longer valid - source_valid = False - - # 2. Check the dest node situation - if dest_node.node_type == NodeType.SWITCH: - # It's a switch - if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: - dest_valid = True - else: - # IER no longer valid - dest_valid = False - elif dest_node.node_type == NodeType.ACTUATOR: - # It's an actuator - pass - else: - # It's not a switch or an actuator (so active node) - if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: - if dest_node.has_service(protocol): - if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol): - dest_valid = True - else: - dest_valid = False - else: - # Do nothing - IER is not valid on this node - # (This shouldn't happen if the IER has been written correctly) - dest_valid = False - else: - # Do nothing - IER no longer valid - dest_valid = False - - # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) - if acl_block: - if _VERBOSE: - print( - "ACL block on source: " - + source_node.ip_address - + ", dest: " - + dest_node.ip_address - + ", protocol: " - + protocol - + ", port: " - + port - ) - else: - if _VERBOSE: - print("No ACL block") - - # Check whether both the source and destination are valid, and there's no ACL block - if source_valid and dest_valid and not acl_block: - # Load up the link(s) with the traffic - - if _VERBOSE: - print("Source, Dest and ACL valid") - - # Get the shortest path (i.e. nodes) between source and destination - path_node_list = shortest_path(network, source_node, dest_node) - path_node_list_length = len(path_node_list) - path_valid = True - - # We might have a switch in the path, so check all nodes are operational - for node in path_node_list: - if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING: - path_valid = False - - if path_valid: - if _VERBOSE: - print("Applying IER to link(s)") - count = 0 - link_capacity_exceeded = False - - # Check that the link capacity is not exceeded by the new load - while count < path_node_list_length - 1: - # Get the link between the next two nodes - edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) - link_id = edge_dict[0].get("id") - link = links[link_id] - # Check whether the new load exceeds the bandwidth - if (link.get_current_load() + load) > link.get_bandwidth(): - link_capacity_exceeded = True - if _VERBOSE: - print("Link capacity exceeded") - pass - count += 1 - - # Check whether the link capacity for any links on this path have been exceeded - if link_capacity_exceeded == False: - # Now apply the new loads to the links - count = 0 - while count < path_node_list_length - 1: - # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], - path_node_list[count + 1], - ) - link_id = edge_dict[0].get("id") - link = links[link_id] - # Add the load from this IER - link.add_protocol_load(protocol, load) - count += 1 - # This IER is now valid, so set it to running - ier_value.set_is_running(True) - else: - # One of the nodes is not operational - if _VERBOSE: - print("Path not valid - one or more nodes not operational") - pass - - else: - if _VERBOSE: - print("Source, Dest or ACL were not valid") - pass - # ------------------------------------ - else: - # Do nothing - IER no longer valid - pass - - -def apply_node_pol( - nodes: Dict[str, NodeUnion], - node_pol: Dict[str, NodeStateInstructionGreen], - step: int, -) -> None: - """ - Applies node pattern of life. - - Args: - nodes: The nodes within the environment - node_pol: The node pattern of life to apply - step: The step number. - """ - if _VERBOSE: - print("Applying Node PoL") - - for key, node_instruction in node_pol.items(): - start_step = node_instruction.get_start_step() - stop_step = node_instruction.get_end_step() - node_id = node_instruction.get_node_id() - node_pol_type = node_instruction.get_node_pol_type() - service_name = node_instruction.get_service_name() - state = node_instruction.get_state() - - if step >= start_step and step <= stop_step: - # continue -------------------------- - node = nodes[node_id] - - if node_pol_type == NodePOLType.OPERATING: - # Change hardware state - node.hardware_state = state - elif node_pol_type == NodePOLType.OS: - # Change OS state - # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - node.set_software_state_if_not_compromised(state) - elif node_pol_type == NodePOLType.SERVICE: - # Change a service state - # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this - if isinstance(node, ServiceNode): - node.set_service_state_if_not_compromised(service_name, state) - else: - # Change the file system status - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - node.set_file_system_state_if_not_compromised(state) - else: - # PoL is not valid in this time step - pass diff --git a/src/primaite/pol/ier.py b/src/primaite/pol/ier.py deleted file mode 100644 index b8da28c0..00000000 --- a/src/primaite/pol/ier.py +++ /dev/null @@ -1,147 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -""" -Information Exchange Requirements for APE. - -Used to represent an information flow from source to destination. -""" - - -class IER(object): - """Information Exchange Requirement class.""" - - def __init__( - self, - _id: str, - _start_step: int, - _end_step: int, - _load: int, - _protocol: str, - _port: str, - _source_node_id: str, - _dest_node_id: str, - _mission_criticality: int, - _running: bool = False, - ) -> None: - """ - Initialise an Information Exchange Request. - - :param _id: The IER id - :param _start_step: The step when this IER should start - :param _end_step: The step when this IER should end - :param _load: The load this IER should put on a link (bps) - :param _protocol: The protocol of this IER - :param _port: The port this IER runs on - :param _source_node_id: The source node ID - :param _dest_node_id: The destination node ID - :param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical) - :param _running: Indicates whether the IER is currently running - """ - self.id: str = _id - self.start_step: int = _start_step - self.end_step: int = _end_step - self.source_node_id: str = _source_node_id - self.dest_node_id: str = _dest_node_id - self.load: int = _load - self.protocol: str = _protocol - self.port: str = _port - self.mission_criticality: int = _mission_criticality - self.running: bool = _running - - def get_id(self) -> str: - """ - Gets IER ID. - - Returns: - IER ID - """ - return self.id - - def get_start_step(self) -> int: - """ - Gets IER start step. - - Returns: - IER start step - """ - return self.start_step - - def get_end_step(self) -> int: - """ - Gets IER end step. - - Returns: - IER end step - """ - return self.end_step - - def get_load(self) -> int: - """ - Gets IER load. - - Returns: - IER load - """ - return self.load - - def get_protocol(self) -> str: - """ - Gets IER protocol. - - Returns: - IER protocol - """ - return self.protocol - - def get_port(self) -> str: - """ - Gets IER port. - - Returns: - IER port - """ - return self.port - - def get_source_node_id(self) -> str: - """ - Gets IER source node ID. - - Returns: - IER source node ID - """ - return self.source_node_id - - def get_dest_node_id(self) -> str: - """ - Gets IER destination node ID. - - Returns: - IER destination node ID - """ - return self.dest_node_id - - def get_is_running(self) -> bool: - """ - Informs whether the IER is currently running. - - Returns: - True if running - """ - return self.running - - def set_is_running(self, _value: bool) -> None: - """ - Sets the running state of the IER. - - Args: - _value: running status - """ - self.running = _value - - def get_mission_criticality(self) -> int: - """ - Gets the IER mission criticality (used in the reward function). - - Returns: - Mission criticality value (0 lowest to 5 highest) - """ - return self.mission_criticality diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py deleted file mode 100644 index ca1a58da..00000000 --- a/src/primaite/pol/red_agent_pol.py +++ /dev/null @@ -1,353 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Implements POL on the network (nodes and links) resulting from the red agent attack.""" -from typing import Dict - -from networkx import MultiGraph, shortest_path - -from primaite import getLogger -from primaite.acl.access_control_list import AccessControlList -from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState -from primaite.links.link import Link -from primaite.nodes.active_node import ActiveNode -from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed -from primaite.nodes.service_node import ServiceNode -from primaite.pol.ier import IER - -_LOGGER = getLogger(__name__) - -_VERBOSE: bool = False - - -def apply_red_agent_iers( - network: MultiGraph, - nodes: Dict[str, NodeUnion], - links: Dict[str, Link], - iers: Dict[str, IER], - acl: AccessControlList, - step: int, -) -> None: - """ - Applies IERs to the links (link POL) resulting from red agent attack. - - Args: - network: The network modelled in the environment - nodes: The nodes within the environment - links: The links within the environment - iers: The red agent IERs to apply to the links - acl: The Access Control List - step: The step number. - """ - # Go through each IER and check the conditions for it being applied - # If everything is in place, apply the IER protocol load to the relevant links - for ier_key, ier_value in iers.items(): - start_step = ier_value.get_start_step() - stop_step = ier_value.get_end_step() - protocol = ier_value.get_protocol() - port = ier_value.get_port() - load = ier_value.get_load() - source_node_id = ier_value.get_source_node_id() - dest_node_id = ier_value.get_dest_node_id() - - # Need to set the running status to false first for all IERs - ier_value.set_is_running(False) - - source_valid = True - dest_valid = True - acl_block = False - - if step >= start_step and step <= stop_step: - # continue -------------------------- - - # Get the source and destination node for this link - source_node = nodes[source_node_id] - dest_node = nodes[dest_node_id] - - # 1. Check the source node situation - if source_node.node_type == NodeType.SWITCH: - # It's a switch - if source_node.hardware_state == HardwareState.ON: - source_valid = True - else: - # IER no longer valid - source_valid = False - elif source_node.node_type == NodeType.ACTUATOR: - # It's an actuator - # TO DO - pass - else: - # It's not a switch or an actuator (so active node) - # TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it - # could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs - # to change according to duck typing. - if source_node.hardware_state == HardwareState.ON: - if source_node.has_service(protocol): - # Red agents IERs can only be valid if the source service is in a compromised state - if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED: - source_valid = True - else: - source_valid = False - else: - # Do nothing - IER is not valid on this node - # (This shouldn't happen if the IER has been written correctly) - source_valid = False - else: - # Do nothing - IER no longer valid - source_valid = False - - # 2. Check the dest node situation - if dest_node.node_type == NodeType.SWITCH: - # It's a switch - if dest_node.hardware_state == HardwareState.ON: - dest_valid = True - else: - # IER no longer valid - dest_valid = False - elif dest_node.node_type == NodeType.ACTUATOR: - # It's an actuator - pass - else: - # It's not a switch or an actuator (so active node) - if dest_node.hardware_state == HardwareState.ON: - if dest_node.has_service(protocol): - # We don't care what state the destination service is in for an IER - dest_valid = True - else: - # Do nothing - IER is not valid on this node - # (This shouldn't happen if the IER has been written correctly) - dest_valid = False - else: - # Do nothing - IER no longer valid - dest_valid = False - - # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) - if acl_block: - if _VERBOSE: - print( - "ACL block on source: " - + source_node.ip_address - + ", dest: " - + dest_node.ip_address - + ", protocol: " - + protocol - + ", port: " - + port - ) - else: - if _VERBOSE: - print("No ACL block") - - # Check whether both the source and destination are valid, and there's no ACL block - if source_valid and dest_valid and not acl_block: - # Load up the link(s) with the traffic - - if _VERBOSE: - print("Source, Dest and ACL valid") - - # Get the shortest path (i.e. nodes) between source and destination - path_node_list = shortest_path(network, source_node, dest_node) - path_node_list_length = len(path_node_list) - path_valid = True - - # We might have a switch in the path, so check all nodes are operational - # We're assuming here that red agents can get past switches that are patching - for node in path_node_list: - if node.hardware_state != HardwareState.ON: - path_valid = False - - if path_valid: - if _VERBOSE: - print("Applying IER to link(s)") - count = 0 - link_capacity_exceeded = False - - # Check that the link capacity is not exceeded by the new load - while count < path_node_list_length - 1: - # Get the link between the next two nodes - edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) - link_id = edge_dict[0].get("id") - link = links[link_id] - # Check whether the new load exceeds the bandwidth - if (link.get_current_load() + load) > link.get_bandwidth(): - link_capacity_exceeded = True - if _VERBOSE: - print("Link capacity exceeded") - pass - count += 1 - - # Check whether the link capacity for any links on this path have been exceeded - if link_capacity_exceeded == False: - # Now apply the new loads to the links - count = 0 - while count < path_node_list_length - 1: - # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], - path_node_list[count + 1], - ) - link_id = edge_dict[0].get("id") - link = links[link_id] - # Add the load from this IER - link.add_protocol_load(protocol, load) - count += 1 - # This IER is now valid, so set it to running - ier_value.set_is_running(True) - if _VERBOSE: - print("Red IER was allowed to run in step " + str(step)) - else: - # One of the nodes is not operational - if _VERBOSE: - print("Path not valid - one or more nodes not operational") - pass - - else: - if _VERBOSE: - print("Red IER was NOT allowed to run in step " + str(step)) - print("Source, Dest or ACL were not valid") - pass - # ------------------------------------ - else: - # Do nothing - IER no longer valid - pass - - pass - - -def apply_red_agent_node_pol( - nodes: Dict[str, NodeUnion], - iers: Dict[str, IER], - node_pol: Dict[str, NodeStateInstructionRed], - step: int, -) -> None: - """ - Applies node pattern of life. - - Args: - nodes: The nodes within the environment - iers: The red agent IERs - node_pol: The red agent node pattern of life to apply - step: The step number. - """ - if _VERBOSE: - print("Applying Node Red Agent PoL") - - for key, node_instruction in node_pol.items(): - start_step = node_instruction.get_start_step() - stop_step = node_instruction.get_end_step() - target_node_id = node_instruction.get_target_node_id() - initiator = node_instruction.get_initiator() - pol_type = node_instruction.get_pol_type() - service_name = node_instruction.get_service_name() - state = node_instruction.get_state() - source_node_id = node_instruction.get_source_node_id() - source_node_service_name = node_instruction.get_source_node_service() - source_node_service_state_value = node_instruction.get_source_node_service_state() - - passed_checks = False - - if step >= start_step and step <= stop_step: - # continue -------------------------- - target_node: NodeUnion = nodes[target_node_id] - - # check if the initiator type is a str, and if so, cast it as - # NodePOLInitiator - if isinstance(initiator, str): - initiator = NodePOLInitiator[initiator] - - # Based the action taken on the initiator type - if initiator == NodePOLInitiator.DIRECT: - # No conditions required, just apply the change - passed_checks = True - elif initiator == NodePOLInitiator.IER: - # Need to check there is a red IER incoming - passed_checks = is_red_ier_incoming(target_node, iers, pol_type) - elif initiator == NodePOLInitiator.SERVICE: - # Need to check the condition of a service on another node - source_node = nodes[source_node_id] - if source_node.has_service(source_node_service_name): - if ( - source_node.get_service_state(source_node_service_name) - == SoftwareState[source_node_service_state_value] - ): - passed_checks = True - else: - # Do nothing, no matching state value - pass - else: - # Do nothing, service not on this node - pass - else: - _LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration") - - # Only apply the PoL if the checks have passed (based on the initiator type) - if passed_checks: - # Apply the change - if pol_type == NodePOLType.OPERATING: - # Change hardware state - target_node.hardware_state = state - elif pol_type == NodePOLType.OS: - # Change OS state - if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): - target_node.software_state = state - elif pol_type == NodePOLType.SERVICE: - # Change a service state - if isinstance(target_node, ServiceNode): - target_node.set_service_state(service_name, state) - else: - # Change the file system status - if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): - target_node.set_file_system_state(state) - else: - _LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks") - else: - # PoL is not valid in this time step - pass - - -def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool: - """Checks if the RED IER is incoming. - - :param node: Destination node of the IER - :type node: NodeUnion - :param iers: Directory of IERs - :type iers: Dict[str,IER] - :param node_pol_type: Type of Pattern-Of-Life - :type node_pol_type: NodePOLType - :return: Whether the RED IER is incoming. - :rtype: bool - """ - node_id = node.node_id - - for ier_key, ier_value in iers.items(): - if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id: - if ( - node_pol_type == NodePOLType.OPERATING - or node_pol_type == NodePOLType.OS - or node_pol_type == NodePOLType.FILE - ): - # It's looking to change hardware state, file system or SoftwareState, so valid - return True - elif node_pol_type == NodePOLType.SERVICE: - # Check if the service is present on the node and running - ier_protocol = ier_value.get_protocol() - if isinstance(node, ServiceNode): - if node.has_service(ier_protocol): - if node.service_running(ier_protocol): - # Matching service is present and running, so valid - return True - else: - # Service is present, but not running - return False - else: - # Service is not present - return False - else: - # Not a service node - return False - else: - # Shouldn't get here - instruction type is undefined - return False - else: - # The IER destination is not this node, or the IER is not running - return False diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py deleted file mode 100644 index 7d5b2709..00000000 --- a/src/primaite/primaite_session.py +++ /dev/null @@ -1,228 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output.""" -from __future__ import annotations - -import json -from pathlib import Path -from typing import Any, Dict, Final, Optional, Tuple, Union - -from primaite import getLogger -from primaite.agents.agent_abc import AgentSessionABC -from primaite.agents.hardcoded_acl import HardCodedACLAgent -from primaite.agents.hardcoded_node import HardCodedNodeAgent - -# from primaite.agents.rllib import RLlibAgent -from primaite.agents.sb3 import SB3Agent -from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent -from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType -from primaite.config import lay_down_config, training_config -from primaite.config.training_config import TrainingConfig -from primaite.utils.session_metadata_parser import parse_session_metadata -from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict - -_LOGGER = getLogger(__name__) - - -class PrimaiteSession: - """ - The PrimaiteSession class. - - Provides a single learning and evaluation entry point for all training and lay down configurations. - """ - - def __init__( - self, - training_config_path: Optional[Union[str, Path]] = "", - lay_down_config_path: Optional[Union[str, Path]] = "", - session_path: Optional[Union[str, Path]] = None, - legacy_training_config: bool = False, - legacy_lay_down_config: bool = False, - ) -> None: - """ - The PrimaiteSession constructor. - - :param training_config_path: YAML file containing configurable items defined in - `primaite.config.training_config.TrainingConfig` - :type training_config_path: Union[path, str] - :param lay_down_config_path: YAML file containing configurable items for generating network laydown. - :type lay_down_config_path: Union[path, str] - :param session_path: directory path of the session to load - :param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0, - otherwise False. - :param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0, - otherwise False. - """ - self._agent_session: AgentSessionABC = None # noqa - self.session_path: Path = session_path # noqa - self.timestamp_str: str = None # noqa - self.learning_path: Path = None # noqa - self.evaluation_path: Path = None # noqa - self.legacy_training_config = legacy_training_config - self.legacy_lay_down_config = legacy_lay_down_config - - # check if session path is provided - if session_path is not None: - # set load_session to true - self.is_load_session = True - if not isinstance(session_path, Path): - session_path = Path(session_path) - - # if a session path is provided, load it - if not session_path.exists(): - raise Exception(f"Session could not be loaded. Path does not exist: {session_path}") - - md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path) - - if not isinstance(training_config_path, Path): - training_config_path = Path(training_config_path) - self._training_config_path: Final[Union[Path, str]] = training_config_path - self._training_config: Final[TrainingConfig] = training_config.load( - self._training_config_path, legacy_training_config - ) - - if not isinstance(lay_down_config_path, Path): - lay_down_config_path = Path(lay_down_config_path) - self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) # noqa - - def setup(self) -> None: - """Performs the session setup.""" - if self._training_config.agent_framework == AgentFramework.CUSTOM: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}") - if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") - if self._training_config.action_type == ActionType.NODE: - # Deterministic Hardcoded Agent with Node Action Space - self._agent_session = HardCodedNodeAgent( - self._training_config_path, self._lay_down_config_path, self.session_path - ) - - elif self._training_config.action_type == ActionType.ACL: - # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = HardCodedACLAgent( - self._training_config_path, self._lay_down_config_path, self.session_path - ) - - elif self._training_config.action_type == ActionType.ANY: - # Deterministic Hardcoded Agent with ANY Action Space - raise NotImplementedError - - else: - # Invalid AgentIdentifier ActionType combo - raise ValueError - - elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}") - if self._training_config.action_type == ActionType.NODE: - self._agent_session = DoNothingNodeAgent( - self._training_config_path, self._lay_down_config_path, self.session_path - ) - - elif self._training_config.action_type == ActionType.ACL: - # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = DoNothingACLAgent( - self._training_config_path, self._lay_down_config_path, self.session_path - ) - - elif self._training_config.action_type == ActionType.ANY: - # Deterministic Hardcoded Agent with ANY Action Space - raise NotImplementedError - - else: - # Invalid AgentIdentifier ActionType combo - raise ValueError - - elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}") - self._agent_session = RandomAgent( - self._training_config_path, self._lay_down_config_path, self.session_path - ) - elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}") - self._agent_session = DummyAgent( - self._training_config_path, self._lay_down_config_path, self.session_path - ) - - else: - # Invalid AgentFramework AgentIdentifier combo - raise ValueError - - elif self._training_config.agent_framework == AgentFramework.SB3: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}") - # Stable Baselines3 Agent - self._agent_session = SB3Agent( - self._training_config_path, - self._lay_down_config_path, - self.session_path, - self.legacy_training_config, - self.legacy_lay_down_config, - ) - - # elif self._training_config.agent_framework == AgentFramework.RLLIB: - # _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") - # # Ray RLlib Agent - # self._agent_session = RLlibAgent( - # self._training_config_path, self._lay_down_config_path, self.session_path - # ) - - else: - # Invalid AgentFramework - raise ValueError - - self.session_path: Path = self._agent_session.session_path - self.timestamp_str: str = self._agent_session.timestamp_str - self.learning_path: Path = self._agent_session.learning_path - self.evaluation_path: Path = self._agent_session.evaluation_path - - def learn( - self, - **kwargs: Any, - ) -> None: - """ - Train the agent. - - :param kwargs: Any agent-framework specific key word args. - """ - if not self._training_config.session_type == SessionType.EVAL: - self._agent_session.learn(**kwargs) - - def evaluate( - self, - **kwargs: Any, - ) -> None: - """ - Evaluate the agent. - - :param kwargs: Any agent-framework specific key word args. - """ - if not self._training_config.session_type == SessionType.TRAIN: - self._agent_session.evaluate(**kwargs) - - def close(self) -> None: - """Closes the agent.""" - self._agent_session.close() - - def learn_av_reward_per_episode_dict(self) -> Dict[int, float]: - """Get the learn av reward per episode from file.""" - csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" - return av_rewards_dict(self.learning_path / csv_file) - - def eval_av_reward_per_episode_dict(self) -> Dict[int, float]: - """Get the eval av reward per episode from file.""" - csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" - return av_rewards_dict(self.evaluation_path / csv_file) - - def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: - """Get the learn all transactions from file.""" - csv_file = f"all_transactions_{self.timestamp_str}.csv" - return all_transactions_dict(self.learning_path / csv_file) - - def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: - """Get the eval all transactions from file.""" - csv_file = f"all_transactions_{self.timestamp_str}.csv" - return all_transactions_dict(self.evaluation_path / csv_file) - - def metadata_file_as_dict(self) -> Dict[str, Any]: - """Read the session_metadata.json file and return as a dict.""" - with open(self.session_path / "session_metadata.json", "r") as file: - return json.load(file) diff --git a/src/primaite/setup/old_installation_clean_up.py b/src/primaite/setup/old_installation_clean_up.py deleted file mode 100644 index 412aed60..00000000 --- a/src/primaite/setup/old_installation_clean_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - -from primaite import getLogger - -_LOGGER = getLogger(__name__) - - -def run() -> None: - """Perform the full clean-up.""" - pass - - -if __name__ == "__main__": - run() diff --git a/src/primaite/transactions/__init__.py b/src/primaite/transactions/__init__.py deleted file mode 100644 index 505c5080..00000000 --- a/src/primaite/transactions/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""Record data of the system's state and agent's observations and actions.""" diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py deleted file mode 100644 index 6b973ca3..00000000 --- a/src/primaite/transactions/transaction.py +++ /dev/null @@ -1,102 +0,0 @@ -# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -"""The Transaction class.""" -from datetime import datetime -from typing import List, Optional, Tuple, TYPE_CHECKING, Union - -from primaite.common.enums import AgentIdentifier - -if TYPE_CHECKING: - import numpy as np - from gymnasium import spaces - - -class Transaction(object): - """Transaction class.""" - - def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None: - """ - Transaction constructor. - - :param agent_identifier: An identifier for the agent in use - :param episode_number: The episode number - :param step_number: The step number - """ - self.timestamp: datetime = datetime.now() - "The datetime of the transaction" - self.agent_identifier: AgentIdentifier = agent_identifier - "The agent identifier" - self.episode_number: int = episode_number - "The episode number" - self.step_number: int = step_number - "The step number" - self.obs_space: "spaces.Space" = None - "The observation space (pre)" - self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None - "The observation space before any actions are taken" - self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None - "The observation space after any actions are taken" - self.reward: Optional[float] = None - "The reward value" - self.action_space: Optional[int] = None - "The action space invoked by the agent" - self.obs_space_description: Optional[List[str]] = None - "The env observation space description" - - def as_csv_data(self) -> Tuple[List, List]: - """ - Converts the Transaction to a csv data row and provides a header. - - :return: A tuple consisting of (header, data). - """ - if isinstance(self.action_space, int): - action_length = self.action_space - else: - action_length = self.action_space.size - - # Create the action space headers array - action_header = [] - for x in range(action_length): - action_header.append("AS_" + str(x)) - - # Open up a csv file - header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + self.obs_space_description - - row = [ - str(self.timestamp), - str(self.episode_number), - str(self.step_number), - str(self.reward), - ] - row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist() - return header, row - - -def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]: - """ - Turns action space into a string array so it can be saved to csv. - - :param action_space: The action space - :return: The action space as an array of strings - """ - if isinstance(action_space, list): - return [str(i) for i in action_space] - else: - return [str(action_space)] - - -def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]: - """ - Turns observation space into a string array so it can be saved to csv. - - :param obs_space: The observation space - :param obs_assets: The number of assets (i.e. nodes or links) in the observation space - :param obs_features: The number of features associated with the asset - :return: The observation space as an array of strings - """ - return_array = [] - for x in range(obs_assets): - for y in range(obs_features): - return_array.append(str(obs_space[x][y])) - - return return_array