From 6e58c01e8d0dd8894ba8ce32c82633e2fa050112 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 May 2023 17:03:53 +0100 Subject: [PATCH 01/16] Start creating observations module --- src/primaite/environment/observations.py | 227 +++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 src/primaite/environment/observations.py diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py new file mode 100644 index 00000000..40cc26a5 --- /dev/null +++ b/src/primaite/environment/observations.py @@ -0,0 +1,227 @@ +# """Module for handling configurable observation spaces in PrimAITE.""" +# import logging +# from abc import ABC, abstractmethod +# from enum import Enum + +# import numpy as np +# from gym import spaces + +# from primaite.common.enums import FileSystemState, HardwareState, SoftwareState +# from primaite.environment.primaite_env import Primaite +# from primaite.nodes.active_node import ActiveNode +# from primaite.nodes.service_node import ServiceNode + +# _LOGGER = logging.getLogger(__name__) + + +# class AbstractObservationComponent(ABC): +# """Represents a part of the PrimAITE observation space.""" +# @abstractmethod +# def __init__(self, env: Primaite): +# _LOGGER.info(f"Initialising {self} observation component") +# self.env: Primaite = env +# self.space: spaces.Space +# self.current_observation: np.ndarray # type might be too restrictive? +# return NotImplemented + +# @abstractmethod +# def update(self): +# """Look at the environment and update the current observation value""" +# self.current_observation = NotImplemented + +# # @abstractmethod +# # def export(self): +# # return NotImplemented + + +# class NodeLinkTable(AbstractObservationComponent): +# """Table with nodes/links as rows and hardware/software status as cols. + +# #todo: write full description + +# """ + +# _FIXED_PARAMETERS = 4 +# _MAX_VAL = 1_000_000 +# _DATA_TYPE = np.int64 + +# def __init__(self, env: Primaite): +# super().__init__() + +# # 1. Define the shape of your observation space component +# num_items = self.env.num_links + self.env.num_nodes +# num_columns = self.env.num_services + self._FIXED_PARAMETERS +# observation_shape = (num_items, num_columns) + +# # 2. Create Observation space +# self.space = spaces.Box( +# low=0, +# high=self._MAX_VAL, +# shape=observation_shape, +# dtype=self._DATA_TYPE, +# ) + +# # 3. Initialise Observation with zeroes +# self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) + +# def update_obs(self): +# item_index = 0 +# nodes = self.env.nodes +# links = self.env.links +# # Do nodes first +# for _, node in nodes.items(): +# self.current_observation[item_index][0] = int(node.node_id) +# self.current_observation[item_index][1] = node.hardware_state.value +# if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): +# self.current_observation[item_index][2] = node.software_state.value +# self.current_observation[item_index][ +# 3 +# ] = node.file_system_state_observed.value +# else: +# self.current_observation[item_index][2] = 0 +# self.current_observation[item_index][3] = 0 +# service_index = 4 +# if isinstance(node, ServiceNode): +# for service in self.env.services_list: +# if node.has_service(service): +# self.current_observation[item_index][ +# service_index +# ] = node.get_service_state(service).value +# else: +# self.current_observation[item_index][service_index] = 0 +# service_index += 1 +# else: +# # Not a service node +# for service in self.env.services_list: +# self.current_observation[item_index][service_index] = 0 +# service_index += 1 +# item_index += 1 + +# # Now do links +# for _, link in links.items(): +# self.current_observation[item_index][0] = int(link.get_id()) +# self.current_observation[item_index][1] = 0 +# self.current_observation[item_index][2] = 0 +# self.current_observation[item_index][3] = 0 +# protocol_list = link.get_protocol_list() +# protocol_index = 0 +# for protocol in protocol_list: +# self.current_observation[item_index][ +# protocol_index + 4 +# ] = protocol.get_load() +# protocol_index += 1 +# item_index += 1 + + +# class NodeStatuses(AbstractObservationComponent): +# _DATA_TYPE = np.int64 + +# def __init__(self): +# super().__init__() + +# # 1. Define the shape of your observation space component +# shape = [ +# len(HardwareState) + 1, +# len(SoftwareState) + 1, +# len(FileSystemState) + 1, +# ] +# services_shape = [len(SoftwareState) + 1] * self.env.num_services +# shape = shape + services_shape + +# # 2. Create Observation space +# self.space = spaces.MultiDiscrete(shape) + +# # 3. Initialise observation with zeroes +# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + +# def update_obs(self): +# obs = [] +# for _, node in self.env.nodes.items(): +# hardware_state = node.hardware_state.value +# software_state = 0 +# file_system_state = 0 +# service_states = [0] * self.env.num_services + +# if isinstance(node, ActiveNode): +# software_state = node.software_state.value +# file_system_state = node.file_system_state_observed.value + +# if isinstance(node, ServiceNode): +# for i, service in enumerate(self.env.services_list): +# if node.has_service(service): +# service_states[i] = node.get_service_state(service).value +# obs.extend([hardware_state, software_state, file_system_state, *service_states]) +# self.current_observation[:] = obs + + +# class LinkTrafficLevels(AbstractObservationComponent): +# _DATA_TYPE = np.int64 + +# def __init__( +# self, combine_service_traffic: bool = False, quantisation_levels: int = 5 +# ): +# super().__init__() +# self._combine_service_traffic: bool = combine_service_traffic +# self._quantisation_levels: int = quantisation_levels +# self._entries_per_link: int = 1 + +# if not self._combine_service_traffic: +# self._entries_per_link = self.env.num_services + +# # 1. Define the shape of your observation space component +# shape = ( +# [self._quantisation_levels] * self.env.num_links * self._entries_per_link +# ) + +# # 2. Create Observation space +# self.space = spaces.MultiDiscrete(shape) + +# # 3. Initialise observation with zeroes +# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + +# def update_obs(self): +# obs = [] +# for _, link in self.env.links.items(): +# bandwidth = link.bandwidth +# if self._combine_service_traffic: +# loads = [link.get_current_load()] +# else: +# loads = [protocol.get_load() for protocol in link.protocol_list] + +# for load in loads: +# if load <= 0: +# traffic_level = 0 +# elif load >= bandwidth: +# traffic_level = self._quantisation_levels - 1 +# else: +# traffic_level = (load / bandwidth) // ( +# 1 / (self._quantisation_levels - 1) +# ) + 1 + +# obs.append(int(traffic_level)) + +# self.current_observation[:] = obs + + +# class ObservationsHandler: +# class registry(Enum): +# NODE_LINK_TABLE: NodeLinkTable +# NODE_STATUSES: NodeStatuses +# LINK_TRAFFIC_LEVELS: LinkTrafficLevels + +# def __init__(self): +# ... +# # i can access the registry items like this: +# # self.registry.LINK_TRAFFIC_LEVELS + +# def update_obs(self): +# ... + +# def register(self): +# ... + +# def deregister(self, observation: AbstractObservationComponent): +# ... + +# def export(self): +# ... From 46352ff9c2d9ae4c768c5b3d6aa52daaffbbd49e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 13:28:40 +0100 Subject: [PATCH 02/16] Integrate observation handler with components --- src/primaite/environment/observations.py | 393 ++++++++++++----------- 1 file changed, 211 insertions(+), 182 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 40cc26a5..338c11a1 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,227 +1,256 @@ -# """Module for handling configurable observation spaces in PrimAITE.""" -# import logging -# from abc import ABC, abstractmethod -# from enum import Enum +"""Module for handling configurable observation spaces in PrimAITE.""" +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Tuple -# import numpy as np -# from gym import spaces +import numpy as np +from gym import spaces -# from primaite.common.enums import FileSystemState, HardwareState, SoftwareState -# from primaite.environment.primaite_env import Primaite -# from primaite.nodes.active_node import ActiveNode -# from primaite.nodes.service_node import ServiceNode +from primaite.common.enums import FileSystemState, HardwareState, SoftwareState +from primaite.environment.primaite_env import Primaite +from primaite.nodes.active_node import ActiveNode +from primaite.nodes.service_node import ServiceNode -# _LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -# class AbstractObservationComponent(ABC): -# """Represents a part of the PrimAITE observation space.""" -# @abstractmethod -# def __init__(self, env: Primaite): -# _LOGGER.info(f"Initialising {self} observation component") -# self.env: Primaite = env -# self.space: spaces.Space -# self.current_observation: np.ndarray # type might be too restrictive? -# return NotImplemented +class AbstractObservationComponent(ABC): + """Represents a part of the PrimAITE observation space.""" -# @abstractmethod -# def update(self): -# """Look at the environment and update the current observation value""" -# self.current_observation = NotImplemented + @abstractmethod + def __init__(self, env: Primaite): + _LOGGER.info(f"Initialising {self} observation component") + self.env: Primaite = env + self.space: spaces.Space + self.current_observation: np.ndarray # type might be too restrictive? + return NotImplemented -# # @abstractmethod -# # def export(self): -# # return NotImplemented + @abstractmethod + def update(self): + """Look at the environment and update the current observation value.""" + self.current_observation = NotImplemented -# class NodeLinkTable(AbstractObservationComponent): -# """Table with nodes/links as rows and hardware/software status as cols. +class NodeLinkTable(AbstractObservationComponent): + """Table with nodes/links as rows and hardware/software status as cols. -# #todo: write full description + #todo: write full description -# """ + """ -# _FIXED_PARAMETERS = 4 -# _MAX_VAL = 1_000_000 -# _DATA_TYPE = np.int64 + _FIXED_PARAMETERS = 4 + _MAX_VAL = 1_000_000 + _DATA_TYPE = np.int64 -# def __init__(self, env: Primaite): -# super().__init__() + def __init__(self, env: Primaite): + super().__init__() -# # 1. Define the shape of your observation space component -# num_items = self.env.num_links + self.env.num_nodes -# num_columns = self.env.num_services + self._FIXED_PARAMETERS -# observation_shape = (num_items, num_columns) + # 1. Define the shape of your observation space component + num_items = self.env.num_links + self.env.num_nodes + num_columns = self.env.num_services + self._FIXED_PARAMETERS + observation_shape = (num_items, num_columns) -# # 2. Create Observation space -# self.space = spaces.Box( -# low=0, -# high=self._MAX_VAL, -# shape=observation_shape, -# dtype=self._DATA_TYPE, -# ) + # 2. Create Observation space + self.space = spaces.Box( + low=0, + high=self._MAX_VAL, + shape=observation_shape, + dtype=self._DATA_TYPE, + ) -# # 3. Initialise Observation with zeroes -# self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) + # 3. Initialise Observation with zeroes + self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) -# def update_obs(self): -# item_index = 0 -# nodes = self.env.nodes -# links = self.env.links -# # Do nodes first -# for _, node in nodes.items(): -# self.current_observation[item_index][0] = int(node.node_id) -# self.current_observation[item_index][1] = node.hardware_state.value -# if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): -# self.current_observation[item_index][2] = node.software_state.value -# self.current_observation[item_index][ -# 3 -# ] = node.file_system_state_observed.value -# else: -# self.current_observation[item_index][2] = 0 -# self.current_observation[item_index][3] = 0 -# service_index = 4 -# if isinstance(node, ServiceNode): -# for service in self.env.services_list: -# if node.has_service(service): -# self.current_observation[item_index][ -# service_index -# ] = node.get_service_state(service).value -# else: -# self.current_observation[item_index][service_index] = 0 -# service_index += 1 -# else: -# # Not a service node -# for service in self.env.services_list: -# self.current_observation[item_index][service_index] = 0 -# service_index += 1 -# item_index += 1 + def update_obs(self): + """Update the observation. -# # Now do links -# for _, link in links.items(): -# self.current_observation[item_index][0] = int(link.get_id()) -# self.current_observation[item_index][1] = 0 -# self.current_observation[item_index][2] = 0 -# self.current_observation[item_index][3] = 0 -# protocol_list = link.get_protocol_list() -# protocol_index = 0 -# for protocol in protocol_list: -# self.current_observation[item_index][ -# protocol_index + 4 -# ] = protocol.get_load() -# protocol_index += 1 -# item_index += 1 + todo: complete description.. + """ + item_index = 0 + nodes = self.env.nodes + links = self.env.links + # Do nodes first + for _, node in nodes.items(): + self.current_observation[item_index][0] = int(node.node_id) + self.current_observation[item_index][1] = node.hardware_state.value + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + self.current_observation[item_index][2] = node.software_state.value + self.current_observation[item_index][ + 3 + ] = node.file_system_state_observed.value + else: + self.current_observation[item_index][2] = 0 + self.current_observation[item_index][3] = 0 + service_index = 4 + if isinstance(node, ServiceNode): + for service in self.env.services_list: + if node.has_service(service): + self.current_observation[item_index][ + service_index + ] = node.get_service_state(service).value + else: + self.current_observation[item_index][service_index] = 0 + service_index += 1 + else: + # Not a service node + for service in self.env.services_list: + self.current_observation[item_index][service_index] = 0 + service_index += 1 + item_index += 1 + + # Now do links + for _, link in links.items(): + self.current_observation[item_index][0] = int(link.get_id()) + self.current_observation[item_index][1] = 0 + self.current_observation[item_index][2] = 0 + self.current_observation[item_index][3] = 0 + protocol_list = link.get_protocol_list() + protocol_index = 0 + for protocol in protocol_list: + self.current_observation[item_index][ + protocol_index + 4 + ] = protocol.get_load() + protocol_index += 1 + item_index += 1 -# class NodeStatuses(AbstractObservationComponent): -# _DATA_TYPE = np.int64 +class NodeStatuses(AbstractObservationComponent): + """todo: complete description.""" -# def __init__(self): -# super().__init__() + _DATA_TYPE = np.int64 -# # 1. Define the shape of your observation space component -# shape = [ -# len(HardwareState) + 1, -# len(SoftwareState) + 1, -# len(FileSystemState) + 1, -# ] -# services_shape = [len(SoftwareState) + 1] * self.env.num_services -# shape = shape + services_shape + def __init__(self): + super().__init__() -# # 2. Create Observation space -# self.space = spaces.MultiDiscrete(shape) + # 1. Define the shape of your observation space component + shape = [ + len(HardwareState) + 1, + len(SoftwareState) + 1, + len(FileSystemState) + 1, + ] + services_shape = [len(SoftwareState) + 1] * self.env.num_services + shape = shape + services_shape -# # 3. Initialise observation with zeroes -# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + # 2. Create Observation space + self.space = spaces.MultiDiscrete(shape) -# def update_obs(self): -# obs = [] -# for _, node in self.env.nodes.items(): -# hardware_state = node.hardware_state.value -# software_state = 0 -# file_system_state = 0 -# service_states = [0] * self.env.num_services + # 3. Initialise observation with zeroes + self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) -# if isinstance(node, ActiveNode): -# software_state = node.software_state.value -# file_system_state = node.file_system_state_observed.value + def update_obs(self): + """todo: complete description.""" + obs = [] + for _, node in self.env.nodes.items(): + hardware_state = node.hardware_state.value + software_state = 0 + file_system_state = 0 + service_states = [0] * self.env.num_services -# if isinstance(node, ServiceNode): -# for i, service in enumerate(self.env.services_list): -# if node.has_service(service): -# service_states[i] = node.get_service_state(service).value -# obs.extend([hardware_state, software_state, file_system_state, *service_states]) -# self.current_observation[:] = obs + if isinstance(node, ActiveNode): + software_state = node.software_state.value + file_system_state = node.file_system_state_observed.value + + if isinstance(node, ServiceNode): + for i, service in enumerate(self.env.services_list): + if node.has_service(service): + service_states[i] = node.get_service_state(service).value + obs.extend([hardware_state, software_state, file_system_state, *service_states]) + self.current_observation[:] = obs -# class LinkTrafficLevels(AbstractObservationComponent): -# _DATA_TYPE = np.int64 +class LinkTrafficLevels(AbstractObservationComponent): + """todo: complete description.""" -# def __init__( -# self, combine_service_traffic: bool = False, quantisation_levels: int = 5 -# ): -# super().__init__() -# self._combine_service_traffic: bool = combine_service_traffic -# self._quantisation_levels: int = quantisation_levels -# self._entries_per_link: int = 1 + _DATA_TYPE = np.int64 -# if not self._combine_service_traffic: -# self._entries_per_link = self.env.num_services + def __init__( + self, combine_service_traffic: bool = False, quantisation_levels: int = 5 + ): + super().__init__() + self._combine_service_traffic: bool = combine_service_traffic + self._quantisation_levels: int = quantisation_levels + self._entries_per_link: int = 1 -# # 1. Define the shape of your observation space component -# shape = ( -# [self._quantisation_levels] * self.env.num_links * self._entries_per_link -# ) + if not self._combine_service_traffic: + self._entries_per_link = self.env.num_services -# # 2. Create Observation space -# self.space = spaces.MultiDiscrete(shape) + # 1. Define the shape of your observation space component + shape = ( + [self._quantisation_levels] * self.env.num_links * self._entries_per_link + ) -# # 3. Initialise observation with zeroes -# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + # 2. Create Observation space + self.space = spaces.MultiDiscrete(shape) -# def update_obs(self): -# obs = [] -# for _, link in self.env.links.items(): -# bandwidth = link.bandwidth -# if self._combine_service_traffic: -# loads = [link.get_current_load()] -# else: -# loads = [protocol.get_load() for protocol in link.protocol_list] + # 3. Initialise observation with zeroes + self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) -# for load in loads: -# if load <= 0: -# traffic_level = 0 -# elif load >= bandwidth: -# traffic_level = self._quantisation_levels - 1 -# else: -# traffic_level = (load / bandwidth) // ( -# 1 / (self._quantisation_levels - 1) -# ) + 1 + def update_obs(self): + """todo: complete description.""" + obs = [] + for _, link in self.env.links.items(): + bandwidth = link.bandwidth + if self._combine_service_traffic: + loads = [link.get_current_load()] + else: + loads = [protocol.get_load() for protocol in link.protocol_list] -# obs.append(int(traffic_level)) + for load in loads: + if load <= 0: + traffic_level = 0 + elif load >= bandwidth: + traffic_level = self._quantisation_levels - 1 + else: + traffic_level = (load / bandwidth) // ( + 1 / (self._quantisation_levels - 1) + ) + 1 -# self.current_observation[:] = obs + obs.append(int(traffic_level)) + + self.current_observation[:] = obs -# class ObservationsHandler: -# class registry(Enum): -# NODE_LINK_TABLE: NodeLinkTable -# NODE_STATUSES: NodeStatuses -# LINK_TRAFFIC_LEVELS: LinkTrafficLevels +class ObservationsHandler: + """todo: complete description.""" -# def __init__(self): -# ... -# # i can access the registry items like this: -# # self.registry.LINK_TRAFFIC_LEVELS + class registry(Enum): + """todo: complete description.""" -# def update_obs(self): -# ... + NODE_LINK_TABLE: NodeLinkTable + NODE_STATUSES: NodeStatuses + LINK_TRAFFIC_LEVELS: LinkTrafficLevels -# def register(self): -# ... + def __init__(self): + """todo: complete description.""" + """Initialise the handler without any components yet. They""" + self.registered_obs_components: List[AbstractObservationComponent] = [] + self.space: spaces.Space + self.current_observation: Tuple[np.ndarray] + # i can access the registry items like this: + # self.registry.LINK_TRAFFIC_LEVELS -# def deregister(self, observation: AbstractObservationComponent): -# ... + def update_obs(self): + """todo: complete description.""" + current_obs = [] + for obs in self.registered_obs_components: + obs.update_obs() + current_obs.append(obs.current_observation) + self.current_observation = tuple(current_obs) -# def export(self): -# ... + def register(self, obs_component: AbstractObservationComponent): + """todo: complete description.""" + self.registered_obs_components.append(obs_component) + self.update_space() + + def deregister(self, obs_component: AbstractObservationComponent): + """todo: complete description.""" + self.registered_obs_components.remove(obs_component) + self.update_space() + + def update_space(self): + """todo: complete description.""" + component_spaces = [] + for obs_comp in self.registered_obs_components: + component_spaces.append(obs_comp.space) + self.space = spaces.Tuple(component_spaces) From 2b25573378436aff5326d8a072c902163c7f870b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 16:42:10 +0100 Subject: [PATCH 03/16] Integrate obs handler with Primaite Env --- src/primaite/config/config_1_DDOS_BASIC.yaml | 8 + src/primaite/environment/observations.py | 90 +++++- src/primaite/environment/primaite_env.py | 282 ++----------------- 3 files changed, 115 insertions(+), 265 deletions(-) diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index ada813f3..a1961df3 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,5 +1,13 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATION_SPACE + components: + - name: NODE_LINK_TABLE + - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + options: + - combine_service_traffic : False + - quantisation_levels : 7 - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 338c11a1..94c2730f 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -35,7 +35,23 @@ class AbstractObservationComponent(ABC): class NodeLinkTable(AbstractObservationComponent): """Table with nodes/links as rows and hardware/software status as cols. - #todo: write full description + Initialise the observation space with the BOX option chosen. + + This will create the observation space formatted as a table of integers. + There is one row per node, followed by one row per link. + Columns are as follows: + * node/link ID + * node hardware status / 0 for links + * node operating system status (if active/service) / 0 for links + * node file system status (active/service only) / 0 for links + * node service1 status / traffic load from that service for links + * node service2 status / traffic load from that service for links + * ... + * node serviceN status / traffic load from that service for links + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(12, 7)`` + #todo: clean up description """ @@ -44,7 +60,7 @@ class NodeLinkTable(AbstractObservationComponent): _DATA_TYPE = np.int64 def __init__(self, env: Primaite): - super().__init__() + super().__init__(env) # 1. Define the shape of your observation space component num_items = self.env.num_links + self.env.num_nodes @@ -65,6 +81,10 @@ class NodeLinkTable(AbstractObservationComponent): def update_obs(self): """Update the observation. + Update the environment's observation state based on the current status of nodes and links. + + The structure of the observation space is described in :func:`~_init_box_observations` + This function can only be called if the observation space setting is set to BOX. todo: complete description.. """ item_index = 0 @@ -116,12 +136,20 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """todo: complete description.""" + """todo: complete description. + + This will create the observation space with node observations followed by link observations. + Each node has 3 elements in the observation space plus 1 per service, more specifically: + * hardware state + * operating system state + * file system state + * service states (one per service) + """ _DATA_TYPE = np.int64 - def __init__(self): - super().__init__() + def __init__(self, env: Primaite): + super().__init__(env) # 1. Define the shape of your observation space component shape = [ @@ -139,7 +167,15 @@ class NodeStatuses(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update_obs(self): - """todo: complete description.""" + """todo: complete description. + + Update the environment's observation state based on the current status of nodes and links. + + The structure of the observation space is described in :func:`~_init_multidiscrete_observations` + This function can only be called if the observation space setting is set to MULTIDISCRETE. + + + """ obs = [] for _, node in self.env.nodes.items(): hardware_state = node.hardware_state.value @@ -160,14 +196,26 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """todo: complete description.""" + """todo: complete description. + + Each link has one element in the observation space, corresponding to the traffic load, + it can take the following values: + 0 = No traffic (0% of bandwidth) + 1 = No traffic (0%-33% of bandwidth) + 2 = No traffic (33%-66% of bandwidth) + 3 = No traffic (66%-100% of bandwidth) + 4 = No traffic (100% of bandwidth) + """ _DATA_TYPE = np.int64 def __init__( - self, combine_service_traffic: bool = False, quantisation_levels: int = 5 + self, + env: Primaite, + combine_service_traffic: bool = False, + quantisation_levels: int = 5, ): - super().__init__() + super().__init__(env) self._combine_service_traffic: bool = combine_service_traffic self._quantisation_levels: int = quantisation_levels self._entries_per_link: int = 1 @@ -212,7 +260,7 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: - """todo: complete description.""" + """Component-based observation space handler.""" class registry(Enum): """todo: complete description.""" @@ -254,3 +302,25 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) self.space = spaces.Tuple(component_spaces) + + @classmethod + def from_config(cls, obs_space_config): + """todo: complete description. + + This method parses config items related to the observation space, then + creates the necessary components and adds them to the observation handler. + """ + # Instantiate the handler + handler = cls() + + for component_cfg in obs_space_config["components"]: + # Figure out which class can instantiate the desired component + comp_type = component_cfg["name"] + comp_class = cls.registry[comp_type].value + + # Create the component with options from the YAML + component = comp_class(**component_cfg["options"]) + + handler.register(component) + + return handler diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 56893ee9..afa04060 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,7 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import networkx as nx import numpy as np @@ -23,11 +23,11 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, - ObservationType, Priority, SoftwareState, ) from primaite.common.service import Service +from primaite.environment.observations import ObservationsHandler from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode @@ -149,8 +149,8 @@ class Primaite(Env): # The action type self.action_type = 0 - # Observation type, by default box. - self.observation_type = ObservationType.BOX + # todo: proper description here + self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown try: @@ -161,6 +161,10 @@ class Primaite(Env): _LOGGER.error("Could not load the environment configuration") _LOGGER.error("Exception occured", exc_info=True) + # If it doesn't exist after parsing config, create default obs space. + if self.get("obs_handler") is None: + self.configure_obs_space() + # Store the node objects as node attributes # (This is so we can access them as objects) for node in self.network: @@ -641,252 +645,17 @@ class Primaite(Env): else: pass - def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the BOX option chosen. - - This will create the observation space formatted as a table of integers. - There is one row per node, followed by one row per link. - Columns are as follows: - * node/link ID - * node hardware status / 0 for links - * node operating system status (if active/service) / 0 for links - * node file system status (active/service only) / 0 for links - * node service1 status / traffic load from that service for links - * node service2 status / traffic load from that service for links - * ... - * node serviceN status / traffic load from that service for links - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(12, 7)`` - - :return: Box gym observation - :rtype: gym.spaces.Box - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space type BOX selected") - - # 1. Determine observation shape from laydown - num_items = self.num_links + self.num_nodes - num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - observation_shape = (num_items, num_observation_parameters) - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.Box( - low=0, - high=self.OBSERVATION_SPACE_HIGH_VALUE, - shape=observation_shape, - dtype=np.int64, - ) - initial_observation = np.zeros(observation_shape, dtype=np.int64) - - return observation_space, initial_observation - - def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the MULTIDISCRETE option chosen. - - This will create the observation space with node observations followed by link observations. - Each node has 3 elements in the observation space plus 1 per service, more specifically: - * hardware state - * operating system state - * file system state - * service states (one per service) - Each link has one element in the observation space, corresponding to the traffic load, - it can take the following values: - 0 = No traffic (0% of bandwidth) - 1 = No traffic (0%-33% of bandwidth) - 2 = No traffic (33%-66% of bandwidth) - 3 = No traffic (66%-100% of bandwidth) - 4 = No traffic (100% of bandwidth) - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(37,)`` - - :return: MultiDiscrete gym observation - :rtype: gym.spaces.MultiDiscrete - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space MULTIDISCRETE selected") - - # 1. Determine observation shape from laydown - node_obs_shape = [ - len(HardwareState) + 1, - len(SoftwareState) + 1, - len(FileSystemState) + 1, - ] - node_services = [len(SoftwareState) + 1] * self.num_services - node_obs_shape = node_obs_shape + node_services - # the magic number 5 refers to 5 states of quantisation of traffic amount. - # (zero, low, medium, high, fully utilised/overwhelmed) - link_obs_shape = [5] * self.num_links - observation_shape = node_obs_shape * self.num_nodes + link_obs_shape - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.MultiDiscrete(observation_shape) - initial_observation = np.zeros(len(observation_shape), dtype=np.int64) - - return observation_space, initial_observation - def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Build the observation space based on network laydown and provide initial obs. - - This method uses the object's `num_links`, `num_nodes`, `num_services`, - `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` - attributes to figure out the correct shape and format for the observation space. - - :raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType` - :return: Gym observation space - :rtype: gym.spaces.Space - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - if self.observation_type == ObservationType.BOX: - observation_space, initial_observation = self._init_box_observations() - return observation_space, initial_observation - elif self.observation_type == ObservationType.MULTIDISCRETE: - ( - observation_space, - initial_observation, - ) = self._init_multidiscrete_observations() - return observation_space, initial_observation - else: - errmsg = ( - f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" - f", got {self.observation_type} instead" - ) - _LOGGER.error(errmsg) - raise ValueError(errmsg) - - def _update_env_obs_box(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_box_observations` - This function can only be called if the observation space setting is set to BOX. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.BOX - item_index = 0 - - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 - - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 - - def _update_env_obs_multidiscrete(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_multidiscrete_observations` - This function can only be called if the observation space setting is set to MULTIDISCRETE. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.MULTIDISCRETE - obs = [] - # 1. Set nodes - # Each node has the following variables in the observation space: - # - Hardware state - # - Software state - # - File System state - # - Service 1 state - # - Service 2 state - # - ... - # - Service N state - for node_key, node in self.nodes.items(): - hardware_state = node.hardware_state.value - software_state = 0 - file_system_state = 0 - services_states = [0] * self.num_services - - if isinstance( - node, ActiveNode - ): # ServiceNode is a subclass of ActiveNode so no need to check that also - software_state = node.software_state.value - file_system_state = node.file_system_state_observed.value - - if isinstance(node, ServiceNode): - for i, service in enumerate(self.services_list): - if node.has_service(service): - services_states[i] = node.get_service_state(service).value - - obs.extend( - [ - hardware_state, - software_state, - file_system_state, - *services_states, - ] - ) - - # 2. Set links - # Each link has just one variable in the observation space, it represents the traffic amount - # In order for the space to be fully MultiDiscrete, the amount of - # traffic on each link is quantised into a few levels: - # 0: no traffic (0% of bandwidth) - # 1: low traffic (0-33% of bandwidth) - # 2: medium traffic (33-66% of bandwidth) - # 3: high traffic (66-100% of bandwidth) - # 4: max traffic/overloaded (100% of bandwidth) - - for link_key, link in self.links.items(): - bandwidth = link.bandwidth - load = link.get_current_load() - - if load <= 0: - traffic_level = 0 - elif load >= bandwidth: - traffic_level = 4 - else: - traffic_level = (load / bandwidth) // (1 / 3) + 1 - - obs.append(int(traffic_level)) - - self.env_obs = np.asarray(obs) + """todo: write docstring.""" + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): - """Updates the observation space based on the node and link status.""" - if self.observation_type == ObservationType.BOX: - self._update_env_obs_box() - elif self.observation_type == ObservationType.MULTIDISCRETE: - self._update_env_obs_multidiscrete() + """Updates the observation space based on the node and link status. + + todo: better docstring + """ + self.obs_handler.update_obs() + self.env_obs = self.obs_handler.current_observation def load_config(self): """Loads config data in order to build the environment configuration.""" @@ -921,9 +690,9 @@ class Primaite(Env): elif item["itemType"] == "ACTIONS": # Get the action information self.get_action_info(item) - elif item["itemType"] == "OBSERVATIONS": + elif item["itemType"] == "OBSERVATION_SPACE": # Get the observation information - self.get_observation_info(item) + self.configure_obs_space(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1256,13 +1025,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def get_observation_info(self, observation_info): - """Extracts observation_info. + def configure_obs_space(self, observation_config: Optional[Dict] = None): + """todo: better docstring.""" + if observation_config is None: + observation_config = { + "components": [ + {"name": "NODE_LINK_TABLE"}, + ] + } - :param observation_info: Config item that defines which type of observation space to use - :type observation_info: str - """ - self.observation_type = ObservationType[observation_info["type"]] + self.obs_handler = ObservationsHandler[observation_config] def get_steps_info(self, steps_info): """ From 7041b79d2ab27ea1276a9b675d79329f16fd4ad9 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 17:42:35 +0100 Subject: [PATCH 04/16] Fix trying to init obs before building network --- src/primaite/environment/observations.py | 74 +++++++++++++----------- src/primaite/environment/primaite_env.py | 39 +++++++------ 2 files changed, 61 insertions(+), 52 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 94c2730f..a1b0d9ac 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,17 +1,22 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from enum import Enum -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Tuple import numpy as np from gym import spaces from primaite.common.enums import FileSystemState, HardwareState, SoftwareState -from primaite.environment.primaite_env import Primaite from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode +# This dependency is only needed for type hints, +# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking +# Therefore, this avoids circular dependency problem. +if TYPE_CHECKING: + from primaite.environment.primaite_env import Primaite + + _LOGGER = logging.getLogger(__name__) @@ -19,9 +24,9 @@ class AbstractObservationComponent(ABC): """Represents a part of the PrimAITE observation space.""" @abstractmethod - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): _LOGGER.info(f"Initialising {self} observation component") - self.env: Primaite = env + self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? return NotImplemented @@ -51,7 +56,7 @@ class NodeLinkTable(AbstractObservationComponent): For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be ``(12, 7)`` - #todo: clean up description + #TODO: clean up description """ @@ -59,7 +64,7 @@ class NodeLinkTable(AbstractObservationComponent): _MAX_VAL = 1_000_000 _DATA_TYPE = np.int64 - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): super().__init__(env) # 1. Define the shape of your observation space component @@ -76,16 +81,16 @@ class NodeLinkTable(AbstractObservationComponent): ) # 3. Initialise Observation with zeroes - self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) + self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) - def update_obs(self): + def update(self): """Update the observation. Update the environment's observation state based on the current status of nodes and links. The structure of the observation space is described in :func:`~_init_box_observations` This function can only be called if the observation space setting is set to BOX. - todo: complete description.. + TODO: complete description.. """ item_index = 0 nodes = self.env.nodes @@ -136,7 +141,7 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """todo: complete description. + """TODO: complete description. This will create the observation space with node observations followed by link observations. Each node has 3 elements in the observation space plus 1 per service, more specifically: @@ -148,7 +153,7 @@ class NodeStatuses(AbstractObservationComponent): _DATA_TYPE = np.int64 - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): super().__init__(env) # 1. Define the shape of your observation space component @@ -166,8 +171,8 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - def update_obs(self): - """todo: complete description. + def update(self): + """TODO: complete description. Update the environment's observation state based on the current status of nodes and links. @@ -196,7 +201,7 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """todo: complete description. + """TODO: complete description. Each link has one element in the observation space, corresponding to the traffic load, it can take the following values: @@ -211,7 +216,7 @@ class LinkTrafficLevels(AbstractObservationComponent): def __init__( self, - env: Primaite, + env: "Primaite", combine_service_traffic: bool = False, quantisation_levels: int = 5, ): @@ -234,8 +239,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - def update_obs(self): - """todo: complete description.""" + def update(self): + """TODO: complete description.""" obs = [] for _, link in self.env.links.items(): bandwidth = link.bandwidth @@ -262,15 +267,14 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: """Component-based observation space handler.""" - class registry(Enum): - """todo: complete description.""" - - NODE_LINK_TABLE: NodeLinkTable - NODE_STATUSES: NodeStatuses - LINK_TRAFFIC_LEVELS: LinkTrafficLevels + registry = { + "NODE_LINK_TABLE": NodeLinkTable, + "NODE_STATUSES": NodeStatuses, + "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, + } def __init__(self): - """todo: complete description.""" + """TODO: complete description.""" """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space @@ -279,33 +283,33 @@ class ObservationsHandler: # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): - """todo: complete description.""" + """TODO: complete description.""" current_obs = [] for obs in self.registered_obs_components: - obs.update_obs() + obs.update() current_obs.append(obs.current_observation) self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): - """todo: complete description.""" + """TODO: complete description.""" self.registered_obs_components.append(obs_component) self.update_space() def deregister(self, obs_component: AbstractObservationComponent): - """todo: complete description.""" + """TODO: complete description.""" self.registered_obs_components.remove(obs_component) self.update_space() def update_space(self): - """todo: complete description.""" + """TODO: complete description.""" component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) self.space = spaces.Tuple(component_spaces) @classmethod - def from_config(cls, obs_space_config): - """todo: complete description. + def from_config(cls, env: "Primaite", obs_space_config: dict): + """TODO: complete description. This method parses config items related to the observation space, then creates the necessary components and adds them to the observation handler. @@ -316,11 +320,13 @@ class ObservationsHandler: for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] - comp_class = cls.registry[comp_type].value + comp_class = cls.registry[comp_type] # Create the component with options from the YAML - component = comp_class(**component_cfg["options"]) + options = component_cfg.get("options") or {} + component = comp_class(env, **options) handler.register(component) + handler.update_obs() return handler diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index afa04060..0107920f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -149,7 +149,8 @@ class Primaite(Env): # The action type self.action_type = 0 - # todo: proper description here + # TODO: proper description here + self.obs_config: dict self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown @@ -161,10 +162,6 @@ class Primaite(Env): _LOGGER.error("Could not load the environment configuration") _LOGGER.error("Exception occured", exc_info=True) - # If it doesn't exist after parsing config, create default obs space. - if self.get("obs_handler") is None: - self.configure_obs_space() - # Store the node objects as node attributes # (This is so we can access them as objects) for node in self.network: @@ -195,6 +192,10 @@ class Primaite(Env): _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") + # # If it doesn't exist after parsing config, create default obs space. + # if getattr(self, "obs_handler", None) is None: + # self.configure_obs_space() + # Initiate observation space self.observation_space, self.env_obs = self.init_observations() @@ -646,13 +647,22 @@ class Primaite(Env): pass def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """todo: write docstring.""" + """TODO: write docstring.""" + if getattr(self, "obs_config", None) is None: + self.obs_config = { + "components": [ + {"name": "NODE_LINK_TABLE"}, + ] + } + + self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): """Updates the observation space based on the node and link status. - todo: better docstring + TODO: better docstring """ self.obs_handler.update_obs() self.env_obs = self.obs_handler.current_observation @@ -692,7 +702,7 @@ class Primaite(Env): self.get_action_info(item) elif item["itemType"] == "OBSERVATION_SPACE": # Get the observation information - self.configure_obs_space(item) + self.save_obs_config(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1025,16 +1035,9 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def configure_obs_space(self, observation_config: Optional[Dict] = None): - """todo: better docstring.""" - if observation_config is None: - observation_config = { - "components": [ - {"name": "NODE_LINK_TABLE"}, - ] - } - - self.obs_handler = ObservationsHandler[observation_config] + def save_obs_config(self, obs_config: Optional[Dict] = None): + """TODO: better docstring.""" + self.obs_config = obs_config def get_steps_info(self, steps_info): """ From 3e208bad9bf553acc2eb4157d4bd07da906d5227 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 17:50:18 +0100 Subject: [PATCH 05/16] Better Obs default handling --- src/primaite/environment/primaite_env.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0107920f..81557075 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -47,15 +47,12 @@ _LOGGER = logging.getLogger(__name__) class Primaite(Env): """PRIMmary AI Training Evironment (Primaite) class.""" - # Observation / Action Space contants - OBSERVATION_SPACE_FIXED_PARAMETERS = 4 + # Action Space contants ACTION_SPACE_NODE_PROPERTY_VALUES = 5 ACTION_SPACE_NODE_ACTION_VALUES = 4 ACTION_SPACE_ACL_ACTION_VALUES = 3 ACTION_SPACE_ACL_PERMISSION_VALUES = 2 - OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space - def __init__(self, _config_values, _transaction_list): """ Init. @@ -149,8 +146,11 @@ class Primaite(Env): # The action type self.action_type = 0 - # TODO: proper description here - self.obs_config: dict + # stores the observation config from the yaml, default is NODE_LINK_TABLE + self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]} + + # Observation Handler manages the user-configurable observation space. + # It will be initialised later. self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown @@ -192,10 +192,6 @@ class Primaite(Env): _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") - # # If it doesn't exist after parsing config, create default obs space. - # if getattr(self, "obs_handler", None) is None: - # self.configure_obs_space() - # Initiate observation space self.observation_space, self.env_obs = self.init_observations() @@ -648,13 +644,6 @@ class Primaite(Env): def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: """TODO: write docstring.""" - if getattr(self, "obs_config", None) is None: - self.obs_config = { - "components": [ - {"name": "NODE_LINK_TABLE"}, - ] - } - self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation From c0b214612a6cf4b7290325a55292402544634aa7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 18:01:47 +0100 Subject: [PATCH 06/16] Let single-component spaces not use Tuple Spaces --- src/primaite/environment/observations.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a1b0d9ac..5bad056c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Tuple, Union import numpy as np from gym import spaces @@ -278,7 +278,7 @@ class ObservationsHandler: """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space - self.current_observation: Tuple[np.ndarray] + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] # i can access the registry items like this: # self.registry.LINK_TRAFFIC_LEVELS @@ -288,7 +288,12 @@ class ObservationsHandler: for obs in self.registered_obs_components: obs.update() current_obs.append(obs.current_observation) - self.current_observation = tuple(current_obs) + + # If there is only one component, don't use a tuple, just pass through that component's obs. + if len(current_obs) == 1: + self.current_observation = current_obs[0] + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """TODO: complete description.""" @@ -305,7 +310,12 @@ class ObservationsHandler: component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - self.space = spaces.Tuple(component_spaces) + + # If there is only one component, don't use a tuple space, just pass through that component's space. + if len(component_spaces) == 1: + self.space = component_spaces[0] + else: + self.space = spaces.Tuple(component_spaces) @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): From 484a31d0822d6daa6137b5c456cac52ed39b131e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 21:28:38 +0100 Subject: [PATCH 07/16] Add docstrings to new observation code --- src/primaite/environment/observations.py | 133 ++++++++++++++++------- src/primaite/environment/primaite_env.py | 25 +++-- 2 files changed, 109 insertions(+), 49 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 5bad056c..c4402b69 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -33,18 +33,16 @@ class AbstractObservationComponent(ABC): @abstractmethod def update(self): - """Look at the environment and update the current observation value.""" + """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented class NodeLinkTable(AbstractObservationComponent): - """Table with nodes/links as rows and hardware/software status as cols. - - Initialise the observation space with the BOX option chosen. + """Table with nodes and links as rows and hardware/software status as cols. This will create the observation space formatted as a table of integers. There is one row per node, followed by one row per link. - Columns are as follows: + The number of columns is 4 plus one per service. They are: * node/link ID * node hardware status / 0 for links * node operating system status (if active/service) / 0 for links @@ -56,8 +54,6 @@ class NodeLinkTable(AbstractObservationComponent): For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be ``(12, 7)`` - #TODO: clean up description - """ _FIXED_PARAMETERS = 4 @@ -84,13 +80,9 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) def update(self): - """Update the observation. + """Update the observation based on current environment state. - Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_box_observations` - This function can only be called if the observation space setting is set to BOX. - TODO: complete description.. + The structure of the observation space is described in :class:`.NodeLinkTable` """ item_index = 0 nodes = self.env.nodes @@ -141,14 +133,30 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """TODO: complete description. + """Flat list of nodes' hardware, OS, file system, and service states. - This will create the observation space with node observations followed by link observations. - Each node has 3 elements in the observation space plus 1 per service, more specifically: - * hardware state - * operating system state - * file system state - * service states (one per service) + The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by + integers. + Each node has 3 elements plus 1 per service. It will have the following structure: + .. code-block:: + [ + node1 hardware state, + node1 OS state, + node1 file system state, + node1 service1 state, + node1 service2 state, + node1 serviceN state (one for each service), + node2 hardware state, + node2 OS state, + node2 file system state, + node2 service1 state, + node2 service2 state, + node2 serviceN state (one for each service), + ... + ] + + :param env: The environment that forms the basis of the observations + :type env: Primaite """ _DATA_TYPE = np.int64 @@ -172,14 +180,9 @@ class NodeStatuses(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update(self): - """TODO: complete description. - - Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_multidiscrete_observations` - This function can only be called if the observation space setting is set to MULTIDISCRETE. - + """Update the observation based on current environment state. + The structure of the observation space is described in :class:`.NodeStatuses` """ obs = [] for _, node in self.env.nodes.items(): @@ -201,15 +204,28 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """TODO: complete description. + """Flat list of traffic levels encoded into banded categories. - Each link has one element in the observation space, corresponding to the traffic load, - it can take the following values: + For each link, total traffic or traffic per service is encoded into a categorical value. + For example, if ``quantisation_levels=5``, the traffic levels represent these values: 0 = No traffic (0% of bandwidth) 1 = No traffic (0%-33% of bandwidth) 2 = No traffic (33%-66% of bandwidth) 3 = No traffic (66%-100% of bandwidth) 4 = No traffic (100% of bandwidth) + + .. note:: + The lowest category always corresponds to no traffic and the highest category to the link being at max capacity. + Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories. + + :param env: The environment that forms the basis of the observations + :type env: Primaite + :param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually, + defaults to False + :type combine_service_traffic: bool, optional + :param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value , + defaults to 5 + :type quantisation_levels: int, optional """ _DATA_TYPE = np.int64 @@ -220,7 +236,10 @@ class LinkTrafficLevels(AbstractObservationComponent): combine_service_traffic: bool = False, quantisation_levels: int = 5, ): + assert quantisation_levels >= 3 + super().__init__(env) + self._combine_service_traffic: bool = combine_service_traffic self._quantisation_levels: int = quantisation_levels self._entries_per_link: int = 1 @@ -240,7 +259,10 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update(self): - """TODO: complete description.""" + """Update the observation based on current environment state. + + The structure of the observation space is described in :class:`.LinkTrafficLevels` + """ obs = [] for _, link in self.env.links.items(): bandwidth = link.bandwidth @@ -265,7 +287,11 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: - """Component-based observation space handler.""" + """Component-based observation space handler. + + This allows users to configure observation spaces by mixing and matching components. + Each component can also define further parameters to make them more flexible. + """ registry = { "NODE_LINK_TABLE": NodeLinkTable, @@ -274,8 +300,6 @@ class ObservationsHandler: } def __init__(self): - """TODO: complete description.""" - """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space self.current_observation: Union[Tuple[np.ndarray], np.ndarray] @@ -283,7 +307,7 @@ class ObservationsHandler: # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): - """TODO: complete description.""" + """Fetch fresh information about the environment.""" current_obs = [] for obs in self.registered_obs_components: obs.update() @@ -296,17 +320,26 @@ class ObservationsHandler: self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): - """TODO: complete description.""" + """Add a component for this handler to track. + + :param obs_component: The component to add. + :type obs_component: AbstractObservationComponent + """ self.registered_obs_components.append(obs_component) self.update_space() def deregister(self, obs_component: AbstractObservationComponent): - """TODO: complete description.""" + """Remove a component from this handler. + + :param obs_component: Which component to remove. It must exist within this object's + ``registered_obs_components`` attribute. + :type obs_component: AbstractObservationComponent + """ self.registered_obs_components.remove(obs_component) self.update_space() def update_space(self): - """TODO: complete description.""" + """Rebuild the handler's composite observation space from its components.""" component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) @@ -319,10 +352,28 @@ class ObservationsHandler: @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): - """TODO: complete description. + """Parse a config dictinary, return a new observation handler populated with new observation component objects. - This method parses config items related to the observation space, then - creates the necessary components and adds them to the observation handler. + The expected format for the config dictionary is: + + ..code-block::python + config = { + components: [ + { + "name": "" + }, + { + "name": "" + "options": {"opt1": val1, "opt2": val2} + }, + { + ... + }, + ] + } + + :return: Observation handler + :rtype: primaite.environment.observations.ObservationsHandler """ # Instantiate the handler handler = cls() diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 81557075..8cff91d8 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,7 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import networkx as nx import numpy as np @@ -643,16 +643,17 @@ class Primaite(Env): pass def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """TODO: write docstring.""" + """Create the environment's observation handler. + + :return: The observation space, initial observation (zeroed out array with the correct shape) + :rtype: Tuple[spaces.Space, np.ndarray] + """ self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): - """Updates the observation space based on the node and link status. - - TODO: better docstring - """ + """Updates the observation space based on the node and link status.""" self.obs_handler.update_obs() self.env_obs = self.obs_handler.current_observation @@ -1024,8 +1025,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def save_obs_config(self, obs_config: Optional[Dict] = None): - """TODO: better docstring.""" + def save_obs_config(self, obs_config: dict): + """Cache the config for the observation space. + + This is necessary as the observation space can't be built while reading the config, + it must be done after all the nodes, links, and services have been initialised. + + :param obs_config: Parsed config relating to the observation space. The format is described in + :py:meth:`primaite.environment.observations.ObservationsHandler.from_config` + :type obs_config: dict + """ self.obs_config = obs_config def get_steps_info(self, steps_info): From 85c102cfc19fb7c2236e19b7e6210eee58519cc4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 21:42:34 +0100 Subject: [PATCH 08/16] Update docs page on observations --- docs/source/about.rst | 49 +++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 8cc08b13..ee84d880 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -182,16 +182,13 @@ All ACL rules are considered when applying an IER. Logic follows the order of ru Observation Spaces ****************** +The observation space provides the blue agent with information about the current status of nodes and links. -The OpenAI Gym observation space provides the status of all nodes and links across the whole system:​ +PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted. -* Nodes (in terms of hardware state, Software State, file system state and services state) ​ -* Links (in terms of current loading for each service/protocol) - -The observation space can be configured as a ``gym.spaces.Box`` or ``gym.spaces.MultiDiscrete``, by setting the ``OBSERVATIONS`` parameter in the laydown config. - -Box-type observation space --------------------------- +NodeLinkTable component +----------------------- +For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below: An example observation space is provided below: @@ -249,8 +246,6 @@ An example observation space is provided below: - 5000 - 0 -The observation space is a 6 x 6 Box type (OpenAI Gym Space) in this example. This is made up from the node and link information detailed below. - For the nodes, the following values are represented: * ID @@ -290,9 +285,9 @@ For the links, the following statuses are represented: * SoftwareState = N/A * Protocol = loading in bits/s -MultiDiscrete-type observation space ------------------------------------- -The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers. +NodeStatus component +---------------------- +This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states, represented by integers. The example above would have the following structure: .. code-block:: @@ -301,9 +296,6 @@ The example above would have the following structure: node1_info node2_info node3_info - link1_status - link2_status - link3_status ] Each ``node_info`` contains the following: @@ -318,7 +310,25 @@ Each ``node_info`` contains the following: service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED) ] -Each ``link_status`` is just a number from 0-4 representing the network load in relation to bandwidth. +In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: + +.. code-block:: + + gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4]) + +LinkTrafficLevels +----------------- +This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value. +The number of bins can be customised with 5 being the default. It has the following strucutre: +.. code-block:: + + [ + link1_status + link2_status + link3_status + ] + +Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth. .. code-block:: @@ -328,12 +338,11 @@ Each ``link_status`` is just a number from 0-4 representing the network load in 3 = high traffic (<100%) 4 = max traffic/ overwhelmed (100%) -The full observation space would have 15 node-related elements and 3 link-related elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: +If the network has three links, the full observation space would have 3 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: .. code-block:: - gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) - + gym.spaces.MultiDiscrete([5,5,5]) Action Spaces ************** From 875562c3857b6c34cb79a7aceb418c28fe7a177f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 21:56:05 +0100 Subject: [PATCH 09/16] begin updating observations tests --- tests/test_observation_space.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 6a187761..a13121b9 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,5 +1,6 @@ """Test env creation and behaviour with different observation spaces.""" +from primaite.environment.observations import NodeStatuses, ObservationsHandler from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config @@ -32,3 +33,13 @@ def test_creating_env_with_multidiscrete_obs(): # the nodes have hardware, OS, FS, and service, the links just have bandwidth, # therefore we need 3*4 + 2 observations assert env.env_obs.shape == (3 * 4 + 2,) + + +def test_component_registration(): + """Test that we can register and deregister a component.""" + handler = ObservationsHandler() + component = NodeStatuses() + handler.register(component) + assert component in handler.registered_obs_components + handler.deregister(component) + assert component not in handler.registered_obs_components From b6ce1cbae91857e95b4a2e2ca7af99cd1e09cc25 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 09:10:53 +0100 Subject: [PATCH 10/16] Edit configs for observation space --- src/primaite/config/config_1_DDOS_BASIC.yaml | 8 -------- .../laydown_with_LINK_TRAFFIC_LEVELS.yaml} | 5 +++-- .../laydown_with_NODE_LINK_TABLE.yaml} | 8 ++++++-- 3 files changed, 9 insertions(+), 12 deletions(-) rename tests/config/{box_obs_space_laydown_config.yaml => obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml} (94%) rename tests/config/{multidiscrete_obs_space_laydown_config.yaml => obs_tests/laydown_with_NODE_LINK_TABLE.yaml} (87%) diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index a1961df3..ada813f3 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,13 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: NODE_LINK_TABLE - - name: NODE_STATUSES - - name: LINK_TRAFFIC_LEVELS - options: - - combine_service_traffic : False - - quantisation_levels : 7 - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/tests/config/box_obs_space_laydown_config.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml similarity index 94% rename from tests/config/box_obs_space_laydown_config.yaml rename to tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml index 203bc0e7..d1909125 100644 --- a/tests/config/box_obs_space_laydown_config.yaml +++ b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml @@ -1,7 +1,8 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX +- itemType: OBSERVATION_SPACE + components: + - name: NODE_STATUSES - itemType: STEPS steps: 5 - itemType: PORTS diff --git a/tests/config/multidiscrete_obs_space_laydown_config.yaml b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml similarity index 87% rename from tests/config/multidiscrete_obs_space_laydown_config.yaml rename to tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml index 38438d6d..36fb8199 100644 --- a/tests/config/multidiscrete_obs_space_laydown_config.yaml +++ b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml @@ -1,7 +1,11 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: MULTIDISCRETE +- itemType: OBSERVATION_SPACE + components: + - name: NODE_LINK_TABLE + options: + - combine_service_traffic: false + - quantisation_levels: 8 - itemType: STEPS steps: 5 - itemType: PORTS From f37b943f7eef3117e2cd90053553c05d7c8cb88d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 12:59:01 +0100 Subject: [PATCH 11/16] Add tests for observations --- pytest.ini | 2 + src/primaite/environment/observations.py | 13 +- .../laydown_with_LINK_TRAFFIC_LEVELS.yaml | 43 ++++- .../laydown_with_NODE_LINK_TABLE.yaml | 11 +- .../obs_tests/laydown_with_NODE_STATUSES.yaml | 107 +++++++++++ .../obs_tests/laydown_without_obs_space.yaml | 74 ++++++++ .../obs_tests/main_config_no_agent.yaml | 89 +++++++++ tests/test_observation_space.py | 169 +++++++++++++++--- 8 files changed, 476 insertions(+), 32 deletions(-) create mode 100644 tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml create mode 100644 tests/config/obs_tests/laydown_without_obs_space.yaml create mode 100644 tests/config/obs_tests/main_config_no_agent.yaml diff --git a/pytest.ini b/pytest.ini index e618d7a5..b5fae8d0 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ [pytest] testpaths = tests +markers = + env_config_paths diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index c4402b69..a467a5db 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -165,14 +165,15 @@ class NodeStatuses(AbstractObservationComponent): super().__init__(env) # 1. Define the shape of your observation space component - shape = [ + node_shape = [ len(HardwareState) + 1, len(SoftwareState) + 1, len(FileSystemState) + 1, ] services_shape = [len(SoftwareState) + 1] * self.env.num_services - shape = shape + services_shape + node_shape = node_shape + services_shape + shape = node_shape * self.env.num_nodes # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) @@ -199,7 +200,9 @@ class NodeStatuses(AbstractObservationComponent): for i, service in enumerate(self.env.services_list): if node.has_service(service): service_states[i] = node.get_service_state(service).value - obs.extend([hardware_state, software_state, file_system_state, *service_states]) + obs.extend( + [hardware_state, software_state, file_system_state, *service_states] + ) self.current_observation[:] = obs @@ -303,8 +306,6 @@ class ObservationsHandler: self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space self.current_observation: Union[Tuple[np.ndarray], np.ndarray] - # i can access the registry items like this: - # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): """Fetch fresh information about the environment.""" @@ -318,6 +319,7 @@ class ObservationsHandler: self.current_observation = current_obs[0] else: self.current_observation = tuple(current_obs) + # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -349,6 +351,7 @@ class ObservationsHandler: self.space = component_spaces[0] else: self.space = spaces.Tuple(component_spaces) + # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): diff --git a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml index d1909125..516bf5cc 100644 --- a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml @@ -2,15 +2,20 @@ type: NODE - itemType: OBSERVATION_SPACE components: - - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 - itemType: STEPS steps: 5 - itemType: PORTS portsList: - port: '80' + - port: '53' - itemType: SERVICES serviceList: - name: TCP + - name: UDP ######################################## # Nodes @@ -28,6 +33,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '2' name: SERVER @@ -42,6 +50,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '3' name: SWITCH1 @@ -67,3 +78,33 @@ bandwidth: 1000 source: '3' destination: '2' + +######################################### +# IERS +- itemType: GREEN_IER + id: '5' + startStep: 0 + endStep: 5 + load: 20 + protocol: TCP + port: '80' + source: '1' + destination: '2' + missionCriticality: 5 + +######################################### +# ACL Rules +- itemType: ACL_RULE + id: '6' + permission: ALLOW + source: 192.168.1.1 + destination: 192.168.1.2 + protocol: TCP + port: 80 +- itemType: ACL_RULE + id: '7' + permission: ALLOW + source: 192.168.1.2 + destination: 192.168.1.1 + protocol: TCP + port: 80 diff --git a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml index 36fb8199..0ceefbfa 100644 --- a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml @@ -3,17 +3,16 @@ - itemType: OBSERVATION_SPACE components: - name: NODE_LINK_TABLE - options: - - combine_service_traffic: false - - quantisation_levels: 8 - itemType: STEPS steps: 5 - itemType: PORTS portsList: - port: '80' + - port: '53' - itemType: SERVICES serviceList: - name: TCP + - name: UDP ######################################## # Nodes @@ -31,6 +30,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '2' name: SERVER @@ -45,6 +47,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '3' name: SWITCH1 diff --git a/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml b/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml new file mode 100644 index 00000000..56ff3725 --- /dev/null +++ b/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml @@ -0,0 +1,107 @@ +- itemType: ACTIONS + type: NODE +- itemType: OBSERVATION_SPACE + components: + - name: NODE_STATUSES +- itemType: STEPS + steps: 5 +- itemType: PORTS + portsList: + - port: '80' + - port: '53' +- itemType: SERVICES + serviceList: + - name: TCP + - name: UDP + +######################################## +# Nodes +- itemType: NODE + node_id: '1' + name: PC1 + node_class: SERVICE + node_type: COMPUTER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.1 + software_state: COMPROMISED + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: GOOD +- itemType: NODE + node_id: '2' + name: SERVER + node_class: SERVICE + node_type: SERVER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.2 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: OVERWHELMED +- itemType: NODE + node_id: '3' + name: SWITCH1 + node_class: ACTIVE + node_type: SWITCH + priority: P2 + hardware_state: 'ON' + ip_address: 192.168.1.3 + software_state: GOOD + file_system_state: GOOD + +######################################## +# Links +- itemType: LINK + id: '4' + name: link1 + bandwidth: 1000 + source: '1' + destination: '3' +- itemType: LINK + id: '5' + name: link2 + bandwidth: 1000 + source: '3' + destination: '2' + +######################################### +# IERS +- itemType: GREEN_IER + id: '5' + startStep: 0 + endStep: 5 + load: 20 + protocol: TCP + port: '80' + source: '1' + destination: '2' + missionCriticality: 5 + +######################################### +# ACL Rules +- itemType: ACL_RULE + id: '6' + permission: ALLOW + source: 192.168.1.1 + destination: 192.168.1.2 + protocol: TCP + port: 80 +- itemType: ACL_RULE + id: '7' + permission: ALLOW + source: 192.168.1.2 + destination: 192.168.1.1 + protocol: TCP + port: 80 diff --git a/tests/config/obs_tests/laydown_without_obs_space.yaml b/tests/config/obs_tests/laydown_without_obs_space.yaml new file mode 100644 index 00000000..3ef214da --- /dev/null +++ b/tests/config/obs_tests/laydown_without_obs_space.yaml @@ -0,0 +1,74 @@ +- itemType: ACTIONS + type: NODE +- itemType: STEPS + steps: 5 +- itemType: PORTS + portsList: + - port: '80' + - port: '53' +- itemType: SERVICES + serviceList: + - name: TCP + - name: UDP + +######################################## +# Nodes +- itemType: NODE + node_id: '1' + name: PC1 + node_class: SERVICE + node_type: COMPUTER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.1 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: GOOD +- itemType: NODE + node_id: '2' + name: SERVER + node_class: SERVICE + node_type: SERVER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.2 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: GOOD +- itemType: NODE + node_id: '3' + name: SWITCH1 + node_class: ACTIVE + node_type: SWITCH + priority: P2 + hardware_state: 'ON' + ip_address: 192.168.1.3 + software_state: GOOD + file_system_state: GOOD + +######################################## +# Links +- itemType: LINK + id: '4' + name: link1 + bandwidth: 1000 + source: '1' + destination: '3' +- itemType: LINK + id: '5' + name: link2 + bandwidth: 1000 + source: '3' + destination: '2' diff --git a/tests/config/obs_tests/main_config_no_agent.yaml b/tests/config/obs_tests/main_config_no_agent.yaml new file mode 100644 index 00000000..f632dca9 --- /dev/null +++ b/tests/config/obs_tests/main_config_no_agent.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index a13121b9..314728ae 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,45 +1,168 @@ """Test env creation and behaviour with different observation spaces.""" +import numpy as np +import pytest -from primaite.environment.observations import NodeStatuses, ObservationsHandler +from primaite.environment.observations import ( + NodeLinkTable, + NodeStatuses, + ObservationsHandler, +) +from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config -def test_creating_env_with_box_obs(): - """Try creating env with box observation space.""" +@pytest.fixture +def env(request): + """Build Primaite environment for integration tests of observation space.""" + marker = request.node.get_closest_marker("env_config_paths") + main_config_path = marker.args[0]["main_config_path"] + lay_down_config_path = marker.args[0]["lay_down_config_path"] env = _get_primaite_env_from_config( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", + main_config_path=main_config_path, + lay_down_config_path=lay_down_config_path, ) - env.update_environent_obs() - - # we have three nodes and two links, with one service - # therefore the box observation space will have: - # * 5 columns (four fixed and one for the service) - # * 5 rows (3 nodes + 2 links) - assert env.env_obs.shape == (5, 5) + yield env -def test_creating_env_with_multidiscrete_obs(): - """Try creating env with MultiDiscrete observation space.""" - env = _get_primaite_env_from_config( +@pytest.mark.env_config_paths( + dict( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT - / "multidiscrete_obs_space_laydown_config.yaml", + / "obs_tests/laydown_without_obs_space.yaml", ) +) +def test_default_obs_space(env: Primaite): + """Create environment with no obs space defined in config and check that the default obs space was created.""" env.update_environent_obs() - # we have three nodes and two links, with one service - # the nodes have hardware, OS, FS, and service, the links just have bandwidth, - # therefore we need 3*4 + 2 observations - assert env.env_obs.shape == (3 * 4 + 2,) + components = env.obs_handler.registered_obs_components + + assert len(components) == 1 + assert isinstance(components[0], NodeLinkTable) -def test_component_registration(): - """Test that we can register and deregister a component.""" +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_without_obs_space.yaml", + ) +) +def test_registering_components(env: Primaite): + """Test regitering and deregistering a component.""" handler = ObservationsHandler() - component = NodeStatuses() + component = NodeStatuses(env) handler.register(component) assert component in handler.registered_obs_components handler.deregister(component) assert component not in handler.registered_obs_components + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_with_NODE_LINK_TABLE.yaml", + ) +) +class TestNodeLinkTable: + """Test the NodeLinkTable observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with box observation space.""" + env.update_environent_obs() + + # we have three nodes and two links, with two service + # therefore the box observation space will have: + # * 5 rows (3 nodes + 2 links) + # * 6 columns (four fixed and two for the services) + assert env.env_obs.shape == (5, 6) + + # def test_value(self, env: Primaite): + # """""" + # ... + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_with_NODE_STATUSES.yaml", + ) +) +class TestNodeStatuses: + """Test the NodeStatuses observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with NodeStatuses as the only component.""" + assert env.env_obs.shape == (15) + + def test_values(self, env: Primaite): + """Test that the hardware and software states are encoded correctly. + + The laydown has: + * one node with a compromised operating system state + * one node with two services, and the second service is overwhelmed. + * all other states are good or null + Therefore, the expected state is: + * node 1: + * hardware = good (1) + * OS = compromised (3) + * file system = good (1) + * service 1 = good (1) + * service 2 = good (1) + * node 2: + * hardware = good (1) + * OS = good (1) + * file system = good (1) + * service 1 = good (1) + * service 2 = overwhelmed (4) + * node 3 (switch): + * hardware = good (1) + * OS = good (1) + * file system = good (1) + * service 1 = n/a (0) + * service 2 = n/a (0) + """ + act = np.asarray([0, 0, 0, 0]) + obs, _, _, _ = env.step(act) + assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml", + ) +) +class TestLinkTrafficLevels: + """Test the LinkTrafficLevels observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with MultiDiscrete observation space.""" + env.update_environent_obs() + + # we have two links and two services, so the shape should be 2 * 2 + assert env.env_obs.shape == (2 * 2,) + + def test_values(self, env: Primaite): + """Test that traffic values are encoded correctly. + + The laydown has: + * two services + * three nodes + * two links + * an IER trying to send 20 bits of data over both links the whole time (via the first service) + * link bandwidth of 1000, therefore the utilisation is 2% + """ + act = np.asarray([0, 0, 0, 0]) + obs, reward, done, info = env.step(act) + obs, reward, done, info = env.step(act) + + # the observation space has combine_service_traffic set to False, so the space has this format: + # [link1_service1, link1_service2, link2_service1, link2_service2] + # we send 20 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 1 and all others 0 + assert np.array_equal(obs, [1, 0, 1, 0]) From 2330a30021c689f64dbba3b07300826822d514f2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 13:08:11 +0100 Subject: [PATCH 12/16] Get observation tests passing --- tests/test_observation_space.py | 64 ++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 314728ae..3fe71003 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -79,9 +79,65 @@ class TestNodeLinkTable: # * 6 columns (four fixed and two for the services) assert env.env_obs.shape == (5, 6) - # def test_value(self, env: Primaite): - # """""" - # ... + def test_value(self, env: Primaite): + """Test that the observation is generated correctly. + + The laydown has: + * 3 nodes (2 service nodes and 1 active node) + * 2 services + * 2 links + + Both nodes have both services, and all states are GOOD, therefore the expected observation value is: + + * Node 1: + * 1 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 1 (good service1 state) + * 1 (good service2 state) + * Node 2: + * 2 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 1 (good service1 state) + * 1 (good service2 state) + * Node 3 (active node): + * 3 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 0 (doesn't have service1) + * 0 (doesn't have service2) + * Link 1: + * 4 (id) + * 0 (n/a hardware state) + * 0 (n/a OS state) + * 0 (n/a file system state) + * 0 (no traffic for service1) + * 0 (no traffic for service2) + * Link 2: + * 5 (id) + * 0 (good hardware state) + * 0 (good OS state) + * 0 (good file system state) + * 0 (no traffic service1) + * 0 (no traffic for service2) + """ + act = np.asarray([0, 0, 0, 0]) + obs, reward, done, info = env.step(act) + + assert np.array_equal( + obs, + [ + [1, 1, 1, 1, 1, 1], + [2, 1, 1, 1, 1, 1], + [3, 1, 1, 1, 0, 0], + [4, 0, 0, 0, 0, 0], + [5, 0, 0, 0, 0, 0], + ], + ) @pytest.mark.env_config_paths( @@ -96,7 +152,7 @@ class TestNodeStatuses: def test_obs_shape(self, env: Primaite): """Try creating env with NodeStatuses as the only component.""" - assert env.env_obs.shape == (15) + assert env.env_obs.shape == (15,) def test_values(self, env: Primaite): """Test that the hardware and software states are encoded correctly. From 25ec0d93a965cd95bdedb9f03c04f102f03bf65d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 13:15:38 +0100 Subject: [PATCH 13/16] Fix Link Traffic Levels observation encoding --- src/primaite/environment/observations.py | 2 +- .../obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml | 2 +- tests/test_observation_space.py | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a467a5db..a598d6db 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -281,7 +281,7 @@ class LinkTrafficLevels(AbstractObservationComponent): traffic_level = self._quantisation_levels - 1 else: traffic_level = (load / bandwidth) // ( - 1 / (self._quantisation_levels - 1) + 1 / (self._quantisation_levels - 2) ) + 1 obs.append(int(traffic_level)) diff --git a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml index 516bf5cc..e65ea306 100644 --- a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml @@ -85,7 +85,7 @@ id: '5' startStep: 0 endStep: 5 - load: 20 + load: 999 protocol: TCP port: '80' source: '1' diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 3fe71003..ae862c96 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -210,8 +210,8 @@ class TestLinkTrafficLevels: * two services * three nodes * two links - * an IER trying to send 20 bits of data over both links the whole time (via the first service) - * link bandwidth of 1000, therefore the utilisation is 2% + * an IER trying to send 999 bits of data over both links the whole time (via the first service) + * link bandwidth of 1000, therefore the utilisation is 99.9% """ act = np.asarray([0, 0, 0, 0]) obs, reward, done, info = env.step(act) @@ -219,6 +219,7 @@ class TestLinkTrafficLevels: # the observation space has combine_service_traffic set to False, so the space has this format: # [link1_service1, link1_service2, link2_service1, link2_service2] - # we send 20 bits of data via link1 and link2 on service 1. - # therefore the first and third elements should be 1 and all others 0 - assert np.array_equal(obs, [1, 0, 1, 0]) + # we send 999 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 6 and all others 0 + # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) + assert np.array_equal(obs, [6, 0, 6, 0]) From 9d868c50905141683b30a82bb2e33e519305e5e1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 13:23:03 +0100 Subject: [PATCH 14/16] Update docs with configurable obs space info --- docs/source/config.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/config.rst b/docs/source/config.rst index 88399973..8a8515ca 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -288,6 +288,28 @@ The config_[name].yaml file consists of the following attributes: Determines whether a NODE or ACL action space format is adopted for the session +* **itemType: OBSERVATION_SPACE** [dict] + + Allows for user to configure observation space by combining one or more observation components. List of available + components is is :py:mod:'primaite.environment.observations'. + + The observation space config item should have a ``components`` key which is a list of components. Each component + config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the + component while it is being initialised. + + This example illustrates the correct format for the observation space config item + +.. code-block::yaml + + - itemType: OBSERVATION_SPACE + components: + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 + - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + * **itemType: STEPS** [int] Determines the number of steps to run in each episode of the session From 9417cd85abfd72d00f0ee2c84c9c95db045042f2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Jun 2023 15:25:11 +0100 Subject: [PATCH 15/16] Apply suggestions from code review. --- src/primaite/common/config_values_main.py | 1 + src/primaite/environment/observations.py | 25 ++-- src/primaite/environment/primaite_env.py | 13 +-- src/primaite/main.py | 4 + ...n_with_NODE_STATUSES.yaml => laydown.yaml} | 5 +- .../laydown_with_LINK_TRAFFIC_LEVELS.yaml | 110 ------------------ .../laydown_with_NODE_LINK_TABLE.yaml | 77 ------------ .../obs_tests/laydown_without_obs_space.yaml | 74 ------------ .../main_config_LINK_TRAFFIC_LEVELS.yaml | 96 +++++++++++++++ .../main_config_NODE_LINK_TABLE.yaml | 93 +++++++++++++++ .../obs_tests/main_config_NODE_STATUSES.yaml | 93 +++++++++++++++ ...gent.yaml => main_config_without_obs.yaml} | 2 +- tests/conftest.py | 4 + tests/test_observation_space.py | 49 ++++---- 14 files changed, 338 insertions(+), 308 deletions(-) rename tests/config/obs_tests/{laydown_with_NODE_STATUSES.yaml => laydown.yaml} (95%) delete mode 100644 tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml delete mode 100644 tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml delete mode 100644 tests/config/obs_tests/laydown_without_obs_space.yaml create mode 100644 tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml create mode 100644 tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml create mode 100644 tests/config/obs_tests/main_config_NODE_STATUSES.yaml rename tests/config/obs_tests/{main_config_no_agent.yaml => main_config_without_obs.yaml} (98%) diff --git a/src/primaite/common/config_values_main.py b/src/primaite/common/config_values_main.py index 3493f9d2..f822b77f 100644 --- a/src/primaite/common/config_values_main.py +++ b/src/primaite/common/config_values_main.py @@ -9,6 +9,7 @@ class ConfigValuesMain(object): """Init.""" # Generic self.agent_identifier = "" # the agent in use + self.observation_config = None # observation space config self.num_episodes = 0 # number of episodes to train over self.num_steps = 0 # number of steps in an episode self.time_delay = 0 # delay between steps (ms) - applies to generic agents only diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a598d6db..9e71ef1b 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union import numpy as np from gym import spaces @@ -56,9 +56,9 @@ class NodeLinkTable(AbstractObservationComponent): ``(12, 7)`` """ - _FIXED_PARAMETERS = 4 - _MAX_VAL = 1_000_000 - _DATA_TYPE = np.int64 + _FIXED_PARAMETERS: int = 4 + _MAX_VAL: int = 1_000_000 + _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): super().__init__(env) @@ -159,7 +159,7 @@ class NodeStatuses(AbstractObservationComponent): :type env: Primaite """ - _DATA_TYPE = np.int64 + _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): super().__init__(env) @@ -231,7 +231,7 @@ class LinkTrafficLevels(AbstractObservationComponent): :type quantisation_levels: int, optional """ - _DATA_TYPE = np.int64 + _DATA_TYPE: type = np.int64 def __init__( self, @@ -239,7 +239,14 @@ class LinkTrafficLevels(AbstractObservationComponent): combine_service_traffic: bool = False, quantisation_levels: int = 5, ): - assert quantisation_levels >= 3 + if quantisation_levels < 3: + _msg = ( + f"quantisation_levels must be 3 or more because the lowest and highest levels are " + f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. " + f"Resetting to default value (5)" + ) + _LOGGER.warning(_msg) + quantisation_levels = 5 super().__init__(env) @@ -296,7 +303,7 @@ class ObservationsHandler: Each component can also define further parameters to make them more flexible. """ - registry = { + _REGISTRY: Final[Dict[str, type]] = { "NODE_LINK_TABLE": NodeLinkTable, "NODE_STATUSES": NodeStatuses, "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, @@ -384,7 +391,7 @@ class ObservationsHandler: for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] - comp_class = cls.registry[comp_type] + comp_class = cls._REGISTRY[comp_type] # Create the component with options from the YAML options = component_cfg.get("options") or {} diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0ff58100..7995c4f7 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -48,10 +48,10 @@ class Primaite(Env): """PRIMmary AI Training Evironment (Primaite) class.""" # Action Space contants - ACTION_SPACE_NODE_PROPERTY_VALUES = 5 - ACTION_SPACE_NODE_ACTION_VALUES = 4 - ACTION_SPACE_ACL_ACTION_VALUES = 3 - ACTION_SPACE_ACL_PERMISSION_VALUES = 2 + ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5 + ACTION_SPACE_NODE_ACTION_VALUES: int = 4 + ACTION_SPACE_ACL_ACTION_VALUES: int = 3 + ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__(self, _config_values, _transaction_list): """ @@ -148,6 +148,8 @@ class Primaite(Env): # stores the observation config from the yaml, default is NODE_LINK_TABLE self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]} + if self.config_values.observation_config is not None: + self.obs_config = self.config_values.observation_config # Observation Handler manages the user-configurable observation space. # It will be initialised later. @@ -690,9 +692,6 @@ class Primaite(Env): elif item["itemType"] == "ACTIONS": # Get the action information self.get_action_info(item) - elif item["itemType"] == "OBSERVATION_SPACE": - # Get the observation information - self.save_obs_config(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) diff --git a/src/primaite/main.py b/src/primaite/main.py index c963dd00..5f8aa5e2 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -163,6 +163,10 @@ def load_config_values(): try: # Generic config_values.agent_identifier = config_data["agentIdentifier"] + if "observationSpace" in config_data: + config_values.observation_config = config_data["observationSpace"] + else: + config_values.observation_config = None config_values.num_episodes = int(config_data["numEpisodes"]) config_values.time_delay = int(config_data["timeDelay"]) config_values.config_filename_use_case = ( diff --git a/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml b/tests/config/obs_tests/laydown.yaml similarity index 95% rename from tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml rename to tests/config/obs_tests/laydown.yaml index 56ff3725..d3b131db 100644 --- a/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml +++ b/tests/config/obs_tests/laydown.yaml @@ -1,8 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: NODE_STATUSES - itemType: STEPS steps: 5 - itemType: PORTS @@ -82,7 +79,7 @@ id: '5' startStep: 0 endStep: 5 - load: 20 + load: 999 protocol: TCP port: '80' source: '1' diff --git a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml deleted file mode 100644 index e65ea306..00000000 --- a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml +++ /dev/null @@ -1,110 +0,0 @@ -- itemType: ACTIONS - type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: LINK_TRAFFIC_LEVELS - options: - combine_service_traffic: false - quantisation_levels: 8 -- itemType: STEPS - steps: 5 -- itemType: PORTS - portsList: - - port: '80' - - port: '53' -- itemType: SERVICES - serviceList: - - name: TCP - - name: UDP - -######################################## -# Nodes -- itemType: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - -######################################## -# Links -- itemType: LINK - id: '4' - name: link1 - bandwidth: 1000 - source: '1' - destination: '3' -- itemType: LINK - id: '5' - name: link2 - bandwidth: 1000 - source: '3' - destination: '2' - -######################################### -# IERS -- itemType: GREEN_IER - id: '5' - startStep: 0 - endStep: 5 - load: 999 - protocol: TCP - port: '80' - source: '1' - destination: '2' - missionCriticality: 5 - -######################################### -# ACL Rules -- itemType: ACL_RULE - id: '6' - permission: ALLOW - source: 192.168.1.1 - destination: 192.168.1.2 - protocol: TCP - port: 80 -- itemType: ACL_RULE - id: '7' - permission: ALLOW - source: 192.168.1.2 - destination: 192.168.1.1 - protocol: TCP - port: 80 diff --git a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml deleted file mode 100644 index 0ceefbfa..00000000 --- a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml +++ /dev/null @@ -1,77 +0,0 @@ -- itemType: ACTIONS - type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: NODE_LINK_TABLE -- itemType: STEPS - steps: 5 -- itemType: PORTS - portsList: - - port: '80' - - port: '53' -- itemType: SERVICES - serviceList: - - name: TCP - - name: UDP - -######################################## -# Nodes -- itemType: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - -######################################## -# Links -- itemType: LINK - id: '4' - name: link1 - bandwidth: 1000 - source: '1' - destination: '3' -- itemType: LINK - id: '5' - name: link2 - bandwidth: 1000 - source: '3' - destination: '2' diff --git a/tests/config/obs_tests/laydown_without_obs_space.yaml b/tests/config/obs_tests/laydown_without_obs_space.yaml deleted file mode 100644 index 3ef214da..00000000 --- a/tests/config/obs_tests/laydown_without_obs_space.yaml +++ /dev/null @@ -1,74 +0,0 @@ -- itemType: ACTIONS - type: NODE -- itemType: STEPS - steps: 5 -- itemType: PORTS - portsList: - - port: '80' - - port: '53' -- itemType: SERVICES - serviceList: - - name: TCP - - name: UDP - -######################################## -# Nodes -- itemType: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - -######################################## -# Links -- itemType: LINK - id: '4' - name: link1 - bandwidth: 1000 - source: '1' - destination: '3' -- itemType: LINK - id: '5' - name: link2 - bandwidth: 1000 - source: '3' - destination: '2' diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml new file mode 100644 index 00000000..cdb741f3 --- /dev/null +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -0,0 +1,96 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml new file mode 100644 index 00000000..19d220c8 --- /dev/null +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -0,0 +1,93 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: NODE_LINK_TABLE + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml new file mode 100644 index 00000000..25520ccc --- /dev/null +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -0,0 +1,93 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: NODE_STATUSES + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_no_agent.yaml b/tests/config/obs_tests/main_config_without_obs.yaml similarity index 98% rename from tests/config/obs_tests/main_config_no_agent.yaml rename to tests/config/obs_tests/main_config_without_obs.yaml index f632dca9..43ee251f 100644 --- a/tests/config/obs_tests/main_config_no_agent.yaml +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -21,7 +21,7 @@ agentLoadFile: C:\[Path]\[agent_saved_filename.zip] # Environment config values # The high value for the observation space -observationSpaceHighValue: 1000000000 +observationSpaceHighValue: 1_000_000_000 # Reward values # Generic diff --git a/tests/conftest.py b/tests/conftest.py index 1e987223..f3728b63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,10 @@ def _get_primaite_env_from_config( def load_config_values(): config_values.agent_identifier = config_data["agentIdentifier"] + if "observationSpace" in config_data: + config_values.observation_config = config_data["observationSpace"] + else: + config_values.observation_config = None config_values.num_episodes = int(config_data["numEpisodes"]) config_values.time_delay = int(config_data["timeDelay"]) config_values.config_filename_use_case = lay_down_config_path diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index ae862c96..dcf98ae1 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -27,9 +27,8 @@ def env(request): @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_without_obs_space.yaml", + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) def test_default_obs_space(env: Primaite): @@ -44,9 +43,8 @@ def test_default_obs_space(env: Primaite): @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_without_obs_space.yaml", + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) def test_registering_components(env: Primaite): @@ -61,9 +59,9 @@ def test_registering_components(env: Primaite): @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_with_NODE_LINK_TABLE.yaml", + main_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_NODE_LINK_TABLE.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) class TestNodeLinkTable: @@ -92,17 +90,17 @@ class TestNodeLinkTable: * Node 1: * 1 (id) * 1 (good hardware state) - * 1 (good OS state) + * 3 (compromised OS state) * 1 (good file system state) - * 1 (good service1 state) - * 1 (good service2 state) + * 1 (good TCP state) + * 1 (good UDP state) * Node 2: * 2 (id) * 1 (good hardware state) * 1 (good OS state) * 1 (good file system state) - * 1 (good service1 state) - * 1 (good service2 state) + * 1 (good TCP state) + * 4 (overwhelmed UDP state) * Node 3 (active node): * 3 (id) * 1 (good hardware state) @@ -115,14 +113,14 @@ class TestNodeLinkTable: * 0 (n/a hardware state) * 0 (n/a OS state) * 0 (n/a file system state) - * 0 (no traffic for service1) + * 999 (999 traffic for service1) * 0 (no traffic for service2) * Link 2: * 5 (id) * 0 (good hardware state) * 0 (good OS state) * 0 (good file system state) - * 0 (no traffic service1) + * 999 (999 traffic service1) * 0 (no traffic for service2) """ act = np.asarray([0, 0, 0, 0]) @@ -131,20 +129,19 @@ class TestNodeLinkTable: assert np.array_equal( obs, [ - [1, 1, 1, 1, 1, 1], - [2, 1, 1, 1, 1, 1], + [1, 1, 3, 1, 1, 1], + [2, 1, 1, 1, 1, 4], [3, 1, 1, 1, 0, 0], - [4, 0, 0, 0, 0, 0], - [5, 0, 0, 0, 0, 0], + [4, 0, 0, 0, 999, 0], + [5, 0, 0, 0, 999, 0], ], ) @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_with_NODE_STATUSES.yaml", + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) class TestNodeStatuses: @@ -188,9 +185,9 @@ class TestNodeStatuses: @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml", + main_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) class TestLinkTrafficLevels: From 64bf4bf58ad5ba4444e05df0d1e9f99668fa2684 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 9 Jun 2023 10:28:24 +0100 Subject: [PATCH 16/16] Fix obs tests with new changes --- tests/test_observation_space.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index dcf98ae1..0df59b72 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -18,7 +18,7 @@ def env(request): marker = request.node.get_closest_marker("env_config_paths") main_config_path = marker.args[0]["main_config_path"] lay_down_config_path = marker.args[0]["lay_down_config_path"] - env = _get_primaite_env_from_config( + env, _ = _get_primaite_env_from_config( main_config_path=main_config_path, lay_down_config_path=lay_down_config_path, ) @@ -123,8 +123,8 @@ class TestNodeLinkTable: * 999 (999 traffic service1) * 0 (no traffic for service2) """ - act = np.asarray([0, 0, 0, 0]) - obs, reward, done, info = env.step(act) + # act = np.asarray([0,]) + obs, reward, done, info = env.step(0) # apply the 'do nothing' action assert np.array_equal( obs, @@ -178,8 +178,7 @@ class TestNodeStatuses: * service 1 = n/a (0) * service 2 = n/a (0) """ - act = np.asarray([0, 0, 0, 0]) - obs, _, _, _ = env.step(act) + obs, _, _, _ = env.step(0) # apply the 'do nothing' action assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) @@ -210,9 +209,8 @@ class TestLinkTrafficLevels: * an IER trying to send 999 bits of data over both links the whole time (via the first service) * link bandwidth of 1000, therefore the utilisation is 99.9% """ - act = np.asarray([0, 0, 0, 0]) - obs, reward, done, info = env.step(act) - obs, reward, done, info = env.step(act) + obs, reward, done, info = env.step(0) + obs, reward, done, info = env.step(0) # the observation space has combine_service_traffic set to False, so the space has this format: # [link1_service1, link1_service2, link2_service1, link2_service2]