From 0483eeca8276d59ea574ee2e5feaf4a07f7562d7 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 30 May 2023 11:40:40 +0100 Subject: [PATCH 01/18] 1443 - changed IF statements from `if initial ... if reference` to `if reference ... if final` to compare the final state (state after red and blue actions) with the reference state (state with no red or blue action and with green normal network traffic occurring) --- src/primaite/environment/reward.py | 144 ++++++++++++++--------------- 1 file changed, 70 insertions(+), 74 deletions(-) diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 007dcdc8..e1ece848 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -93,7 +93,6 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ """ score = 0 final_node_operating_state = final_node.hardware_state - initial_node_operating_state = initial_node.hardware_state reference_node_operating_state = reference_node.hardware_state if final_node_operating_state == reference_node_operating_state: @@ -102,26 +101,26 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ else: # We're different from the reference situation # Need to compare initial and reference (current) state of node (i.e. at every step) - if initial_node_operating_state == HardwareState.ON: - if reference_node_operating_state == HardwareState.OFF: + if reference_node_operating_state == HardwareState.ON: + if final_node_operating_state == HardwareState.OFF: score += config_values.off_should_be_on - elif reference_node_operating_state == HardwareState.RESETTING: + elif final_node_operating_state == HardwareState.RESETTING: score += config_values.resetting_should_be_on else: pass - elif initial_node_operating_state == HardwareState.OFF: - if reference_node_operating_state == HardwareState.ON: + elif reference_node_operating_state == HardwareState.OFF: + if final_node_operating_state == HardwareState.ON: score += config_values.on_should_be_off - elif reference_node_operating_state == HardwareState.RESETTING: + elif final_node_operating_state == HardwareState.RESETTING: score += config_values.resetting_should_be_off else: pass - elif initial_node_operating_state == HardwareState.RESETTING: - if reference_node_operating_state == HardwareState.ON: + elif reference_node_operating_state == HardwareState.RESETTING: + if final_node_operating_state == HardwareState.ON: score += config_values.on_should_be_resetting - elif reference_node_operating_state == HardwareState.OFF: + elif final_node_operating_state == HardwareState.OFF: score += config_values.off_should_be_resetting - elif reference_node_operating_state == HardwareState.RESETTING: + elif final_node_operating_state == HardwareState.RESETTING: score += config_values.resetting else: pass @@ -143,7 +142,6 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) """ score = 0 final_node_os_state = final_node.software_state - initial_node_os_state = initial_node.software_state reference_node_os_state = reference_node.software_state if final_node_os_state == reference_node_os_state: @@ -152,28 +150,28 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) else: # We're different from the reference situation # Need to compare initial and reference (current) state of node (i.e. at every step) - if initial_node_os_state == SoftwareState.GOOD: - if reference_node_os_state == SoftwareState.PATCHING: + if reference_node_os_state == SoftwareState.GOOD: + if final_node_os_state == SoftwareState.PATCHING: score += config_values.patching_should_be_good - elif reference_node_os_state == SoftwareState.COMPROMISED: + elif final_node_os_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_good else: pass - elif initial_node_os_state == SoftwareState.PATCHING: - if reference_node_os_state == SoftwareState.GOOD: + elif reference_node_os_state == SoftwareState.PATCHING: + if final_node_os_state == SoftwareState.GOOD: score += config_values.good_should_be_patching - elif reference_node_os_state == SoftwareState.COMPROMISED: + elif final_node_os_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_patching - elif reference_node_os_state == SoftwareState.PATCHING: + elif final_node_os_state == SoftwareState.PATCHING: score += config_values.patching else: pass - elif initial_node_os_state == SoftwareState.COMPROMISED: - if reference_node_os_state == SoftwareState.GOOD: + elif reference_node_os_state == SoftwareState.COMPROMISED: + if final_node_os_state == SoftwareState.GOOD: score += config_values.good_should_be_compromised - elif reference_node_os_state == SoftwareState.PATCHING: + elif final_node_os_state == SoftwareState.PATCHING: score += config_values.patching_should_be_compromised - elif reference_node_os_state == SoftwareState.COMPROMISED: + elif final_node_os_state == SoftwareState.COMPROMISED: score += config_values.compromised else: pass @@ -195,12 +193,11 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va """ score = 0 final_node_services: Dict[str, Service] = final_node.services - initial_node_services: Dict[str, Service] = initial_node.services reference_node_services: Dict[str, Service] = reference_node.services for service_key, final_service in final_node_services.items(): reference_service = reference_node_services[service_key] - initial_service = initial_node_services[service_key] + final_service = final_node_services[service_key] if final_service.software_state == reference_service.software_state: # All is well - we're no different from the reference situation @@ -208,45 +205,45 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va else: # We're different from the reference situation # Need to compare initial and reference state of node (i.e. at every step) - if initial_service.software_state == SoftwareState.GOOD: - if reference_service.software_state == SoftwareState.PATCHING: + if reference_service.software_state == SoftwareState.GOOD: + if final_service.software_state == SoftwareState.PATCHING: score += config_values.patching_should_be_good - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_good - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed_should_be_good else: pass - elif initial_service.software_state == SoftwareState.PATCHING: - if reference_service.software_state == SoftwareState.GOOD: + elif reference_service.software_state == SoftwareState.PATCHING: + if final_service.software_state == SoftwareState.GOOD: score += config_values.good_should_be_patching - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_patching - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed_should_be_patching - elif reference_service.software_state == SoftwareState.PATCHING: + elif final_service.software_state == SoftwareState.PATCHING: score += config_values.patching else: pass - elif initial_service.software_state == SoftwareState.COMPROMISED: - if reference_service.software_state == SoftwareState.GOOD: + elif reference_service.software_state == SoftwareState.COMPROMISED: + if final_service.software_state == SoftwareState.GOOD: score += config_values.good_should_be_compromised - elif reference_service.software_state == SoftwareState.PATCHING: + elif final_service.software_state == SoftwareState.PATCHING: score += config_values.patching_should_be_compromised - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed_should_be_compromised else: pass - elif initial_service.software_state == SoftwareState.OVERWHELMED: - if reference_service.software_state == SoftwareState.GOOD: + elif reference_service.software_state == SoftwareState.OVERWHELMED: + if final_service.software_state == SoftwareState.GOOD: score += config_values.good_should_be_overwhelmed - elif reference_service.software_state == SoftwareState.PATCHING: + elif final_service.software_state == SoftwareState.PATCHING: score += config_values.patching_should_be_overwhelmed - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_overwhelmed - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed else: pass @@ -267,7 +264,6 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu """ score = 0 final_node_file_system_state = final_node.file_system_state_actual - initial_node_file_system_state = initial_node.file_system_state_actual reference_node_file_system_state = reference_node.file_system_state_actual final_node_scanning_state = final_node.file_system_scanning @@ -280,66 +276,66 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu else: # We're different from the reference situation # Need to compare initial and reference state of node (i.e. at every step) - if initial_node_file_system_state == FileSystemState.GOOD: - if reference_node_file_system_state == FileSystemState.REPAIRING: + if reference_node_file_system_state == FileSystemState.GOOD: + if final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_good - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_good - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_good - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_good else: pass - elif initial_node_file_system_state == FileSystemState.REPAIRING: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.REPAIRING: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_repairing - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_repairing - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_repairing - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_repairing - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing else: pass - elif initial_node_file_system_state == FileSystemState.RESTORING: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.RESTORING: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_restoring - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_restoring - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_restoring - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_restoring - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring else: pass - elif initial_node_file_system_state == FileSystemState.CORRUPT: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.CORRUPT: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt else: pass - elif initial_node_file_system_state == FileSystemState.DESTROYED: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.DESTROYED: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed else: pass From 91dec9e83d5b00a03c1436acba71814aa35bba8c Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 30 May 2023 11:50:54 +0100 Subject: [PATCH 02/18] 1443 - updated test_reward.py to reflect updates to reward.py so that the correct config values are called i.e. compromisedShouldBeGood on the correct steps during the training run --- .../config/one_node_states_on_off_lay_down_config.yaml | 10 +++++----- tests/test_reward.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml index 355760bf..00f8016e 100644 --- a/tests/config/one_node_states_on_off_lay_down_config.yaml +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -1,7 +1,7 @@ - itemType: ACTIONS type: NODE - itemType: STEPS - steps: 13 + steps: 15 - itemType: PORTS portsList: - port: '21' @@ -42,7 +42,7 @@ - itemType: RED_POL id: '2' startStep: 3 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: FILE @@ -66,7 +66,7 @@ - itemType: RED_POL id: '4' startStep: 6 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: OPERATING @@ -90,7 +90,7 @@ - itemType: RED_POL id: '6' startStep: 9 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: SERVICE @@ -114,7 +114,7 @@ - itemType: RED_POL id: '8' startStep: 12 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: OS diff --git a/tests/test_reward.py b/tests/test_reward.py index 10dfb79c..4925a434 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -28,4 +28,4 @@ def test_rewards_are_being_penalised_at_each_step_function(): Average Reward: 2 (26 / 13) """ print("average reward", env.average_reward) - assert env.average_reward == 2.0 + assert env.average_reward == -8.0 From 2724838cf83ba042dba082de0bcc87976ae3f9fe Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 13:14:43 +0100 Subject: [PATCH 03/18] Setup testing scripts --- src/primaite/config/config_main.yaml | 2 +- src/primaite/main.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/primaite/config/config_main.yaml b/src/primaite/config/config_main.yaml index b31a73b7..6104e25b 100644 --- a/src/primaite/config/config_main.yaml +++ b/src/primaite/config/config_main.yaml @@ -11,7 +11,7 @@ numEpisodes: 10 # Time delay between steps (for generic agents) timeDelay: 10 # Filename of the scenario / laydown -configFilename: config_5_DATA_MANIPULATION.yaml +configFilename: config_1_DDOS_BASIC.yaml # Type of session to be run (TRAINING or EVALUATION) sessionType: TRAINING # Determine whether to load an agent from file diff --git a/src/primaite/main.py b/src/primaite/main.py index 0963fa7e..1211b5cd 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -349,12 +349,12 @@ except Exception: transaction_list = [] # Create the Primaite environment -try: - env = Primaite(config_values, transaction_list) - logging.info("PrimAITE environment created") -except Exception: - logging.error("Could not create PrimAITE environment") - logging.error("Exception occured", exc_info=True) +# try: +env = Primaite(config_values, transaction_list) +# logging.info("PrimAITE environment created") +# except Exception: +# logging.error("Could not create PrimAITE environment") +# logging.error("Exception occured", exc_info=True) # Get the number of steps (which is stored in the child config file) config_values.num_steps = env.episode_steps From 375e20a67b6b8e22fa9bcd73e821df766cf4dbff Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 15:11:41 +0100 Subject: [PATCH 04/18] Configure observation type MULTIDISCRETE --- src/primaite/common/enums.py | 7 + src/primaite/config/config_1_DDOS_BASIC.yaml | 2 + src/primaite/environment/primaite_env.py | 236 ++++++++++++++----- 3 files changed, 188 insertions(+), 57 deletions(-) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0e00c9e4..138d2742 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -83,6 +83,13 @@ class ActionType(Enum): ACL = 1 +class ObservationType(Enum): + """Observation type enumeration.""" + + BOX = 0 + MULTIDISCRETE = 1 + + class FileSystemState(Enum): """File System State.""" diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index ada813f3..2796adb4 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,5 +1,7 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATIONS + type: MULTIDISCRETE - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 99c7c09f..b85410c0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -23,6 +23,7 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, SoftwareState, ) @@ -148,6 +149,9 @@ class Primaite(Env): # The action type self.action_type = 0 + # Observation type. + self.observation_type = 0 + # Open the config file and build the environment laydown try: self.config_file = open(self.config_values.config_filename_use_case, "r") @@ -203,26 +207,8 @@ class Primaite(Env): # - service E state | service E loading # - service F state | service F loading # - service G state | service G loading - - # Calculate the number of items that need to be included in the - # observation space - num_items = self.num_links + self.num_nodes - # Set the number of observation parameters, being # of services plus id, - # hardware state, file system state and SoftwareState (i.e. 4) - self.num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - # Define the observation shape - self.observation_shape = (num_items, self.num_observation_parameters) - self.observation_space = spaces.Box( - low=0, - high=self.config_values.observation_space_high_value, - shape=self.observation_shape, - dtype=np.int64, - ) - - # This is the observation that is sent back via the rest and step functions - self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) + # Initiate observation space + self.observation_space, self.env_obs = self.init_observations() # Define Action Space - depends on action space type (Node or ACL) if self.action_type == ActionType.NODE: @@ -671,49 +657,172 @@ class Primaite(Env): else: pass + def init_observations(self): + """Build the observation space based on network laydown and provide initial obs. + + This method uses the object's `num_links`, `num_nodes`, `num_services`, + `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` + attributes to figure out the correct shape and format for the observation space. + + Returns + ------- + gym.spaces.Space + Gym observation space + numpy.Array + Initial observation with all entries set to 0 + """ + if self.observation_type == ObservationType.BOX: + _LOGGER.info("Observation space type BOX selected") + + # 1. Determine observation shape from laydown + num_items = self.num_links + self.num_nodes + num_observation_parameters = ( + self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS + ) + observation_shape = (num_items, num_observation_parameters) + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.Box( + low=0, + high=self.OBSERVATION_SPACE_HIGH_VALUE, + shape=observation_shape, + dtype=np.int64, + ) + initial_observation = np.zeros(observation_shape, dtype=np.int64) + + elif self.observation_type == ObservationType.MULTIDISCRETE: + _LOGGER.info("Observation space MULTIDISCRETE selected") + + # 1. Determine observation shape from laydown + node_obs_shape = [ + len(HardwareState) + 1, + len(SoftwareState) + 1, + len(FileSystemState) + 1, + ] + node_services = [len(SoftwareState) + 1] * self.num_services + node_obs_shape = node_obs_shape + node_services + # the magic number 5 refers to 5 states of quantisation of traffic amount. + # (zero, low, medium, high, fully utilised/overwhelmed) + link_obs_shape = [5] * self.num_links + observation_shape = node_obs_shape + link_obs_shape + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.MultiDiscrete(observation_shape) + initial_observation = np.zeros(len(observation_shape), dtype=np.int64) + else: + raise ValueError( + f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" + f", got {self.observation_type} instead" + ) + + return observation_space, initial_observation + def update_environent_obs(self): """Updates the observation space based on the node and link status.""" - item_index = 0 + if self.observation_type == ObservationType.BOX: + item_index = 0 - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: + # Do nodes first + for node_key, node in self.nodes.items(): + self.env_obs[item_index][0] = int(node.node_id) + self.env_obs[item_index][1] = node.hardware_state.value + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + self.env_obs[item_index][2] = node.software_state.value + self.env_obs[item_index][3] = node.file_system_state_observed.value + else: + self.env_obs[item_index][2] = 0 + self.env_obs[item_index][3] = 0 + service_index = 4 + if isinstance(node, ServiceNode): + for service in self.services_list: + if node.has_service(service): + self.env_obs[item_index][ + service_index + ] = node.get_service_state(service).value + else: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + else: + # Not a service node + for service in self.services_list: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + item_index += 1 + + # Now do links + for link_key, link in self.links.items(): + self.env_obs[item_index][0] = int(link.get_id()) + self.env_obs[item_index][1] = 0 self.env_obs[item_index][2] = 0 self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 + protocol_list = link.get_protocol_list() + protocol_index = 0 + for protocol in protocol_list: + self.env_obs[item_index][protocol_index + 4] = protocol.get_load() + protocol_index += 1 + item_index += 1 - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 + elif self.observation_type == ObservationType.MULTIDISCRETE: + obs = [] + # 1. Set nodes + # Each node has the following variables in the observation space: + # - Hardware state + # - Software state + # - File System state + # - Service 1 state + # - Service 2 state + # - ... + # - Service N state + for node_key, node in self.nodes.items(): + hardware_state = node.hardware_state.value + software_state = 0 + file_system_state = 0 + services_states = [0] * self.num_services + + if isinstance( + node, ActiveNode + ): # ServiceNode is a subclass of ActiveNode so no need to check that also + software_state = node.software_state.value + file_system_state = node.file_system_state_observed.value + + if isinstance(node, ServiceNode): + for i, service in enumerate(self.services_list): + if node.has_service(service): + services_states[i] = node.get_service_state(service).value + + obs.extend( + [ + hardware_state, + software_state, + file_system_state, + *services_states, + ] + ) + + # 2. Set links + # Each link has just one variable in the observation space, it represents the traffic amount + # In order for the space to be fully MultiDiscrete, the amount of + # traffic on each link is quantised into a few levels: + # 0: no traffic (0% of bandwidth) + # 1: low traffic (0-33% of bandwidth) + # 2: medium traffic (33-66% of bandwidth) + # 3: high traffic (66-100% of bandwidth) + # 4: max traffic/overloaded (100% of bandwidth) + + for link_key, link in self.links.items(): + bandwidth = link.bandwidth + load = link.get_current_load() + + if load <= 0: + traffic_level = 0 + elif load >= bandwidth: + traffic_level = 4 + else: + traffic_level = (load / bandwidth) // (1 / 3) + 1 + + obs.append(int(traffic_level)) + + self.env_obs = np.asarray(obs) def load_config(self): """Loads config data in order to build the environment configuration.""" @@ -748,6 +857,9 @@ class Primaite(Env): elif item["itemType"] == "ACTIONS": # Get the action information self.get_action_info(item) + elif item["itemType"] == "OBSERVATIONS": + # Get the observation information + self.get_observation_info(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1080,6 +1192,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] + def get_observation_info(self, observation_info): + """Extracts observation_info. + + Parameters + ---------- + observation_info : str + Config item that defines which type of observation space to use + """ + self.observation_type = ObservationType[observation_info["type"]] + def get_steps_info(self, steps_info): """ Extracts steps_info. From 0227769c341fa2463e54fca72bef93e7b67f70e3 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 15:16:14 +0100 Subject: [PATCH 05/18] Fix observation node shape --- src/primaite/environment/primaite_env.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index b85410c0..7f102bd4 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -191,22 +191,6 @@ class Primaite(Env): _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") - # Define Observation Space - # x = number of nodes and links (i.e. items) - # y = number of parameters to be sent - # For each item, we send: - # - [For Nodes] | [For Links] - # - node ID | link ID - # - hardware state | N/A - # - Software State | N/A - # - file system state | N/A - # - service A state | service A loading - # - service B state | service B loading - # - service C state | service C loading - # - service D state | service D loading - # - service E state | service E loading - # - service F state | service F loading - # - service G state | service G loading # Initiate observation space self.observation_space, self.env_obs = self.init_observations() @@ -704,7 +688,7 @@ class Primaite(Env): # the magic number 5 refers to 5 states of quantisation of traffic amount. # (zero, low, medium, high, fully utilised/overwhelmed) link_obs_shape = [5] * self.num_links - observation_shape = node_obs_shape + link_obs_shape + observation_shape = node_obs_shape * self.num_nodes + link_obs_shape # 2. Create observation space & zeroed out sample from space. observation_space = spaces.MultiDiscrete(observation_shape) From fa44dd1a26f61ada2d6c2dcd6b9b1f859f1838fb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 15:24:13 +0100 Subject: [PATCH 06/18] Update configs and transactions to include new obs --- src/primaite/config/config_1_DDOS_BASIC.yaml | 2 +- src/primaite/config/config_2_DDOS_BASIC.yaml | 2 ++ src/primaite/config/config_3_DOS_VERY_BASIC.yaml | 2 ++ src/primaite/config/config_5_DATA_MANIPULATION.yaml | 2 ++ src/primaite/config/config_UNIT_TEST.yaml | 2 ++ src/primaite/transactions/transactions_to_file.py | 7 ++++++- tests/config/one_node_states_on_off_lay_down_config.yaml | 2 ++ 7 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index 2796adb4..b6aef17b 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,7 +1,7 @@ - itemType: ACTIONS type: NODE - itemType: OBSERVATIONS - type: MULTIDISCRETE + type: BOX - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/config/config_2_DDOS_BASIC.yaml b/src/primaite/config/config_2_DDOS_BASIC.yaml index 425fe013..f6e9cf52 100644 --- a/src/primaite/config/config_2_DDOS_BASIC.yaml +++ b/src/primaite/config/config_2_DDOS_BASIC.yaml @@ -1,5 +1,7 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATIONS + type: BOX - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/config/config_3_DOS_VERY_BASIC.yaml b/src/primaite/config/config_3_DOS_VERY_BASIC.yaml index 8c9b84a6..8ed65535 100644 --- a/src/primaite/config/config_3_DOS_VERY_BASIC.yaml +++ b/src/primaite/config/config_3_DOS_VERY_BASIC.yaml @@ -1,5 +1,7 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATIONS + type: BOX - itemType: STEPS steps: 256 - itemType: PORTS diff --git a/src/primaite/config/config_5_DATA_MANIPULATION.yaml b/src/primaite/config/config_5_DATA_MANIPULATION.yaml index 3b29ff4a..5d48ffe4 100644 --- a/src/primaite/config/config_5_DATA_MANIPULATION.yaml +++ b/src/primaite/config/config_5_DATA_MANIPULATION.yaml @@ -1,5 +1,7 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATIONS + type: BOX - itemType: STEPS steps: 256 - itemType: PORTS diff --git a/src/primaite/config/config_UNIT_TEST.yaml b/src/primaite/config/config_UNIT_TEST.yaml index 3b29ff4a..5d48ffe4 100644 --- a/src/primaite/config/config_UNIT_TEST.yaml +++ b/src/primaite/config/config_UNIT_TEST.yaml @@ -1,5 +1,7 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATIONS + type: BOX - itemType: STEPS steps: 256 - itemType: PORTS diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index c4852982..7a6e212c 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -51,8 +51,13 @@ def write_transaction_to_file(_transaction_list): # This will be tied into the PrimAITE Use Case so that they make sense template_transation = _transaction_list[0] action_length = template_transation.action_space.size + obs_shape = template_transation.obs_space_post.shape obs_assets = template_transation.obs_space_post.shape[0] - obs_features = template_transation.obs_space_post.shape[1] + if len(obs_shape) == 1: + # bit of a workaround but I think the way transactions are written will change soon + obs_features = 1 + else: + obs_features = template_transation.obs_space_post.shape[1] # Create the action space headers array action_header = [] diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml index 355760bf..47962102 100644 --- a/tests/config/one_node_states_on_off_lay_down_config.yaml +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -1,5 +1,7 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATIONS + type: BOX - itemType: STEPS steps: 13 - itemType: PORTS From 6507529db359690a7dc90b2786a95dcc1c4bb1d7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 15:48:11 +0100 Subject: [PATCH 07/18] Add test for new multidiscrete spaces --- .../config/box_obs_space_laydown_config.yaml | 31 +++++++++++++++++++ ...ultidiscrete_obs_space_laydown_config.yaml | 31 +++++++++++++++++++ tests/test_observation_space.py | 23 ++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 tests/config/box_obs_space_laydown_config.yaml create mode 100644 tests/config/multidiscrete_obs_space_laydown_config.yaml create mode 100644 tests/test_observation_space.py diff --git a/tests/config/box_obs_space_laydown_config.yaml b/tests/config/box_obs_space_laydown_config.yaml new file mode 100644 index 00000000..92226863 --- /dev/null +++ b/tests/config/box_obs_space_laydown_config.yaml @@ -0,0 +1,31 @@ +- itemType: ACTIONS + type: NODE +- itemType: OBSERVATIONS + type: BOX +- itemType: STEPS + steps: 5 +- itemType: PORTS + portsList: + - port: '21' +- itemType: SERVICES + serviceList: + - name: ftp +- itemType: NODE + node_id: '1' + name: node + node_class: SERVICE + node_type: COMPUTER + priority: P1 + hardware_state: 'ON' + ip_address: 192.168.0.1 + software_state: GOOD + file_system_state: GOOD + services: + - name: ftp + port: '21' + state: GOOD +- itemType: POSITION + positions: + - node: '1' + x_pos: 309 + y_pos: 78 diff --git a/tests/config/multidiscrete_obs_space_laydown_config.yaml b/tests/config/multidiscrete_obs_space_laydown_config.yaml new file mode 100644 index 00000000..dba1c6de --- /dev/null +++ b/tests/config/multidiscrete_obs_space_laydown_config.yaml @@ -0,0 +1,31 @@ +- itemType: ACTIONS + type: NODE +- itemType: OBSERVATIONS + type: MULTIDISCRETE +- itemType: STEPS + steps: 5 +- itemType: PORTS + portsList: + - port: '21' +- itemType: SERVICES + serviceList: + - name: ftp +- itemType: NODE + node_id: '1' + name: node + node_class: SERVICE + node_type: COMPUTER + priority: P1 + hardware_state: 'ON' + ip_address: 192.168.0.1 + software_state: GOOD + file_system_state: GOOD + services: + - name: ftp + port: '21' + state: GOOD +- itemType: POSITION + positions: + - node: '1' + x_pos: 309 + y_pos: 78 diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py new file mode 100644 index 00000000..2f8e6a24 --- /dev/null +++ b/tests/test_observation_space.py @@ -0,0 +1,23 @@ +"""Test env creation and behaviour with different observation spaces.""" + +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_primaite_env_from_config + + +def test_creating_env_with_box_obs(): + """Try creating env with box observation space.""" + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", + ) + env.update_environent_obs() + + +def test_creating_env_with_multidiscrete_obs(): + """Try creating env with MultiDiscrete observation space.""" + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "multidiscrete_obs_space_laydown_config.yaml", + ) + env.update_environent_obs() From 045e074d0f81f47c9a822d3d708ad73b03b82392 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 16:54:34 +0100 Subject: [PATCH 08/18] Update docs on MultiDiscrete observation spaces. --- docs/source/about.rst | 50 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/docs/source/about.rst b/docs/source/about.rst index 809b0ebe..8cc08b13 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -188,6 +188,11 @@ The OpenAI Gym observation space provides the status of all nodes and links acro * Nodes (in terms of hardware state, Software State, file system state and services state) ​ * Links (in terms of current loading for each service/protocol) +The observation space can be configured as a ``gym.spaces.Box`` or ``gym.spaces.MultiDiscrete``, by setting the ``OBSERVATIONS`` parameter in the laydown config. + +Box-type observation space +-------------------------- + An example observation space is provided below: .. list-table:: Observation Space example @@ -285,6 +290,51 @@ For the links, the following statuses are represented: * SoftwareState = N/A * Protocol = loading in bits/s +MultiDiscrete-type observation space +------------------------------------ +The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers. +The example above would have the following structure: + +.. code-block:: + + [ + node1_info + node2_info + node3_info + link1_status + link2_status + link3_status + ] + +Each ``node_info`` contains the following: + +.. code-block:: + + [ + hardware_state (0=none, 1=ON, 2=OFF, 3=RESETTING) + software_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED) + file_system_state (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING) + service1_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED) + service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED) + ] + +Each ``link_status`` is just a number from 0-4 representing the network load in relation to bandwidth. + +.. code-block:: + + 0 = No traffic (0%) + 1 = low traffic (<33%) + 2 = medium traffic (<66%) + 3 = high traffic (<100%) + 4 = max traffic/ overwhelmed (100%) + +The full observation space would have 15 node-related elements and 3 link-related elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: + +.. code-block:: + + gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) + + Action Spaces ************** From 83694fe53701345a74b2e5c5f8409912dd4fcaaf Mon Sep 17 00:00:00 2001 From: Sunil Samra Date: Wed, 31 May 2023 08:09:09 +0000 Subject: [PATCH 09/18] Apply suggestions from code review --- src/primaite/environment/reward.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index e1ece848..a620f9b3 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -100,7 +100,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference (current) state of node (i.e. at every step) + # Need to compare reference and final (current) state of node (i.e. at every step) if reference_node_operating_state == HardwareState.ON: if final_node_operating_state == HardwareState.OFF: score += config_values.off_should_be_on @@ -149,7 +149,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference (current) state of node (i.e. at every step) + # Need to compare reference and final (current) state of node (i.e. at every step) if reference_node_os_state == SoftwareState.GOOD: if final_node_os_state == SoftwareState.PATCHING: score += config_values.patching_should_be_good @@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference state of node (i.e. at every step) + # Need to compare reference and final state of node (i.e. at every step) if reference_service.software_state == SoftwareState.GOOD: if final_service.software_state == SoftwareState.PATCHING: score += config_values.patching_should_be_good @@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference state of node (i.e. at every step) + # Need to compare reference and final state of node (i.e. at every step) if reference_node_file_system_state == FileSystemState.GOOD: if final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_good From 5ea77f3e7505f21333de98484a23c3569f8dcfa2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 May 2023 09:26:40 +0000 Subject: [PATCH 10/18] Added pull_request_template.md --- .azuredevops/pull_request_template.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .azuredevops/pull_request_template.md diff --git a/.azuredevops/pull_request_template.md b/.azuredevops/pull_request_template.md new file mode 100644 index 00000000..538baf5c --- /dev/null +++ b/.azuredevops/pull_request_template.md @@ -0,0 +1,12 @@ +## Summary +*Replace this text with an explanation of what the changes are and how you implemented them. Can this impact any other parts of the codebase that we should keep in mind?* + +## Test process +*How have you tested this (if applicable)?* + +## Checklist +- [ ] This PR is linked to a **work item** +- [ ] I have performed **self-review** of the code +- [ ] I have written **tests** for any new functionality added with this PR +- [ ] I have updated the **documentation** if this PR changes or adds functionality +- [ ] I have run **pre-commit** checks for code style From 65f2d6202f0f6dc70d4d7c1d62342fd5ed964ddb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 May 2023 10:51:29 +0100 Subject: [PATCH 11/18] Add default observation type --- src/primaite/environment/primaite_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 7f102bd4..f57b274d 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -149,8 +149,8 @@ class Primaite(Env): # The action type self.action_type = 0 - # Observation type. - self.observation_type = 0 + # Observation type, by default box. + self.observation_type = ObservationType.BOX # Open the config file and build the environment laydown try: From 2260cb1668072d52ae38e6e228c59d7b0f1beaaf Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 May 2023 10:52:57 +0100 Subject: [PATCH 12/18] Revert config changes by removing observations --- src/primaite/config/config_1_DDOS_BASIC.yaml | 2 -- src/primaite/config/config_2_DDOS_BASIC.yaml | 2 -- src/primaite/config/config_3_DOS_VERY_BASIC.yaml | 2 -- src/primaite/config/config_5_DATA_MANIPULATION.yaml | 2 -- src/primaite/config/config_UNIT_TEST.yaml | 2 -- tests/config/one_node_states_on_off_lay_down_config.yaml | 2 -- 6 files changed, 12 deletions(-) diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index b6aef17b..ada813f3 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,7 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/config/config_2_DDOS_BASIC.yaml b/src/primaite/config/config_2_DDOS_BASIC.yaml index f6e9cf52..425fe013 100644 --- a/src/primaite/config/config_2_DDOS_BASIC.yaml +++ b/src/primaite/config/config_2_DDOS_BASIC.yaml @@ -1,7 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/config/config_3_DOS_VERY_BASIC.yaml b/src/primaite/config/config_3_DOS_VERY_BASIC.yaml index 8ed65535..8c9b84a6 100644 --- a/src/primaite/config/config_3_DOS_VERY_BASIC.yaml +++ b/src/primaite/config/config_3_DOS_VERY_BASIC.yaml @@ -1,7 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX - itemType: STEPS steps: 256 - itemType: PORTS diff --git a/src/primaite/config/config_5_DATA_MANIPULATION.yaml b/src/primaite/config/config_5_DATA_MANIPULATION.yaml index 5d48ffe4..3b29ff4a 100644 --- a/src/primaite/config/config_5_DATA_MANIPULATION.yaml +++ b/src/primaite/config/config_5_DATA_MANIPULATION.yaml @@ -1,7 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX - itemType: STEPS steps: 256 - itemType: PORTS diff --git a/src/primaite/config/config_UNIT_TEST.yaml b/src/primaite/config/config_UNIT_TEST.yaml index 5d48ffe4..3b29ff4a 100644 --- a/src/primaite/config/config_UNIT_TEST.yaml +++ b/src/primaite/config/config_UNIT_TEST.yaml @@ -1,7 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX - itemType: STEPS steps: 256 - itemType: PORTS diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml index 1ab5f8c2..00f8016e 100644 --- a/tests/config/one_node_states_on_off_lay_down_config.yaml +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -1,7 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX - itemType: STEPS steps: 15 - itemType: PORTS From c6bb855456386278b2a42e04c6e0ad36fddc1b09 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 May 2023 09:55:28 +0000 Subject: [PATCH 13/18] Revert unnecessary main.py change --- src/primaite/config/config_main.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/config/config_main.yaml b/src/primaite/config/config_main.yaml index 6104e25b..b31a73b7 100644 --- a/src/primaite/config/config_main.yaml +++ b/src/primaite/config/config_main.yaml @@ -11,7 +11,7 @@ numEpisodes: 10 # Time delay between steps (for generic agents) timeDelay: 10 # Filename of the scenario / laydown -configFilename: config_1_DDOS_BASIC.yaml +configFilename: config_5_DATA_MANIPULATION.yaml # Type of session to be run (TRAINING or EVALUATION) sessionType: TRAINING # Determine whether to load an agent from file From 76ec9683cb96135bd35cc297b5531b1141e0aeb8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 09:45:46 +0100 Subject: [PATCH 14/18] Improve observation space test --- .../config/box_obs_space_laydown_config.yaml | 63 +++++++++++++++---- ...ultidiscrete_obs_space_laydown_config.yaml | 63 +++++++++++++++---- tests/test_observation_space.py | 11 ++++ 3 files changed, 111 insertions(+), 26 deletions(-) diff --git a/tests/config/box_obs_space_laydown_config.yaml b/tests/config/box_obs_space_laydown_config.yaml index 92226863..203bc0e7 100644 --- a/tests/config/box_obs_space_laydown_config.yaml +++ b/tests/config/box_obs_space_laydown_config.yaml @@ -6,26 +6,63 @@ steps: 5 - itemType: PORTS portsList: - - port: '21' + - port: '80' - itemType: SERVICES serviceList: - - name: ftp + - name: TCP + +######################################## +# Nodes - itemType: NODE node_id: '1' - name: node + name: PC1 node_class: SERVICE node_type: COMPUTER - priority: P1 + priority: P5 hardware_state: 'ON' - ip_address: 192.168.0.1 + ip_address: 192.168.1.1 software_state: GOOD file_system_state: GOOD services: - - name: ftp - port: '21' - state: GOOD -- itemType: POSITION - positions: - - node: '1' - x_pos: 309 - y_pos: 78 + - name: TCP + port: '80' + state: GOOD +- itemType: NODE + node_id: '2' + name: SERVER + node_class: SERVICE + node_type: SERVER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.2 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD +- itemType: NODE + node_id: '3' + name: SWITCH1 + node_class: ACTIVE + node_type: SWITCH + priority: P2 + hardware_state: 'ON' + ip_address: 192.168.1.3 + software_state: GOOD + file_system_state: GOOD + +######################################## +# Links +- itemType: LINK + id: '4' + name: link1 + bandwidth: 1000 + source: '1' + destination: '3' +- itemType: LINK + id: '5' + name: link2 + bandwidth: 1000 + source: '3' + destination: '2' diff --git a/tests/config/multidiscrete_obs_space_laydown_config.yaml b/tests/config/multidiscrete_obs_space_laydown_config.yaml index dba1c6de..38438d6d 100644 --- a/tests/config/multidiscrete_obs_space_laydown_config.yaml +++ b/tests/config/multidiscrete_obs_space_laydown_config.yaml @@ -6,26 +6,63 @@ steps: 5 - itemType: PORTS portsList: - - port: '21' + - port: '80' - itemType: SERVICES serviceList: - - name: ftp + - name: TCP + +######################################## +# Nodes - itemType: NODE node_id: '1' - name: node + name: PC1 node_class: SERVICE node_type: COMPUTER - priority: P1 + priority: P5 hardware_state: 'ON' - ip_address: 192.168.0.1 + ip_address: 192.168.1.1 software_state: GOOD file_system_state: GOOD services: - - name: ftp - port: '21' - state: GOOD -- itemType: POSITION - positions: - - node: '1' - x_pos: 309 - y_pos: 78 + - name: TCP + port: '80' + state: GOOD +- itemType: NODE + node_id: '2' + name: SERVER + node_class: SERVICE + node_type: SERVER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.2 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD +- itemType: NODE + node_id: '3' + name: SWITCH1 + node_class: ACTIVE + node_type: SWITCH + priority: P2 + hardware_state: 'ON' + ip_address: 192.168.1.3 + software_state: GOOD + file_system_state: GOOD + +######################################## +# Links +- itemType: LINK + id: '4' + name: link1 + bandwidth: 1000 + source: '1' + destination: '3' +- itemType: LINK + id: '5' + name: link2 + bandwidth: 1000 + source: '3' + destination: '2' diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 2f8e6a24..6a187761 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -12,6 +12,12 @@ def test_creating_env_with_box_obs(): ) env.update_environent_obs() + # we have three nodes and two links, with one service + # therefore the box observation space will have: + # * 5 columns (four fixed and one for the service) + # * 5 rows (3 nodes + 2 links) + assert env.env_obs.shape == (5, 5) + def test_creating_env_with_multidiscrete_obs(): """Try creating env with MultiDiscrete observation space.""" @@ -21,3 +27,8 @@ def test_creating_env_with_multidiscrete_obs(): / "multidiscrete_obs_space_laydown_config.yaml", ) env.update_environent_obs() + + # we have three nodes and two links, with one service + # the nodes have hardware, OS, FS, and service, the links just have bandwidth, + # therefore we need 3*4 + 2 observations + assert env.env_obs.shape == (3 * 4 + 2,) From a0960555fcb134633b11c59d1f7a039df317e937 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 09:54:45 +0100 Subject: [PATCH 15/18] Fix docstrings to use ReST format --- src/primaite/environment/primaite_env.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index f57b274d..99666237 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -648,12 +648,10 @@ class Primaite(Env): `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` attributes to figure out the correct shape and format for the observation space. - Returns - ------- - gym.spaces.Space - Gym observation space - numpy.Array - Initial observation with all entries set to 0 + :return: Gym observation space + :rtype: gym.spaces.Space + :return: Initial observation with all entires set to 0 + :rtype: numpy.Array """ if self.observation_type == ObservationType.BOX: _LOGGER.info("Observation space type BOX selected") @@ -1179,10 +1177,8 @@ class Primaite(Env): def get_observation_info(self, observation_info): """Extracts observation_info. - Parameters - ---------- - observation_info : str - Config item that defines which type of observation space to use + :param observation_info: Config item that defines which type of observation space to use + :type observation_info: str """ self.observation_type = ObservationType[observation_info["type"]] From bfd20b7a6bdd47641fd6277b931a38cb28dc997e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 09:57:33 +0100 Subject: [PATCH 16/18] Type hint init_observations return type --- src/primaite/environment/primaite_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 99666237..6d956859 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,7 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict +from typing import Dict, Tuple import networkx as nx import numpy as np @@ -641,7 +641,7 @@ class Primaite(Env): else: pass - def init_observations(self): + def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: """Build the observation space based on network laydown and provide initial obs. This method uses the object's `num_links`, `num_nodes`, `num_services`, From 37d606eda60b9f9ea84b4c6a21099dac89a49f79 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 10:57:11 +0100 Subject: [PATCH 17/18] Separate obs functions and provide docstrings --- src/primaite/environment/primaite_env.py | 360 ++++++++++++++--------- 1 file changed, 220 insertions(+), 140 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 6d956859..67ab5375 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -641,6 +641,95 @@ class Primaite(Env): else: pass + def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]: + """Initialise the observation space with the BOX option chosen. + + This will create the observation space formatted as a table of integers. + There is one row per node, followed by one row per link. + Columns are as follows: + * node/link ID + * node hardware status / 0 for links + * node operating system status (if active/service) / 0 for links + * node file system status (active/service only) / 0 for links + * node service1 status / traffic load from that service for links + * node service2 status / traffic load from that service for links + * ... + * node serviceN status / traffic load from that service for links + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(12, 7)`` + + :return: Box gym observation + :rtype: gym.spaces.Box + :return: Initial observation with all entires set to 0 + :rtype: numpy.Array + """ + _LOGGER.info("Observation space type BOX selected") + + # 1. Determine observation shape from laydown + num_items = self.num_links + self.num_nodes + num_observation_parameters = ( + self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS + ) + observation_shape = (num_items, num_observation_parameters) + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.Box( + low=0, + high=self.OBSERVATION_SPACE_HIGH_VALUE, + shape=observation_shape, + dtype=np.int64, + ) + initial_observation = np.zeros(observation_shape, dtype=np.int64) + + return observation_space, initial_observation + + def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]: + """Initialise the observation space with the MULTIDISCRETE option chosen. + + This will create the observation space with node observations followed by link observations. + Each node has 3 elements in the observation space plus 1 per service, more specifically: + * hardware state + * operating system state + * file system state + * service states (one per service) + Each link has one element in the observation space, corresponding to the traffic load, + it can take the following values: + 0 = No traffic (0% of bandwidth) + 1 = No traffic (0%-33% of bandwidth) + 2 = No traffic (33%-66% of bandwidth) + 3 = No traffic (66%-100% of bandwidth) + 4 = No traffic (100% of bandwidth) + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(37,)`` + + :return: MultiDiscrete gym observation + :rtype: gym.spaces.MultiDiscrete + :return: Initial observation with all entires set to 0 + :rtype: numpy.Array + """ + _LOGGER.info("Observation space MULTIDISCRETE selected") + + # 1. Determine observation shape from laydown + node_obs_shape = [ + len(HardwareState) + 1, + len(SoftwareState) + 1, + len(FileSystemState) + 1, + ] + node_services = [len(SoftwareState) + 1] * self.num_services + node_obs_shape = node_obs_shape + node_services + # the magic number 5 refers to 5 states of quantisation of traffic amount. + # (zero, low, medium, high, fully utilised/overwhelmed) + link_obs_shape = [5] * self.num_links + observation_shape = node_obs_shape * self.num_nodes + link_obs_shape + + # 2. Create observation space & zeroed out sample from space. + observation_space = spaces.MultiDiscrete(observation_shape) + initial_observation = np.zeros(len(observation_shape), dtype=np.int64) + + return observation_space, initial_observation + def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: """Build the observation space based on network laydown and provide initial obs. @@ -648,163 +737,154 @@ class Primaite(Env): `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` attributes to figure out the correct shape and format for the observation space. + :raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType` :return: Gym observation space :rtype: gym.spaces.Space :return: Initial observation with all entires set to 0 :rtype: numpy.Array """ if self.observation_type == ObservationType.BOX: - _LOGGER.info("Observation space type BOX selected") - - # 1. Determine observation shape from laydown - num_items = self.num_links + self.num_nodes - num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - observation_shape = (num_items, num_observation_parameters) - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.Box( - low=0, - high=self.OBSERVATION_SPACE_HIGH_VALUE, - shape=observation_shape, - dtype=np.int64, - ) - initial_observation = np.zeros(observation_shape, dtype=np.int64) - + observation_space, initial_observation = self._init_box_observations() + return observation_space, initial_observation elif self.observation_type == ObservationType.MULTIDISCRETE: - _LOGGER.info("Observation space MULTIDISCRETE selected") - - # 1. Determine observation shape from laydown - node_obs_shape = [ - len(HardwareState) + 1, - len(SoftwareState) + 1, - len(FileSystemState) + 1, - ] - node_services = [len(SoftwareState) + 1] * self.num_services - node_obs_shape = node_obs_shape + node_services - # the magic number 5 refers to 5 states of quantisation of traffic amount. - # (zero, low, medium, high, fully utilised/overwhelmed) - link_obs_shape = [5] * self.num_links - observation_shape = node_obs_shape * self.num_nodes + link_obs_shape - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.MultiDiscrete(observation_shape) - initial_observation = np.zeros(len(observation_shape), dtype=np.int64) + ( + observation_space, + initial_observation, + ) = self._init_multidiscrete_observations() + return observation_space, initial_observation else: - raise ValueError( + errmsg = ( f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" f", got {self.observation_type} instead" ) + _LOGGER.error(errmsg) + raise ValueError(errmsg) - return observation_space, initial_observation + def _update_env_obs_box(self): + """Update the environment's observation state based on the current status of nodes and links. + + This function can only be called if the observation space setting is set to BOX. + + :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` + """ + assert self.observation_type == ObservationType.BOX + item_index = 0 + + # Do nodes first + for node_key, node in self.nodes.items(): + self.env_obs[item_index][0] = int(node.node_id) + self.env_obs[item_index][1] = node.hardware_state.value + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + self.env_obs[item_index][2] = node.software_state.value + self.env_obs[item_index][3] = node.file_system_state_observed.value + else: + self.env_obs[item_index][2] = 0 + self.env_obs[item_index][3] = 0 + service_index = 4 + if isinstance(node, ServiceNode): + for service in self.services_list: + if node.has_service(service): + self.env_obs[item_index][ + service_index + ] = node.get_service_state(service).value + else: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + else: + # Not a service node + for service in self.services_list: + self.env_obs[item_index][service_index] = 0 + service_index += 1 + item_index += 1 + + # Now do links + for link_key, link in self.links.items(): + self.env_obs[item_index][0] = int(link.get_id()) + self.env_obs[item_index][1] = 0 + self.env_obs[item_index][2] = 0 + self.env_obs[item_index][3] = 0 + protocol_list = link.get_protocol_list() + protocol_index = 0 + for protocol in protocol_list: + self.env_obs[item_index][protocol_index + 4] = protocol.get_load() + protocol_index += 1 + item_index += 1 + + def _update_env_obs_multidiscrete(self): + """Update the environment's observation state based on the current status of nodes and links. + + This function can only be called if the observation space setting is set to MULTIDISCRETE. + + :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` + """ + assert self.observation_type == ObservationType.MULTIDISCRETE + obs = [] + # 1. Set nodes + # Each node has the following variables in the observation space: + # - Hardware state + # - Software state + # - File System state + # - Service 1 state + # - Service 2 state + # - ... + # - Service N state + for node_key, node in self.nodes.items(): + hardware_state = node.hardware_state.value + software_state = 0 + file_system_state = 0 + services_states = [0] * self.num_services + + if isinstance( + node, ActiveNode + ): # ServiceNode is a subclass of ActiveNode so no need to check that also + software_state = node.software_state.value + file_system_state = node.file_system_state_observed.value + + if isinstance(node, ServiceNode): + for i, service in enumerate(self.services_list): + if node.has_service(service): + services_states[i] = node.get_service_state(service).value + + obs.extend( + [ + hardware_state, + software_state, + file_system_state, + *services_states, + ] + ) + + # 2. Set links + # Each link has just one variable in the observation space, it represents the traffic amount + # In order for the space to be fully MultiDiscrete, the amount of + # traffic on each link is quantised into a few levels: + # 0: no traffic (0% of bandwidth) + # 1: low traffic (0-33% of bandwidth) + # 2: medium traffic (33-66% of bandwidth) + # 3: high traffic (66-100% of bandwidth) + # 4: max traffic/overloaded (100% of bandwidth) + + for link_key, link in self.links.items(): + bandwidth = link.bandwidth + load = link.get_current_load() + + if load <= 0: + traffic_level = 0 + elif load >= bandwidth: + traffic_level = 4 + else: + traffic_level = (load / bandwidth) // (1 / 3) + 1 + + obs.append(int(traffic_level)) + + self.env_obs = np.asarray(obs) def update_environent_obs(self): """Updates the observation space based on the node and link status.""" if self.observation_type == ObservationType.BOX: - item_index = 0 - - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 - - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 - + self._update_env_obs_box() elif self.observation_type == ObservationType.MULTIDISCRETE: - obs = [] - # 1. Set nodes - # Each node has the following variables in the observation space: - # - Hardware state - # - Software state - # - File System state - # - Service 1 state - # - Service 2 state - # - ... - # - Service N state - for node_key, node in self.nodes.items(): - hardware_state = node.hardware_state.value - software_state = 0 - file_system_state = 0 - services_states = [0] * self.num_services - - if isinstance( - node, ActiveNode - ): # ServiceNode is a subclass of ActiveNode so no need to check that also - software_state = node.software_state.value - file_system_state = node.file_system_state_observed.value - - if isinstance(node, ServiceNode): - for i, service in enumerate(self.services_list): - if node.has_service(service): - services_states[i] = node.get_service_state(service).value - - obs.extend( - [ - hardware_state, - software_state, - file_system_state, - *services_states, - ] - ) - - # 2. Set links - # Each link has just one variable in the observation space, it represents the traffic amount - # In order for the space to be fully MultiDiscrete, the amount of - # traffic on each link is quantised into a few levels: - # 0: no traffic (0% of bandwidth) - # 1: low traffic (0-33% of bandwidth) - # 2: medium traffic (33-66% of bandwidth) - # 3: high traffic (66-100% of bandwidth) - # 4: max traffic/overloaded (100% of bandwidth) - - for link_key, link in self.links.items(): - bandwidth = link.bandwidth - load = link.get_current_load() - - if load <= 0: - traffic_level = 0 - elif load >= bandwidth: - traffic_level = 4 - else: - traffic_level = (load / bandwidth) // (1 / 3) + 1 - - obs.append(int(traffic_level)) - - self.env_obs = np.asarray(obs) + self._update_env_obs_multidiscrete() def load_config(self): """Loads config data in order to build the environment configuration.""" From 3b0d05e9c97433dd380fd02cf1301208d4735e01 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 11:02:10 +0100 Subject: [PATCH 18/18] More info in docstring --- src/primaite/environment/primaite_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 67ab5375..56893ee9 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -763,6 +763,7 @@ class Primaite(Env): def _update_env_obs_box(self): """Update the environment's observation state based on the current status of nodes and links. + The structure of the observation space is described in :func:`~_init_box_observations` This function can only be called if the observation space setting is set to BOX. :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` @@ -813,6 +814,7 @@ class Primaite(Env): def _update_env_obs_multidiscrete(self): """Update the environment's observation state based on the current status of nodes and links. + The structure of the observation space is described in :func:`~_init_multidiscrete_observations` This function can only be called if the observation space setting is set to MULTIDISCRETE. :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``