From f0e1cb213863f57caf33bbd65fca786b5a268ee1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 10:57:11 +0100 Subject: [PATCH] Separate obs functions and provide docstrings --- src/primaite/environment/primaite_env.py | 360 ++++++++++++++--------- 1 file changed, 220 insertions(+), 140 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 6d956859..67ab5375 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -641,6 +641,95 @@ class Primaite(Env): else: pass + def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]: + """Initialise the observation space with the BOX option chosen. + + This will create the observation space formatted as a table of integers. + There is one row per node, followed by one row per link. + Columns are as follows: + * node/link ID + * node hardware status / 0 for links + * node operating system status (if active/service) / 0 for links + * node file system status (active/service only) / 0 for links + * node service1 status / traffic load from that service for links + * node service2 status / traffic load from that service for links + * ... + * node serviceN status / traffic load from that service for links + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(12, 7)`` + + :return: Box gym observation + :rtype: gym.spaces.Box + :return: Initial observation with all entires set to 0 + :rtype: numpy.Array + """ + _LOGGER.info("Observation space type BOX selected") + + # 1. Determine observation shape from laydown + num_items = self.num_links + self.num_nodes + num_observation_parameters = ( + self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS + ) + observation_shape = (num_items, num_observation_parameters) + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.Box( + low=0, + high=self.OBSERVATION_SPACE_HIGH_VALUE, + shape=observation_shape, + dtype=np.int64, + ) + initial_observation = np.zeros(observation_shape, dtype=np.int64) + + return observation_space, initial_observation + + def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]: + """Initialise the observation space with the MULTIDISCRETE option chosen. + + This will create the observation space with node observations followed by link observations. + Each node has 3 elements in the observation space plus 1 per service, more specifically: + * hardware state + * operating system state + * file system state + * service states (one per service) + Each link has one element in the observation space, corresponding to the traffic load, + it can take the following values: + 0 = No traffic (0% of bandwidth) + 1 = No traffic (0%-33% of bandwidth) + 2 = No traffic (33%-66% of bandwidth) + 3 = No traffic (66%-100% of bandwidth) + 4 = No traffic (100% of bandwidth) + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(37,)`` + + :return: MultiDiscrete gym observation + :rtype: gym.spaces.MultiDiscrete + :return: Initial observation with all entires set to 0 + :rtype: numpy.Array + """ + _LOGGER.info("Observation space MULTIDISCRETE selected") + + # 1. Determine observation shape from laydown + node_obs_shape = [ + len(HardwareState) + 1, + len(SoftwareState) + 1, + len(FileSystemState) + 1, + ] + node_services = [len(SoftwareState) + 1] * self.num_services + node_obs_shape = node_obs_shape + node_services + # the magic number 5 refers to 5 states of quantisation of traffic amount. + # (zero, low, medium, high, fully utilised/overwhelmed) + link_obs_shape = [5] * self.num_links + observation_shape = node_obs_shape * self.num_nodes + link_obs_shape + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.MultiDiscrete(observation_shape) + initial_observation = np.zeros(len(observation_shape), dtype=np.int64) + + return observation_space, initial_observation + def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: """Build the observation space based on network laydown and provide initial obs. @@ -648,163 +737,154 @@ class Primaite(Env): `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` attributes to figure out the correct shape and format for the observation space. + :raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType` :return: Gym observation space :rtype: gym.spaces.Space :return: Initial observation with all entires set to 0 :rtype: numpy.Array """ if self.observation_type == ObservationType.BOX: - _LOGGER.info("Observation space type BOX selected") - - # 1. Determine observation shape from laydown - num_items = self.num_links + self.num_nodes - num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - observation_shape = (num_items, num_observation_parameters) - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.Box( - low=0, - high=self.OBSERVATION_SPACE_HIGH_VALUE, - shape=observation_shape, - dtype=np.int64, - ) - initial_observation = np.zeros(observation_shape, dtype=np.int64) - + observation_space, initial_observation = self._init_box_observations() + return observation_space, initial_observation elif self.observation_type == ObservationType.MULTIDISCRETE: - _LOGGER.info("Observation space MULTIDISCRETE selected") - - # 1. Determine observation shape from laydown - node_obs_shape = [ - len(HardwareState) + 1, - len(SoftwareState) + 1, - len(FileSystemState) + 1, - ] - node_services = [len(SoftwareState) + 1] * self.num_services - node_obs_shape = node_obs_shape + node_services - # the magic number 5 refers to 5 states of quantisation of traffic amount. - # (zero, low, medium, high, fully utilised/overwhelmed) - link_obs_shape = [5] * self.num_links - observation_shape = node_obs_shape * self.num_nodes + link_obs_shape - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.MultiDiscrete(observation_shape) - initial_observation = np.zeros(len(observation_shape), dtype=np.int64) + ( + observation_space, + initial_observation, + ) = self._init_multidiscrete_observations() + return observation_space, initial_observation else: - raise ValueError( + errmsg = ( f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" f", got {self.observation_type} instead" ) + _LOGGER.error(errmsg) + raise ValueError(errmsg) - return observation_space, initial_observation + def _update_env_obs_box(self): + """Update the environment's observation state based on the current status of nodes and links. + + This function can only be called if the observation space setting is set to BOX. + + :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` + """ + assert self.observation_type == ObservationType.BOX + item_index = 0 + + # Do nodes first + for node_key, node in self.nodes.items(): + self.env_obs[item_index][0] = int(node.node_id) + self.env_obs[item_index][1] = node.hardware_state.value + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + self.env_obs[item_index][2] = node.software_state.value + self.env_obs[item_index][3] = node.file_system_state_observed.value + else: + self.env_obs[item_index][2] = 0 + self.env_obs[item_index][3] = 0 + service_index = 4 + if isinstance(node, ServiceNode): + for service in self.services_list: + if node.has_service(service): + self.env_obs[item_index][ + service_index + ] = node.get_service_state(service).value + else: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + else: + # Not a service node + for service in self.services_list: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + item_index += 1 + + # Now do links + for link_key, link in self.links.items(): + self.env_obs[item_index][0] = int(link.get_id()) + self.env_obs[item_index][1] = 0 + self.env_obs[item_index][2] = 0 + self.env_obs[item_index][3] = 0 + protocol_list = link.get_protocol_list() + protocol_index = 0 + for protocol in protocol_list: + self.env_obs[item_index][protocol_index + 4] = protocol.get_load() + protocol_index += 1 + item_index += 1 + + def _update_env_obs_multidiscrete(self): + """Update the environment's observation state based on the current status of nodes and links. + + This function can only be called if the observation space setting is set to MULTIDISCRETE. + + :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` + """ + assert self.observation_type == ObservationType.MULTIDISCRETE + obs = [] + # 1. Set nodes + # Each node has the following variables in the observation space: + # - Hardware state + # - Software state + # - File System state + # - Service 1 state + # - Service 2 state + # - ... + # - Service N state + for node_key, node in self.nodes.items(): + hardware_state = node.hardware_state.value + software_state = 0 + file_system_state = 0 + services_states = [0] * self.num_services + + if isinstance( + node, ActiveNode + ): # ServiceNode is a subclass of ActiveNode so no need to check that also + software_state = node.software_state.value + file_system_state = node.file_system_state_observed.value + + if isinstance(node, ServiceNode): + for i, service in enumerate(self.services_list): + if node.has_service(service): + services_states[i] = node.get_service_state(service).value + + obs.extend( + [ + hardware_state, + software_state, + file_system_state, + *services_states, + ] + ) + + # 2. Set links + # Each link has just one variable in the observation space, it represents the traffic amount + # In order for the space to be fully MultiDiscrete, the amount of + # traffic on each link is quantised into a few levels: + # 0: no traffic (0% of bandwidth) + # 1: low traffic (0-33% of bandwidth) + # 2: medium traffic (33-66% of bandwidth) + # 3: high traffic (66-100% of bandwidth) + # 4: max traffic/overloaded (100% of bandwidth) + + for link_key, link in self.links.items(): + bandwidth = link.bandwidth + load = link.get_current_load() + + if load <= 0: + traffic_level = 0 + elif load >= bandwidth: + traffic_level = 4 + else: + traffic_level = (load / bandwidth) // (1 / 3) + 1 + + obs.append(int(traffic_level)) + + self.env_obs = np.asarray(obs) def update_environent_obs(self): """Updates the observation space based on the node and link status.""" if self.observation_type == ObservationType.BOX: - item_index = 0 - - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 - - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 - + self._update_env_obs_box() elif self.observation_type == ObservationType.MULTIDISCRETE: - obs = [] - # 1. Set nodes - # Each node has the following variables in the observation space: - # - Hardware state - # - Software state - # - File System state - # - Service 1 state - # - Service 2 state - # - ... - # - Service N state - for node_key, node in self.nodes.items(): - hardware_state = node.hardware_state.value - software_state = 0 - file_system_state = 0 - services_states = [0] * self.num_services - - if isinstance( - node, ActiveNode - ): # ServiceNode is a subclass of ActiveNode so no need to check that also - software_state = node.software_state.value - file_system_state = node.file_system_state_observed.value - - if isinstance(node, ServiceNode): - for i, service in enumerate(self.services_list): - if node.has_service(service): - services_states[i] = node.get_service_state(service).value - - obs.extend( - [ - hardware_state, - software_state, - file_system_state, - *services_states, - ] - ) - - # 2. Set links - # Each link has just one variable in the observation space, it represents the traffic amount - # In order for the space to be fully MultiDiscrete, the amount of - # traffic on each link is quantised into a few levels: - # 0: no traffic (0% of bandwidth) - # 1: low traffic (0-33% of bandwidth) - # 2: medium traffic (33-66% of bandwidth) - # 3: high traffic (66-100% of bandwidth) - # 4: max traffic/overloaded (100% of bandwidth) - - for link_key, link in self.links.items(): - bandwidth = link.bandwidth - load = link.get_current_load() - - if load <= 0: - traffic_level = 0 - elif load >= bandwidth: - traffic_level = 4 - else: - traffic_level = (load / bandwidth) // (1 / 3) + 1 - - obs.append(int(traffic_level)) - - self.env_obs = np.asarray(obs) + self._update_env_obs_multidiscrete() def load_config(self): """Loads config data in order to build the environment configuration."""