diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index ada813f3..a1961df3 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,5 +1,13 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATION_SPACE + components: + - name: NODE_LINK_TABLE + - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + options: + - combine_service_traffic : False + - quantisation_levels : 7 - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 338c11a1..94c2730f 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -35,7 +35,23 @@ class AbstractObservationComponent(ABC): class NodeLinkTable(AbstractObservationComponent): """Table with nodes/links as rows and hardware/software status as cols. - #todo: write full description + Initialise the observation space with the BOX option chosen. + + This will create the observation space formatted as a table of integers. + There is one row per node, followed by one row per link. + Columns are as follows: + * node/link ID + * node hardware status / 0 for links + * node operating system status (if active/service) / 0 for links + * node file system status (active/service only) / 0 for links + * node service1 status / traffic load from that service for links + * node service2 status / traffic load from that service for links + * ... + * node serviceN status / traffic load from that service for links + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(12, 7)`` + #todo: clean up description """ @@ -44,7 +60,7 @@ class NodeLinkTable(AbstractObservationComponent): _DATA_TYPE = np.int64 def __init__(self, env: Primaite): - super().__init__() + super().__init__(env) # 1. Define the shape of your observation space component num_items = self.env.num_links + self.env.num_nodes @@ -65,6 +81,10 @@ class NodeLinkTable(AbstractObservationComponent): def update_obs(self): """Update the observation. + Update the environment's observation state based on the current status of nodes and links. + + The structure of the observation space is described in :func:`~_init_box_observations` + This function can only be called if the observation space setting is set to BOX. todo: complete description.. """ item_index = 0 @@ -116,12 +136,20 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """todo: complete description.""" + """todo: complete description. + + This will create the observation space with node observations followed by link observations. + Each node has 3 elements in the observation space plus 1 per service, more specifically: + * hardware state + * operating system state + * file system state + * service states (one per service) + """ _DATA_TYPE = np.int64 - def __init__(self): - super().__init__() + def __init__(self, env: Primaite): + super().__init__(env) # 1. Define the shape of your observation space component shape = [ @@ -139,7 +167,15 @@ class NodeStatuses(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update_obs(self): - """todo: complete description.""" + """todo: complete description. + + Update the environment's observation state based on the current status of nodes and links. + + The structure of the observation space is described in :func:`~_init_multidiscrete_observations` + This function can only be called if the observation space setting is set to MULTIDISCRETE. + + + """ obs = [] for _, node in self.env.nodes.items(): hardware_state = node.hardware_state.value @@ -160,14 +196,26 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """todo: complete description.""" + """todo: complete description. + + Each link has one element in the observation space, corresponding to the traffic load, + it can take the following values: + 0 = No traffic (0% of bandwidth) + 1 = No traffic (0%-33% of bandwidth) + 2 = No traffic (33%-66% of bandwidth) + 3 = No traffic (66%-100% of bandwidth) + 4 = No traffic (100% of bandwidth) + """ _DATA_TYPE = np.int64 def __init__( - self, combine_service_traffic: bool = False, quantisation_levels: int = 5 + self, + env: Primaite, + combine_service_traffic: bool = False, + quantisation_levels: int = 5, ): - super().__init__() + super().__init__(env) self._combine_service_traffic: bool = combine_service_traffic self._quantisation_levels: int = quantisation_levels self._entries_per_link: int = 1 @@ -212,7 +260,7 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: - """todo: complete description.""" + """Component-based observation space handler.""" class registry(Enum): """todo: complete description.""" @@ -254,3 +302,25 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) self.space = spaces.Tuple(component_spaces) + + @classmethod + def from_config(cls, obs_space_config): + """todo: complete description. + + This method parses config items related to the observation space, then + creates the necessary components and adds them to the observation handler. + """ + # Instantiate the handler + handler = cls() + + for component_cfg in obs_space_config["components"]: + # Figure out which class can instantiate the desired component + comp_type = component_cfg["name"] + comp_class = cls.registry[comp_type].value + + # Create the component with options from the YAML + component = comp_class(**component_cfg["options"]) + + handler.register(component) + + return handler diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 56893ee9..afa04060 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,7 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import networkx as nx import numpy as np @@ -23,11 +23,11 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, - ObservationType, Priority, SoftwareState, ) from primaite.common.service import Service +from primaite.environment.observations import ObservationsHandler from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode @@ -149,8 +149,8 @@ class Primaite(Env): # The action type self.action_type = 0 - # Observation type, by default box. - self.observation_type = ObservationType.BOX + # todo: proper description here + self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown try: @@ -161,6 +161,10 @@ class Primaite(Env): _LOGGER.error("Could not load the environment configuration") _LOGGER.error("Exception occured", exc_info=True) + # If it doesn't exist after parsing config, create default obs space. + if self.get("obs_handler") is None: + self.configure_obs_space() + # Store the node objects as node attributes # (This is so we can access them as objects) for node in self.network: @@ -641,252 +645,17 @@ class Primaite(Env): else: pass - def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the BOX option chosen. - - This will create the observation space formatted as a table of integers. - There is one row per node, followed by one row per link. - Columns are as follows: - * node/link ID - * node hardware status / 0 for links - * node operating system status (if active/service) / 0 for links - * node file system status (active/service only) / 0 for links - * node service1 status / traffic load from that service for links - * node service2 status / traffic load from that service for links - * ... - * node serviceN status / traffic load from that service for links - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(12, 7)`` - - :return: Box gym observation - :rtype: gym.spaces.Box - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space type BOX selected") - - # 1. Determine observation shape from laydown - num_items = self.num_links + self.num_nodes - num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - observation_shape = (num_items, num_observation_parameters) - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.Box( - low=0, - high=self.OBSERVATION_SPACE_HIGH_VALUE, - shape=observation_shape, - dtype=np.int64, - ) - initial_observation = np.zeros(observation_shape, dtype=np.int64) - - return observation_space, initial_observation - - def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the MULTIDISCRETE option chosen. - - This will create the observation space with node observations followed by link observations. - Each node has 3 elements in the observation space plus 1 per service, more specifically: - * hardware state - * operating system state - * file system state - * service states (one per service) - Each link has one element in the observation space, corresponding to the traffic load, - it can take the following values: - 0 = No traffic (0% of bandwidth) - 1 = No traffic (0%-33% of bandwidth) - 2 = No traffic (33%-66% of bandwidth) - 3 = No traffic (66%-100% of bandwidth) - 4 = No traffic (100% of bandwidth) - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(37,)`` - - :return: MultiDiscrete gym observation - :rtype: gym.spaces.MultiDiscrete - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space MULTIDISCRETE selected") - - # 1. Determine observation shape from laydown - node_obs_shape = [ - len(HardwareState) + 1, - len(SoftwareState) + 1, - len(FileSystemState) + 1, - ] - node_services = [len(SoftwareState) + 1] * self.num_services - node_obs_shape = node_obs_shape + node_services - # the magic number 5 refers to 5 states of quantisation of traffic amount. - # (zero, low, medium, high, fully utilised/overwhelmed) - link_obs_shape = [5] * self.num_links - observation_shape = node_obs_shape * self.num_nodes + link_obs_shape - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.MultiDiscrete(observation_shape) - initial_observation = np.zeros(len(observation_shape), dtype=np.int64) - - return observation_space, initial_observation - def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Build the observation space based on network laydown and provide initial obs. - - This method uses the object's `num_links`, `num_nodes`, `num_services`, - `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` - attributes to figure out the correct shape and format for the observation space. - - :raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType` - :return: Gym observation space - :rtype: gym.spaces.Space - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - if self.observation_type == ObservationType.BOX: - observation_space, initial_observation = self._init_box_observations() - return observation_space, initial_observation - elif self.observation_type == ObservationType.MULTIDISCRETE: - ( - observation_space, - initial_observation, - ) = self._init_multidiscrete_observations() - return observation_space, initial_observation - else: - errmsg = ( - f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" - f", got {self.observation_type} instead" - ) - _LOGGER.error(errmsg) - raise ValueError(errmsg) - - def _update_env_obs_box(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_box_observations` - This function can only be called if the observation space setting is set to BOX. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.BOX - item_index = 0 - - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 - - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 - - def _update_env_obs_multidiscrete(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_multidiscrete_observations` - This function can only be called if the observation space setting is set to MULTIDISCRETE. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.MULTIDISCRETE - obs = [] - # 1. Set nodes - # Each node has the following variables in the observation space: - # - Hardware state - # - Software state - # - File System state - # - Service 1 state - # - Service 2 state - # - ... - # - Service N state - for node_key, node in self.nodes.items(): - hardware_state = node.hardware_state.value - software_state = 0 - file_system_state = 0 - services_states = [0] * self.num_services - - if isinstance( - node, ActiveNode - ): # ServiceNode is a subclass of ActiveNode so no need to check that also - software_state = node.software_state.value - file_system_state = node.file_system_state_observed.value - - if isinstance(node, ServiceNode): - for i, service in enumerate(self.services_list): - if node.has_service(service): - services_states[i] = node.get_service_state(service).value - - obs.extend( - [ - hardware_state, - software_state, - file_system_state, - *services_states, - ] - ) - - # 2. Set links - # Each link has just one variable in the observation space, it represents the traffic amount - # In order for the space to be fully MultiDiscrete, the amount of - # traffic on each link is quantised into a few levels: - # 0: no traffic (0% of bandwidth) - # 1: low traffic (0-33% of bandwidth) - # 2: medium traffic (33-66% of bandwidth) - # 3: high traffic (66-100% of bandwidth) - # 4: max traffic/overloaded (100% of bandwidth) - - for link_key, link in self.links.items(): - bandwidth = link.bandwidth - load = link.get_current_load() - - if load <= 0: - traffic_level = 0 - elif load >= bandwidth: - traffic_level = 4 - else: - traffic_level = (load / bandwidth) // (1 / 3) + 1 - - obs.append(int(traffic_level)) - - self.env_obs = np.asarray(obs) + """todo: write docstring.""" + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): - """Updates the observation space based on the node and link status.""" - if self.observation_type == ObservationType.BOX: - self._update_env_obs_box() - elif self.observation_type == ObservationType.MULTIDISCRETE: - self._update_env_obs_multidiscrete() + """Updates the observation space based on the node and link status. + + todo: better docstring + """ + self.obs_handler.update_obs() + self.env_obs = self.obs_handler.current_observation def load_config(self): """Loads config data in order to build the environment configuration.""" @@ -921,9 +690,9 @@ class Primaite(Env): elif item["itemType"] == "ACTIONS": # Get the action information self.get_action_info(item) - elif item["itemType"] == "OBSERVATIONS": + elif item["itemType"] == "OBSERVATION_SPACE": # Get the observation information - self.get_observation_info(item) + self.configure_obs_space(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1256,13 +1025,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def get_observation_info(self, observation_info): - """Extracts observation_info. + def configure_obs_space(self, observation_config: Optional[Dict] = None): + """todo: better docstring.""" + if observation_config is None: + observation_config = { + "components": [ + {"name": "NODE_LINK_TABLE"}, + ] + } - :param observation_info: Config item that defines which type of observation space to use - :type observation_info: str - """ - self.observation_type = ObservationType[observation_info["type"]] + self.obs_handler = ObservationsHandler[observation_config] def get_steps_info(self, steps_info): """