diff --git a/docs/source/about.rst b/docs/source/about.rst index 60f621db..47511c1b 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -182,16 +182,13 @@ All ACL rules are considered when applying an IER. Logic follows the order of ru Observation Spaces ****************** +The observation space provides the blue agent with information about the current status of nodes and links. -The OpenAI Gym observation space provides the status of all nodes and links across the whole system: +PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted. -* Nodes (in terms of hardware state, Software State, file system state and services state) -* Links (in terms of current loading for each service/protocol) - -The observation space can be configured as a ``gym.spaces.Box`` or ``gym.spaces.MultiDiscrete``, by setting the ``OBSERVATIONS`` parameter in the laydown config. - -Box-type observation space --------------------------- +NodeLinkTable component +----------------------- +For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below: An example observation space is provided below: @@ -249,8 +246,6 @@ An example observation space is provided below: - 5000 - 0 -The observation space is a 6 x 6 Box type (OpenAI Gym Space) in this example. This is made up from the node and link information detailed below. - For the nodes, the following values are represented: * ID @@ -290,9 +285,9 @@ For the links, the following statuses are represented: * SoftwareState = N/A * Protocol = loading in bits/s -MultiDiscrete-type observation space ------------------------------------- -The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers. +NodeStatus component +---------------------- +This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states, represented by integers. The example above would have the following structure: .. code-block:: @@ -301,9 +296,6 @@ The example above would have the following structure: node1_info node2_info node3_info - link1_status - link2_status - link3_status ] Each ``node_info`` contains the following: @@ -318,7 +310,25 @@ Each ``node_info`` contains the following: service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED) ] -Each ``link_status`` is just a number from 0-4 representing the network load in relation to bandwidth. +In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: + +.. code-block:: + + gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4]) + +LinkTrafficLevels +----------------- +This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value. +The number of bins can be customised with 5 being the default. It has the following strucutre: +.. code-block:: + + [ + link1_status + link2_status + link3_status + ] + +Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth. .. code-block:: @@ -328,11 +338,11 @@ Each ``link_status`` is just a number from 0-4 representing the network load in 3 = high traffic (<100%) 4 = max traffic/ overwhelmed (100%) -The full observation space would have 15 node-related elements and 3 link-related elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: +If the network has three links, the full observation space would have 3 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: .. code-block:: - gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) + gym.spaces.MultiDiscrete([5,5,5]) Action Spaces ************** diff --git a/docs/source/config.rst b/docs/source/config.rst index 3c46be1d..64d9d27f 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -296,6 +296,34 @@ The Lay Down Config The lay down config file consists of the following attributes: +* **itemType: ACTIONS** [enum] + + Determines whether a NODE or ACL action space format is adopted for the session + +* **itemType: OBSERVATION_SPACE** [dict] + + Allows for user to configure observation space by combining one or more observation components. List of available + components is is :py:mod:'primaite.environment.observations'. + + The observation space config item should have a ``components`` key which is a list of components. Each component + config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the + component while it is being initialised. + + This example illustrates the correct format for the observation space config item + +.. code-block::yaml + + - itemType: OBSERVATION_SPACE + components: + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 + - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + +* **itemType: STEPS** [int] + * **item_type: PORTS** [int] Provides a list of ports modelled in this session diff --git a/pytest.ini b/pytest.ini index e618d7a5..b5fae8d0 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ [pytest] testpaths = tests +markers = + env_config_paths diff --git a/src/primaite/common/config_values_main.py b/src/primaite/common/config_values_main.py new file mode 100644 index 00000000..f822b77f --- /dev/null +++ b/src/primaite/common/config_values_main.py @@ -0,0 +1,91 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +"""The config class.""" + + +class ConfigValuesMain(object): + """Class to hold main config values.""" + + def __init__(self): + """Init.""" + # Generic + self.agent_identifier = "" # the agent in use + self.observation_config = None # observation space config + self.num_episodes = 0 # number of episodes to train over + self.num_steps = 0 # number of steps in an episode + self.time_delay = 0 # delay between steps (ms) - applies to generic agents only + self.config_filename_use_case = "" # the filename for the Use Case config file + self.session_type = "" # the session type to run (TRAINING or EVALUATION) + + # Environment + self.observation_space_high_value = ( + 0 # The high value for the observation space + ) + + # Reward values + # Generic + self.all_ok = 0 + # Node Hardware State + self.off_should_be_on = 0 + self.off_should_be_resetting = 0 + self.on_should_be_off = 0 + self.on_should_be_resetting = 0 + self.resetting_should_be_on = 0 + self.resetting_should_be_off = 0 + self.resetting = 0 + # Node Software or Service State + self.good_should_be_patching = 0 + self.good_should_be_compromised = 0 + self.good_should_be_overwhelmed = 0 + self.patching_should_be_good = 0 + self.patching_should_be_compromised = 0 + self.patching_should_be_overwhelmed = 0 + self.patching = 0 + self.compromised_should_be_good = 0 + self.compromised_should_be_patching = 0 + self.compromised_should_be_overwhelmed = 0 + self.compromised = 0 + self.overwhelmed_should_be_good = 0 + self.overwhelmed_should_be_patching = 0 + self.overwhelmed_should_be_compromised = 0 + self.overwhelmed = 0 + # Node File System State + self.good_should_be_repairing = 0 + self.good_should_be_restoring = 0 + self.good_should_be_corrupt = 0 + self.good_should_be_destroyed = 0 + self.repairing_should_be_good = 0 + self.repairing_should_be_restoring = 0 + self.repairing_should_be_corrupt = 0 + self.repairing_should_be_destroyed = ( + 0 # Repairing does not fix destroyed state - you need to restore + ) + self.repairing = 0 + self.restoring_should_be_good = 0 + self.restoring_should_be_repairing = 0 + self.restoring_should_be_corrupt = ( + 0 # Not the optimal method (as repair will fix corruption) + ) + self.restoring_should_be_destroyed = 0 + self.restoring = 0 + self.corrupt_should_be_good = 0 + self.corrupt_should_be_repairing = 0 + self.corrupt_should_be_restoring = 0 + self.corrupt_should_be_destroyed = 0 + self.corrupt = 0 + self.destroyed_should_be_good = 0 + self.destroyed_should_be_repairing = 0 + self.destroyed_should_be_restoring = 0 + self.destroyed_should_be_corrupt = 0 + self.destroyed = 0 + self.scanning = 0 + # IER status + self.red_ier_running = 0 + self.green_ier_blocked = 0 + + # Patching / Reset + self.os_patching_duration = 0 # The time taken to patch the OS + self.node_reset_duration = 0 # The time taken to reset a node (hardware) + self.service_patching_duration = 0 # The time taken to patch a service + self.file_system_repairing_limit = 0 # The time take to repair a file + self.file_system_restoring_limit = 0 # The time take to restore a file + self.file_system_scanning_limit = 0 # The time taken to scan the file system diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py new file mode 100644 index 00000000..9e71ef1b --- /dev/null +++ b/src/primaite/environment/observations.py @@ -0,0 +1,403 @@ +"""Module for handling configurable observation spaces in PrimAITE.""" +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union + +import numpy as np +from gym import spaces + +from primaite.common.enums import FileSystemState, HardwareState, SoftwareState +from primaite.nodes.active_node import ActiveNode +from primaite.nodes.service_node import ServiceNode + +# This dependency is only needed for type hints, +# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking +# Therefore, this avoids circular dependency problem. +if TYPE_CHECKING: + from primaite.environment.primaite_env import Primaite + + +_LOGGER = logging.getLogger(__name__) + + +class AbstractObservationComponent(ABC): + """Represents a part of the PrimAITE observation space.""" + + @abstractmethod + def __init__(self, env: "Primaite"): + _LOGGER.info(f"Initialising {self} observation component") + self.env: "Primaite" = env + self.space: spaces.Space + self.current_observation: np.ndarray # type might be too restrictive? + return NotImplemented + + @abstractmethod + def update(self): + """Update the observation based on the current state of the environment.""" + self.current_observation = NotImplemented + + +class NodeLinkTable(AbstractObservationComponent): + """Table with nodes and links as rows and hardware/software status as cols. + + This will create the observation space formatted as a table of integers. + There is one row per node, followed by one row per link. + The number of columns is 4 plus one per service. They are: + * node/link ID + * node hardware status / 0 for links + * node operating system status (if active/service) / 0 for links + * node file system status (active/service only) / 0 for links + * node service1 status / traffic load from that service for links + * node service2 status / traffic load from that service for links + * ... + * node serviceN status / traffic load from that service for links + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(12, 7)`` + """ + + _FIXED_PARAMETERS: int = 4 + _MAX_VAL: int = 1_000_000 + _DATA_TYPE: type = np.int64 + + def __init__(self, env: "Primaite"): + super().__init__(env) + + # 1. Define the shape of your observation space component + num_items = self.env.num_links + self.env.num_nodes + num_columns = self.env.num_services + self._FIXED_PARAMETERS + observation_shape = (num_items, num_columns) + + # 2. Create Observation space + self.space = spaces.Box( + low=0, + high=self._MAX_VAL, + shape=observation_shape, + dtype=self._DATA_TYPE, + ) + + # 3. Initialise Observation with zeroes + self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + + def update(self): + """Update the observation based on current environment state. + + The structure of the observation space is described in :class:`.NodeLinkTable` + """ + item_index = 0 + nodes = self.env.nodes + links = self.env.links + # Do nodes first + for _, node in nodes.items(): + self.current_observation[item_index][0] = int(node.node_id) + self.current_observation[item_index][1] = node.hardware_state.value + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + self.current_observation[item_index][2] = node.software_state.value + self.current_observation[item_index][ + 3 + ] = node.file_system_state_observed.value + else: + self.current_observation[item_index][2] = 0 + self.current_observation[item_index][3] = 0 + service_index = 4 + if isinstance(node, ServiceNode): + for service in self.env.services_list: + if node.has_service(service): + self.current_observation[item_index][ + service_index + ] = node.get_service_state(service).value + else: + self.current_observation[item_index][service_index] = 0 + service_index += 1 + else: + # Not a service node + for service in self.env.services_list: + self.current_observation[item_index][service_index] = 0 + service_index += 1 + item_index += 1 + + # Now do links + for _, link in links.items(): + self.current_observation[item_index][0] = int(link.get_id()) + self.current_observation[item_index][1] = 0 + self.current_observation[item_index][2] = 0 + self.current_observation[item_index][3] = 0 + protocol_list = link.get_protocol_list() + protocol_index = 0 + for protocol in protocol_list: + self.current_observation[item_index][ + protocol_index + 4 + ] = protocol.get_load() + protocol_index += 1 + item_index += 1 + + +class NodeStatuses(AbstractObservationComponent): + """Flat list of nodes' hardware, OS, file system, and service states. + + The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by + integers. + Each node has 3 elements plus 1 per service. It will have the following structure: + .. code-block:: + [ + node1 hardware state, + node1 OS state, + node1 file system state, + node1 service1 state, + node1 service2 state, + node1 serviceN state (one for each service), + node2 hardware state, + node2 OS state, + node2 file system state, + node2 service1 state, + node2 service2 state, + node2 serviceN state (one for each service), + ... + ] + + :param env: The environment that forms the basis of the observations + :type env: Primaite + """ + + _DATA_TYPE: type = np.int64 + + def __init__(self, env: "Primaite"): + super().__init__(env) + + # 1. Define the shape of your observation space component + node_shape = [ + len(HardwareState) + 1, + len(SoftwareState) + 1, + len(FileSystemState) + 1, + ] + services_shape = [len(SoftwareState) + 1] * self.env.num_services + node_shape = node_shape + services_shape + + shape = node_shape * self.env.num_nodes + # 2. Create Observation space + self.space = spaces.MultiDiscrete(shape) + + # 3. Initialise observation with zeroes + self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + + def update(self): + """Update the observation based on current environment state. + + The structure of the observation space is described in :class:`.NodeStatuses` + """ + obs = [] + for _, node in self.env.nodes.items(): + hardware_state = node.hardware_state.value + software_state = 0 + file_system_state = 0 + service_states = [0] * self.env.num_services + + if isinstance(node, ActiveNode): + software_state = node.software_state.value + file_system_state = node.file_system_state_observed.value + + if isinstance(node, ServiceNode): + for i, service in enumerate(self.env.services_list): + if node.has_service(service): + service_states[i] = node.get_service_state(service).value + obs.extend( + [hardware_state, software_state, file_system_state, *service_states] + ) + self.current_observation[:] = obs + + +class LinkTrafficLevels(AbstractObservationComponent): + """Flat list of traffic levels encoded into banded categories. + + For each link, total traffic or traffic per service is encoded into a categorical value. + For example, if ``quantisation_levels=5``, the traffic levels represent these values: + 0 = No traffic (0% of bandwidth) + 1 = No traffic (0%-33% of bandwidth) + 2 = No traffic (33%-66% of bandwidth) + 3 = No traffic (66%-100% of bandwidth) + 4 = No traffic (100% of bandwidth) + + .. note:: + The lowest category always corresponds to no traffic and the highest category to the link being at max capacity. + Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories. + + :param env: The environment that forms the basis of the observations + :type env: Primaite + :param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually, + defaults to False + :type combine_service_traffic: bool, optional + :param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value , + defaults to 5 + :type quantisation_levels: int, optional + """ + + _DATA_TYPE: type = np.int64 + + def __init__( + self, + env: "Primaite", + combine_service_traffic: bool = False, + quantisation_levels: int = 5, + ): + if quantisation_levels < 3: + _msg = ( + f"quantisation_levels must be 3 or more because the lowest and highest levels are " + f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. " + f"Resetting to default value (5)" + ) + _LOGGER.warning(_msg) + quantisation_levels = 5 + + super().__init__(env) + + self._combine_service_traffic: bool = combine_service_traffic + self._quantisation_levels: int = quantisation_levels + self._entries_per_link: int = 1 + + if not self._combine_service_traffic: + self._entries_per_link = self.env.num_services + + # 1. Define the shape of your observation space component + shape = ( + [self._quantisation_levels] * self.env.num_links * self._entries_per_link + ) + + # 2. Create Observation space + self.space = spaces.MultiDiscrete(shape) + + # 3. Initialise observation with zeroes + self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + + def update(self): + """Update the observation based on current environment state. + + The structure of the observation space is described in :class:`.LinkTrafficLevels` + """ + obs = [] + for _, link in self.env.links.items(): + bandwidth = link.bandwidth + if self._combine_service_traffic: + loads = [link.get_current_load()] + else: + loads = [protocol.get_load() for protocol in link.protocol_list] + + for load in loads: + if load <= 0: + traffic_level = 0 + elif load >= bandwidth: + traffic_level = self._quantisation_levels - 1 + else: + traffic_level = (load / bandwidth) // ( + 1 / (self._quantisation_levels - 2) + ) + 1 + + obs.append(int(traffic_level)) + + self.current_observation[:] = obs + + +class ObservationsHandler: + """Component-based observation space handler. + + This allows users to configure observation spaces by mixing and matching components. + Each component can also define further parameters to make them more flexible. + """ + + _REGISTRY: Final[Dict[str, type]] = { + "NODE_LINK_TABLE": NodeLinkTable, + "NODE_STATUSES": NodeStatuses, + "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, + } + + def __init__(self): + self.registered_obs_components: List[AbstractObservationComponent] = [] + self.space: spaces.Space + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] + + def update_obs(self): + """Fetch fresh information about the environment.""" + current_obs = [] + for obs in self.registered_obs_components: + obs.update() + current_obs.append(obs.current_observation) + + # If there is only one component, don't use a tuple, just pass through that component's obs. + if len(current_obs) == 1: + self.current_observation = current_obs[0] + else: + self.current_observation = tuple(current_obs) + # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + + def register(self, obs_component: AbstractObservationComponent): + """Add a component for this handler to track. + + :param obs_component: The component to add. + :type obs_component: AbstractObservationComponent + """ + self.registered_obs_components.append(obs_component) + self.update_space() + + def deregister(self, obs_component: AbstractObservationComponent): + """Remove a component from this handler. + + :param obs_component: Which component to remove. It must exist within this object's + ``registered_obs_components`` attribute. + :type obs_component: AbstractObservationComponent + """ + self.registered_obs_components.remove(obs_component) + self.update_space() + + def update_space(self): + """Rebuild the handler's composite observation space from its components.""" + component_spaces = [] + for obs_comp in self.registered_obs_components: + component_spaces.append(obs_comp.space) + + # If there is only one component, don't use a tuple space, just pass through that component's space. + if len(component_spaces) == 1: + self.space = component_spaces[0] + else: + self.space = spaces.Tuple(component_spaces) + # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + + @classmethod + def from_config(cls, env: "Primaite", obs_space_config: dict): + """Parse a config dictinary, return a new observation handler populated with new observation component objects. + + The expected format for the config dictionary is: + + ..code-block::python + config = { + components: [ + { + "name": "" + }, + { + "name": "" + "options": {"opt1": val1, "opt2": val2} + }, + { + ... + }, + ] + } + + :return: Observation handler + :rtype: primaite.environment.observations.ObservationsHandler + """ + # Instantiate the handler + handler = cls() + + for component_cfg in obs_space_config["components"]: + # Figure out which class can instantiate the desired component + comp_type = component_cfg["name"] + comp_class = cls._REGISTRY[comp_type] + + # Create the component with options from the YAML + options = component_cfg.get("options") or {} + component = comp_class(env, **options) + + handler.register(component) + + handler.update_obs() + return handler diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 4dc08ac3..35901198 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -24,11 +24,11 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, - ObservationType, Priority, SoftwareState, ) from primaite.common.service import Service +from primaite.environment.observations import ObservationsHandler from primaite.config import training_config from primaite.config.training_config import TrainingConfig from primaite.environment.reward import calculate_reward_function @@ -51,14 +51,11 @@ _LOGGER.setLevel(logging.INFO) class Primaite(Env): """PRIMmary AI Training Evironment (Primaite) class.""" - # Observation / Action Space contants - OBSERVATION_SPACE_FIXED_PARAMETERS = 4 - ACTION_SPACE_NODE_PROPERTY_VALUES = 5 - ACTION_SPACE_NODE_ACTION_VALUES = 4 - ACTION_SPACE_ACL_ACTION_VALUES = 3 - ACTION_SPACE_ACL_PERMISSION_VALUES = 2 - - OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space + # Action Space contants + ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5 + ACTION_SPACE_NODE_ACTION_VALUES: int = 4 + ACTION_SPACE_ACL_ACTION_VALUES: int = 3 + ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( self, @@ -165,8 +162,18 @@ class Primaite(Env): # Number of ports - gets a value when config is loaded self.num_ports = 0 - # Observation type, by default box. - self.observation_type = ObservationType.BOX + # The action type + self.action_type = 0 + + # TODO fix up with TrainingConfig + # stores the observation config from the yaml, default is NODE_LINK_TABLE + self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]} + if self.config_values.observation_config is not None: + self.obs_config = self.config_values.observation_config + + # Observation Handler manages the user-configurable observation space. + # It will be initialised later. + self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown @@ -229,7 +236,7 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected") + _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -424,9 +431,7 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes - print("") - print(_action) - print(self.action_dict) + if self.training_config.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: @@ -652,252 +657,20 @@ class Primaite(Env): else: pass - def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the BOX option chosen. - - This will create the observation space formatted as a table of integers. - There is one row per node, followed by one row per link. - Columns are as follows: - * node/link ID - * node hardware status / 0 for links - * node operating system status (if active/service) / 0 for links - * node file system status (active/service only) / 0 for links - * node service1 status / traffic load from that service for links - * node service2 status / traffic load from that service for links - * ... - * node serviceN status / traffic load from that service for links - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(12, 7)`` - - :return: Box gym observation - :rtype: gym.spaces.Box - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space type BOX selected") - - # 1. Determine observation shape from laydown - num_items = self.num_links + self.num_nodes - num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - observation_shape = (num_items, num_observation_parameters) - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.Box( - low=0, - high=self.OBSERVATION_SPACE_HIGH_VALUE, - shape=observation_shape, - dtype=np.int64, - ) - initial_observation = np.zeros(observation_shape, dtype=np.int64) - - return observation_space, initial_observation - - def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the MULTIDISCRETE option chosen. - - This will create the observation space with node observations followed by link observations. - Each node has 3 elements in the observation space plus 1 per service, more specifically: - * hardware state - * operating system state - * file system state - * service states (one per service) - Each link has one element in the observation space, corresponding to the traffic load, - it can take the following values: - 0 = No traffic (0% of bandwidth) - 1 = No traffic (0%-33% of bandwidth) - 2 = No traffic (33%-66% of bandwidth) - 3 = No traffic (66%-100% of bandwidth) - 4 = No traffic (100% of bandwidth) - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(37,)`` - - :return: MultiDiscrete gym observation - :rtype: gym.spaces.MultiDiscrete - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space MULTIDISCRETE selected") - - # 1. Determine observation shape from laydown - node_obs_shape = [ - len(HardwareState) + 1, - len(SoftwareState) + 1, - len(FileSystemState) + 1, - ] - node_services = [len(SoftwareState) + 1] * self.num_services - node_obs_shape = node_obs_shape + node_services - # the magic number 5 refers to 5 states of quantisation of traffic amount. - # (zero, low, medium, high, fully utilised/overwhelmed) - link_obs_shape = [5] * self.num_links - observation_shape = node_obs_shape * self.num_nodes + link_obs_shape - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.MultiDiscrete(observation_shape) - initial_observation = np.zeros(len(observation_shape), dtype=np.int64) - - return observation_space, initial_observation - def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Build the observation space based on network laydown and provide initial obs. + """Create the environment's observation handler. - This method uses the object's `num_links`, `num_nodes`, `num_services`, - `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` - attributes to figure out the correct shape and format for the observation space. - - :raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType` - :return: Gym observation space - :rtype: gym.spaces.Space - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array + :return: The observation space, initial observation (zeroed out array with the correct shape) + :rtype: Tuple[spaces.Space, np.ndarray] """ - if self.observation_type == ObservationType.BOX: - observation_space, initial_observation = self._init_box_observations() - return observation_space, initial_observation - elif self.observation_type == ObservationType.MULTIDISCRETE: - ( - observation_space, - initial_observation, - ) = self._init_multidiscrete_observations() - return observation_space, initial_observation - else: - errmsg = ( - f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" - f", got {self.observation_type} instead" - ) - _LOGGER.error(errmsg) - raise ValueError(errmsg) + self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) - def _update_env_obs_box(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_box_observations` - This function can only be called if the observation space setting is set to BOX. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.BOX - item_index = 0 - - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 - - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 - - def _update_env_obs_multidiscrete(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_multidiscrete_observations` - This function can only be called if the observation space setting is set to MULTIDISCRETE. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.MULTIDISCRETE - obs = [] - # 1. Set nodes - # Each node has the following variables in the observation space: - # - Hardware state - # - Software state - # - File System state - # - Service 1 state - # - Service 2 state - # - ... - # - Service N state - for node_key, node in self.nodes.items(): - hardware_state = node.hardware_state.value - software_state = 0 - file_system_state = 0 - services_states = [0] * self.num_services - - if isinstance( - node, ActiveNode - ): # ServiceNode is a subclass of ActiveNode so no need to check that also - software_state = node.software_state.value - file_system_state = node.file_system_state_observed.value - - if isinstance(node, ServiceNode): - for i, service in enumerate(self.services_list): - if node.has_service(service): - services_states[i] = node.get_service_state(service).value - - obs.extend( - [ - hardware_state, - software_state, - file_system_state, - *services_states, - ] - ) - - # 2. Set links - # Each link has just one variable in the observation space, it represents the traffic amount - # In order for the space to be fully MultiDiscrete, the amount of - # traffic on each link is quantised into a few levels: - # 0: no traffic (0% of bandwidth) - # 1: low traffic (0-33% of bandwidth) - # 2: medium traffic (33-66% of bandwidth) - # 3: high traffic (66-100% of bandwidth) - # 4: max traffic/overloaded (100% of bandwidth) - - for link_key, link in self.links.items(): - bandwidth = link.bandwidth - load = link.get_current_load() - - if load <= 0: - traffic_level = 0 - elif load >= bandwidth: - traffic_level = 4 - else: - traffic_level = (load / bandwidth) // (1 / 3) + 1 - - obs.append(int(traffic_level)) - - self.env_obs = np.asarray(obs) + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): """Updates the observation space based on the node and link status.""" - if self.observation_type == ObservationType.BOX: - self._update_env_obs_box() - elif self.observation_type == ObservationType.MULTIDISCRETE: - self._update_env_obs_multidiscrete() + self.obs_handler.update_obs() + self.env_obs = self.obs_handler.current_observation def load_lay_down_config(self): """Loads config data in order to build the environment configuration.""" @@ -929,11 +702,9 @@ class Primaite(Env): elif item["item_type"] == "PORTS": # Create the list of ports self.create_ports_list(item) - elif item["item_type"] == "OBSERVATIONS": - # Get the observation information - self.get_observation_info(item) else: - # Do nothing (bad formatting) + item_type = item["item_type"] + _LOGGER.error(f"Invalid item_type: {item_type}") pass _LOGGER.info("Environment configuration loaded") @@ -1260,6 +1031,28 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] + + def get_action_info(self, action_info): + """ + Extracts action_info. + + Args: + item: A config data item representing action info + """ + self.action_type = ActionType[action_info["type"]] + + def save_obs_config(self, obs_config: dict): + """Cache the config for the observation space. + + This is necessary as the observation space can't be built while reading the config, + it must be done after all the nodes, links, and services have been initialised. + + :param obs_config: Parsed config relating to the observation space. The format is described in + :py:meth:`primaite.environment.observations.ObservationsHandler.from_config` + :type obs_config: dict + """ + self.obs_config = obs_config + def reset_environment(self): """ # Resets environment. diff --git a/tests/config/multidiscrete_obs_space_laydown_config.yaml b/tests/config/multidiscrete_obs_space_laydown_config.yaml deleted file mode 100644 index d7b3703c..00000000 --- a/tests/config/multidiscrete_obs_space_laydown_config.yaml +++ /dev/null @@ -1,68 +0,0 @@ -- item_type: ACTIONS - type: NODE -- item_type: OBSERVATIONS - type: MULTIDISCRETE -- item_type: STEPS - steps: 5 -- item_type: PORTS - ports_list: - - port: '80' -- item_type: SERVICES - service_list: - - name: TCP - -######################################## -# Nodes -- item_type: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD -- item_type: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - -######################################## -# Links -- item_type: LINK - id: '4' - name: link1 - bandwidth: 1000 - source: '1' - destination: '3' -- item_type: LINK - id: '5' - name: link2 - bandwidth: 1000 - source: '3' - destination: '2' diff --git a/tests/config/box_obs_space_laydown_config.yaml b/tests/config/obs_tests/laydown.yaml similarity index 62% rename from tests/config/box_obs_space_laydown_config.yaml rename to tests/config/obs_tests/laydown.yaml index 26e353fa..b110befc 100644 --- a/tests/config/box_obs_space_laydown_config.yaml +++ b/tests/config/obs_tests/laydown.yaml @@ -1,15 +1,15 @@ - item_type: ACTIONS type: NODE -- item_type: OBSERVATIONS - type: BOX - item_type: STEPS steps: 5 - item_type: PORTS ports_list: - port: '80' + - port: '53' - item_type: SERVICES service_list: - name: TCP + - name: UDP ######################################## # Nodes @@ -21,12 +21,15 @@ priority: P5 hardware_state: 'ON' ip_address: 192.168.1.1 - software_state: GOOD + software_state: COMPROMISED file_system_state: GOOD services: - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - item_type: NODE node_id: '2' name: SERVER @@ -41,6 +44,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: OVERWHELMED - item_type: NODE node_id: '3' name: SWITCH1 @@ -66,3 +72,33 @@ bandwidth: 1000 source: '3' destination: '2' + +######################################### +# IERS +- item_type: GREEN_IER + id: '5' + start_step: 0 + end_step: 5 + load: 999 + protocol: TCP + port: '80' + source: '1' + destination: '2' + mission_criticality: 5 + +######################################### +# ACL Rules +- itemType: ACL_RULE + id: '6' + permission: ALLOW + source: 192.168.1.1 + destination: 192.168.1.2 + protocol: TCP + port: 80 +- itemType: ACL_RULE + id: '7' + permission: ALLOW + source: 192.168.1.2 + destination: 192.168.1.1 + protocol: TCP + port: 80 diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml new file mode 100644 index 00000000..cdb741f3 --- /dev/null +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -0,0 +1,96 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml new file mode 100644 index 00000000..19d220c8 --- /dev/null +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -0,0 +1,93 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: NODE_LINK_TABLE + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml new file mode 100644 index 00000000..25520ccc --- /dev/null +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -0,0 +1,93 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: NODE_STATUSES + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_without_obs.yaml b/tests/config/obs_tests/main_config_without_obs.yaml new file mode 100644 index 00000000..43ee251f --- /dev/null +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 6ecc5c1b..0df59b72 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,36 +1,220 @@ """Test env creation and behaviour with different observation spaces.""" +import numpy as np +import pytest +from primaite.environment.observations import ( + NodeLinkTable, + NodeStatuses, + ObservationsHandler, +) +from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config -def test_creating_env_with_box_obs(): - """Try creating env with box observation space.""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", +@pytest.fixture +def env(request): + """Build Primaite environment for integration tests of observation space.""" + marker = request.node.get_closest_marker("env_config_paths") + main_config_path = marker.args[0]["main_config_path"] + lay_down_config_path = marker.args[0]["lay_down_config_path"] + env, _ = _get_primaite_env_from_config( + main_config_path=main_config_path, + lay_down_config_path=lay_down_config_path, ) + yield env + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ) +) +def test_default_obs_space(env: Primaite): + """Create environment with no obs space defined in config and check that the default obs space was created.""" env.update_environent_obs() - # we have three nodes and two links, with one service - # therefore the box observation space will have: - # * 5 columns (four fixed and one for the service) - # * 5 rows (3 nodes + 2 links) - assert env.env_obs.shape == (5, 5) + components = env.obs_handler.registered_obs_components + + assert len(components) == 1 + assert isinstance(components[0], NodeLinkTable) -def test_creating_env_with_multidiscrete_obs(): - """Try creating env with MultiDiscrete observation space.""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "multidiscrete_obs_space_laydown_config.yaml", +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) - env.update_environent_obs() +) +def test_registering_components(env: Primaite): + """Test regitering and deregistering a component.""" + handler = ObservationsHandler() + component = NodeStatuses(env) + handler.register(component) + assert component in handler.registered_obs_components + handler.deregister(component) + assert component not in handler.registered_obs_components - # we have three nodes and two links, with one service - # the nodes have hardware, OS, FS, and service, the links just have bandwidth, - # therefore we need 3*4 + 2 observations - assert env.env_obs.shape == (3 * 4 + 2,) + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_NODE_LINK_TABLE.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ) +) +class TestNodeLinkTable: + """Test the NodeLinkTable observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with box observation space.""" + env.update_environent_obs() + + # we have three nodes and two links, with two service + # therefore the box observation space will have: + # * 5 rows (3 nodes + 2 links) + # * 6 columns (four fixed and two for the services) + assert env.env_obs.shape == (5, 6) + + def test_value(self, env: Primaite): + """Test that the observation is generated correctly. + + The laydown has: + * 3 nodes (2 service nodes and 1 active node) + * 2 services + * 2 links + + Both nodes have both services, and all states are GOOD, therefore the expected observation value is: + + * Node 1: + * 1 (id) + * 1 (good hardware state) + * 3 (compromised OS state) + * 1 (good file system state) + * 1 (good TCP state) + * 1 (good UDP state) + * Node 2: + * 2 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 1 (good TCP state) + * 4 (overwhelmed UDP state) + * Node 3 (active node): + * 3 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 0 (doesn't have service1) + * 0 (doesn't have service2) + * Link 1: + * 4 (id) + * 0 (n/a hardware state) + * 0 (n/a OS state) + * 0 (n/a file system state) + * 999 (999 traffic for service1) + * 0 (no traffic for service2) + * Link 2: + * 5 (id) + * 0 (good hardware state) + * 0 (good OS state) + * 0 (good file system state) + * 999 (999 traffic service1) + * 0 (no traffic for service2) + """ + # act = np.asarray([0,]) + obs, reward, done, info = env.step(0) # apply the 'do nothing' action + + assert np.array_equal( + obs, + [ + [1, 1, 3, 1, 1, 1], + [2, 1, 1, 1, 1, 4], + [3, 1, 1, 1, 0, 0], + [4, 0, 0, 0, 999, 0], + [5, 0, 0, 0, 999, 0], + ], + ) + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ) +) +class TestNodeStatuses: + """Test the NodeStatuses observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with NodeStatuses as the only component.""" + assert env.env_obs.shape == (15,) + + def test_values(self, env: Primaite): + """Test that the hardware and software states are encoded correctly. + + The laydown has: + * one node with a compromised operating system state + * one node with two services, and the second service is overwhelmed. + * all other states are good or null + Therefore, the expected state is: + * node 1: + * hardware = good (1) + * OS = compromised (3) + * file system = good (1) + * service 1 = good (1) + * service 2 = good (1) + * node 2: + * hardware = good (1) + * OS = good (1) + * file system = good (1) + * service 1 = good (1) + * service 2 = overwhelmed (4) + * node 3 (switch): + * hardware = good (1) + * OS = good (1) + * file system = good (1) + * service 1 = n/a (0) + * service 2 = n/a (0) + """ + obs, _, _, _ = env.step(0) # apply the 'do nothing' action + assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ) +) +class TestLinkTrafficLevels: + """Test the LinkTrafficLevels observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with MultiDiscrete observation space.""" + env.update_environent_obs() + + # we have two links and two services, so the shape should be 2 * 2 + assert env.env_obs.shape == (2 * 2,) + + def test_values(self, env: Primaite): + """Test that traffic values are encoded correctly. + + The laydown has: + * two services + * three nodes + * two links + * an IER trying to send 999 bits of data over both links the whole time (via the first service) + * link bandwidth of 1000, therefore the utilisation is 99.9% + """ + obs, reward, done, info = env.step(0) + obs, reward, done, info = env.step(0) + + # the observation space has combine_service_traffic set to False, so the space has this format: + # [link1_service1, link1_service2, link2_service1, link2_service2] + # we send 999 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 6 and all others 0 + # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) + assert np.array_equal(obs, [6, 0, 6, 0])