From 375e20a67b6b8e22fa9bcd73e821df766cf4dbff Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 15:11:41 +0100 Subject: [PATCH] Configure observation type MULTIDISCRETE --- src/primaite/common/enums.py | 7 + src/primaite/config/config_1_DDOS_BASIC.yaml | 2 + src/primaite/environment/primaite_env.py | 236 ++++++++++++++----- 3 files changed, 188 insertions(+), 57 deletions(-) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0e00c9e4..138d2742 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -83,6 +83,13 @@ class ActionType(Enum): ACL = 1 +class ObservationType(Enum): + """Observation type enumeration.""" + + BOX = 0 + MULTIDISCRETE = 1 + + class FileSystemState(Enum): """File System State.""" diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index ada813f3..2796adb4 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,5 +1,7 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATIONS + type: MULTIDISCRETE - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 99c7c09f..b85410c0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -23,6 +23,7 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, SoftwareState, ) @@ -148,6 +149,9 @@ class Primaite(Env): # The action type self.action_type = 0 + # Observation type. + self.observation_type = 0 + # Open the config file and build the environment laydown try: self.config_file = open(self.config_values.config_filename_use_case, "r") @@ -203,26 +207,8 @@ class Primaite(Env): # - service E state | service E loading # - service F state | service F loading # - service G state | service G loading - - # Calculate the number of items that need to be included in the - # observation space - num_items = self.num_links + self.num_nodes - # Set the number of observation parameters, being # of services plus id, - # hardware state, file system state and SoftwareState (i.e. 4) - self.num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - # Define the observation shape - self.observation_shape = (num_items, self.num_observation_parameters) - self.observation_space = spaces.Box( - low=0, - high=self.config_values.observation_space_high_value, - shape=self.observation_shape, - dtype=np.int64, - ) - - # This is the observation that is sent back via the rest and step functions - self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) + # Initiate observation space + self.observation_space, self.env_obs = self.init_observations() # Define Action Space - depends on action space type (Node or ACL) if self.action_type == ActionType.NODE: @@ -671,49 +657,172 @@ class Primaite(Env): else: pass + def init_observations(self): + """Build the observation space based on network laydown and provide initial obs. + + This method uses the object's `num_links`, `num_nodes`, `num_services`, + `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` + attributes to figure out the correct shape and format for the observation space. + + Returns + ------- + gym.spaces.Space + Gym observation space + numpy.Array + Initial observation with all entries set to 0 + """ + if self.observation_type == ObservationType.BOX: + _LOGGER.info("Observation space type BOX selected") + + # 1. Determine observation shape from laydown + num_items = self.num_links + self.num_nodes + num_observation_parameters = ( + self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS + ) + observation_shape = (num_items, num_observation_parameters) + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.Box( + low=0, + high=self.OBSERVATION_SPACE_HIGH_VALUE, + shape=observation_shape, + dtype=np.int64, + ) + initial_observation = np.zeros(observation_shape, dtype=np.int64) + + elif self.observation_type == ObservationType.MULTIDISCRETE: + _LOGGER.info("Observation space MULTIDISCRETE selected") + + # 1. Determine observation shape from laydown + node_obs_shape = [ + len(HardwareState) + 1, + len(SoftwareState) + 1, + len(FileSystemState) + 1, + ] + node_services = [len(SoftwareState) + 1] * self.num_services + node_obs_shape = node_obs_shape + node_services + # the magic number 5 refers to 5 states of quantisation of traffic amount. + # (zero, low, medium, high, fully utilised/overwhelmed) + link_obs_shape = [5] * self.num_links + observation_shape = node_obs_shape + link_obs_shape + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.MultiDiscrete(observation_shape) + initial_observation = np.zeros(len(observation_shape), dtype=np.int64) + else: + raise ValueError( + f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" + f", got {self.observation_type} instead" + ) + + return observation_space, initial_observation + def update_environent_obs(self): """Updates the observation space based on the node and link status.""" - item_index = 0 + if self.observation_type == ObservationType.BOX: + item_index = 0 - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: + # Do nodes first + for node_key, node in self.nodes.items(): + self.env_obs[item_index][0] = int(node.node_id) + self.env_obs[item_index][1] = node.hardware_state.value + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + self.env_obs[item_index][2] = node.software_state.value + self.env_obs[item_index][3] = node.file_system_state_observed.value + else: + self.env_obs[item_index][2] = 0 + self.env_obs[item_index][3] = 0 + service_index = 4 + if isinstance(node, ServiceNode): + for service in self.services_list: + if node.has_service(service): + self.env_obs[item_index][ + service_index + ] = node.get_service_state(service).value + else: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + else: + # Not a service node + for service in self.services_list: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + item_index += 1 + + # Now do links + for link_key, link in self.links.items(): + self.env_obs[item_index][0] = int(link.get_id()) + self.env_obs[item_index][1] = 0 self.env_obs[item_index][2] = 0 self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 + protocol_list = link.get_protocol_list() + protocol_index = 0 + for protocol in protocol_list: + self.env_obs[item_index][protocol_index + 4] = protocol.get_load() + protocol_index += 1 + item_index += 1 - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 + elif self.observation_type == ObservationType.MULTIDISCRETE: + obs = [] + # 1. Set nodes + # Each node has the following variables in the observation space: + # - Hardware state + # - Software state + # - File System state + # - Service 1 state + # - Service 2 state + # - ... + # - Service N state + for node_key, node in self.nodes.items(): + hardware_state = node.hardware_state.value + software_state = 0 + file_system_state = 0 + services_states = [0] * self.num_services + + if isinstance( + node, ActiveNode + ): # ServiceNode is a subclass of ActiveNode so no need to check that also + software_state = node.software_state.value + file_system_state = node.file_system_state_observed.value + + if isinstance(node, ServiceNode): + for i, service in enumerate(self.services_list): + if node.has_service(service): + services_states[i] = node.get_service_state(service).value + + obs.extend( + [ + hardware_state, + software_state, + file_system_state, + *services_states, + ] + ) + + # 2. Set links + # Each link has just one variable in the observation space, it represents the traffic amount + # In order for the space to be fully MultiDiscrete, the amount of + # traffic on each link is quantised into a few levels: + # 0: no traffic (0% of bandwidth) + # 1: low traffic (0-33% of bandwidth) + # 2: medium traffic (33-66% of bandwidth) + # 3: high traffic (66-100% of bandwidth) + # 4: max traffic/overloaded (100% of bandwidth) + + for link_key, link in self.links.items(): + bandwidth = link.bandwidth + load = link.get_current_load() + + if load <= 0: + traffic_level = 0 + elif load >= bandwidth: + traffic_level = 4 + else: + traffic_level = (load / bandwidth) // (1 / 3) + 1 + + obs.append(int(traffic_level)) + + self.env_obs = np.asarray(obs) def load_config(self): """Loads config data in order to build the environment configuration.""" @@ -748,6 +857,9 @@ class Primaite(Env): elif item["itemType"] == "ACTIONS": # Get the action information self.get_action_info(item) + elif item["itemType"] == "OBSERVATIONS": + # Get the observation information + self.get_observation_info(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1080,6 +1192,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] + def get_observation_info(self, observation_info): + """Extracts observation_info. + + Parameters + ---------- + observation_info : str + Config item that defines which type of observation space to use + """ + self.observation_type = ObservationType[observation_info["type"]] + def get_steps_info(self, steps_info): """ Extracts steps_info.