1443 - added changes from ADSP to observation space in primaite_env.py

This commit is contained in:
SunilSamra
2023-05-31 13:15:25 +01:00
parent c6db98c1c2
commit 9a231821ea

View File

@@ -207,6 +207,7 @@ class Primaite(Env):
# Calculate the number of items that need to be included in the
# observation space
"""
num_items = self.num_links + self.num_nodes
# Set the number of observation parameters, being # of services plus id,
# hardware state, file system state and SoftwareState (i.e. 4)
@@ -221,6 +222,23 @@ class Primaite(Env):
shape=self.observation_shape,
dtype=np.int64,
)
"""
self.num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
# Define the observation space:
# There are:
# 4 Operating States (ON/OFF/RESETTING) + NONE (0)
# 4 OS States (GOOD/PATCHING/COMPROMISED) + NONE
# 5 Service States (NONE/GOOD/PATCHING/COMPROMISED/OVERWHELMED) + NONE
# There can be any number of services
# There are 5 node types No Traffic, Low Traffic, Medium Traffic, High Traffic, Overloaded/max traffic
self.observation_space = spaces.MultiDiscrete(
([4, 4] + [5] * self.num_services) * self.num_nodes + [5] * self.num_links
)
# Define the observation shape
self.observation_shape = self.observation_space.sample().shape
# This is the observation that is sent back via the rest and step functions
self.env_obs = np.zeros(self.observation_shape, dtype=np.int64)
@@ -396,7 +414,7 @@ class Primaite(Env):
self.step_count,
self.config_values,
)
# print(f" Step {self.step_count} Reward: {str(reward)}")
print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -678,6 +696,19 @@ class Primaite(Env):
def update_environent_obs(self):
"""Updates the observation space based on the node and link status."""
# Convert back to more readable, original format
reshaped_nodes = self.env_obs[: -self.num_links].reshape(
self.num_nodes, self.num_services + 2
)
# Add empty links back and add node ID back
s = np.zeros(
[reshaped_nodes.shape[0] + self.num_links, reshaped_nodes.shape[1] + 1],
dtype=np.int64,
)
s[:, 0] = range(1, self.num_nodes + self.num_links + 1) # Adding ID back
s[: self.num_nodes, 1:] = reshaped_nodes # put values back in
self.env_obs = s
item_index = 0
# Do nodes first
@@ -720,6 +751,13 @@ class Primaite(Env):
protocol_index += 1
item_index += 1
# Remove ID columns, remove links and flatten to 1D array
node_obs = self.env_obs[: self.num_nodes, 1:].flatten()
# Remove ID, remove all data except link traffic status
link_obs = self.env_obs[self.num_nodes :, 3:].flatten()
# Combine nodes and links
self.env_obs = np.append(node_obs, link_obs)
def load_config(self):
"""Loads config data in order to build the environment configuration."""
for item in self.config_data:
@@ -1187,6 +1225,7 @@ class Primaite(Env):
"""
# reserve 0 action to be a nothing action
# Used to be {0: [1, 0, 0, 0]}
actions = {0: [1, 0, 0, 0]}
action_key = 1