From 9a231821ea185145ac0b159c059d35f8012f267c Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 31 May 2023 13:15:25 +0100 Subject: [PATCH] 1443 - added changes from ADSP to observation space in primaite_env.py --- src/primaite/environment/primaite_env.py | 41 +++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0ebcd973..84b485bd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -207,6 +207,7 @@ class Primaite(Env): # Calculate the number of items that need to be included in the # observation space + """ num_items = self.num_links + self.num_nodes # Set the number of observation parameters, being # of services plus id, # hardware state, file system state and SoftwareState (i.e. 4) @@ -221,6 +222,23 @@ class Primaite(Env): shape=self.observation_shape, dtype=np.int64, ) + """ + self.num_observation_parameters = ( + self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS + ) + # Define the observation space: + # There are: + # 4 Operating States (ON/OFF/RESETTING) + NONE (0) + # 4 OS States (GOOD/PATCHING/COMPROMISED) + NONE + # 5 Service States (NONE/GOOD/PATCHING/COMPROMISED/OVERWHELMED) + NONE + # There can be any number of services + # There are 5 node types No Traffic, Low Traffic, Medium Traffic, High Traffic, Overloaded/max traffic + self.observation_space = spaces.MultiDiscrete( + ([4, 4] + [5] * self.num_services) * self.num_nodes + [5] * self.num_links + ) + + # Define the observation shape + self.observation_shape = self.observation_space.sample().shape # This is the observation that is sent back via the rest and step functions self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) @@ -396,7 +414,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -678,6 +696,19 @@ class Primaite(Env): def update_environent_obs(self): """Updates the observation space based on the node and link status.""" + # Convert back to more readable, original format + reshaped_nodes = self.env_obs[: -self.num_links].reshape( + self.num_nodes, self.num_services + 2 + ) + + # Add empty links back and add node ID back + s = np.zeros( + [reshaped_nodes.shape[0] + self.num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) + s[:, 0] = range(1, self.num_nodes + self.num_links + 1) # Adding ID back + s[: self.num_nodes, 1:] = reshaped_nodes # put values back in + self.env_obs = s item_index = 0 # Do nodes first @@ -720,6 +751,13 @@ class Primaite(Env): protocol_index += 1 item_index += 1 + # Remove ID columns, remove links and flatten to 1D array + node_obs = self.env_obs[: self.num_nodes, 1:].flatten() + # Remove ID, remove all data except link traffic status + link_obs = self.env_obs[self.num_nodes :, 3:].flatten() + # Combine nodes and links + self.env_obs = np.append(node_obs, link_obs) + def load_config(self): """Loads config data in order to build the environment configuration.""" for item in self.config_data: @@ -1187,6 +1225,7 @@ class Primaite(Env): """ # reserve 0 action to be a nothing action + # Used to be {0: [1, 0, 0, 0]} actions = {0: [1, 0, 0, 0]} action_key = 1