Fix observation representation in transactions
This commit is contained in:
@@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC):
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
self.structure: list[str]
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
@@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC):
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""Table with nodes and links as rows and hardware/software status as cols.
|
||||
@@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
|
||||
structure = []
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
node_id = node.node_id
|
||||
node_labels = [
|
||||
f"node_{node_id}_id",
|
||||
f"node_{node_id}_hardware_status",
|
||||
f"node_{node_id}_os_status",
|
||||
f"node_{node_id}_fs_status",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
node_labels.append(f"node_{node_id}_service_{serv}_status")
|
||||
|
||||
structure.extend(node_labels)
|
||||
|
||||
for i, link in enumerate(links):
|
||||
link_id = link.id
|
||||
link_labels = [
|
||||
f"link_{link_id}_id",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
link_labels.append(f"node_{node_id}_service_{serv}_load")
|
||||
|
||||
structure.extend(link_labels)
|
||||
return structure
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""Flat list of nodes' hardware, OS, file system, and service states.
|
||||
@@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
@@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
for state in HardwareState:
|
||||
structure.append(f"node_{node_id}_hardware_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_software_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_software_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_file_system_state_NONE")
|
||||
for state in FileSystemState:
|
||||
structure.append(f"node_{node_id}_file_system_state_{state.name}")
|
||||
for service in services:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(
|
||||
f"node_{node_id}_service_{service}_state_{state.name}"
|
||||
)
|
||||
return structure
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""Flat list of traffic levels encoded into banded categories.
|
||||
@@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
link_id = link.id
|
||||
if self._combine_service_traffic:
|
||||
protocols = ["overall"]
|
||||
else:
|
||||
protocols = [protocol.name for protocol in link.protocol_list]
|
||||
|
||||
for p in protocols:
|
||||
for i in range(self._quantisation_levels):
|
||||
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
@@ -312,11 +396,15 @@ class ObservationsHandler:
|
||||
def __init__(self):
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
# need to keep track of the flattened and unflattened version of the space (if there is one)
|
||||
self.space: spaces.Space
|
||||
self.unflattened_space: spaces.Space
|
||||
# internal the observation space (unflattened version of space if flatten=True)
|
||||
self._space: spaces.Space
|
||||
# flattened version of the observation space
|
||||
self._flat_space: spaces.Space
|
||||
|
||||
self._observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self):
|
||||
@@ -326,17 +414,11 @@ class ObservationsHandler:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
|
||||
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
||||
if len(current_obs) == 1:
|
||||
self.current_observation = current_obs[0]
|
||||
# If there are many compoenents, the space may need to be flattened
|
||||
self._observation = current_obs[0]
|
||||
else:
|
||||
if self.flatten:
|
||||
self.current_observation = spaces.flatten(
|
||||
self.unflattened_space, tuple(current_obs)
|
||||
)
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""Add a component for this handler to track.
|
||||
@@ -363,15 +445,28 @@ class ObservationsHandler:
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
|
||||
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
||||
# if there are multiple components, build a composite tuple space
|
||||
if len(component_spaces) == 1:
|
||||
self.space = component_spaces[0]
|
||||
self._space = component_spaces[0]
|
||||
else:
|
||||
self.unflattened_space = spaces.Tuple(component_spaces)
|
||||
if self.flatten:
|
||||
self.space = spaces.flatten_space(spaces.Tuple(component_spaces))
|
||||
else:
|
||||
self.space = self.unflattened_space
|
||||
self._space = spaces.Tuple(component_spaces)
|
||||
self._flat_space = spaces.flatten_space(self._space)
|
||||
|
||||
@property
|
||||
def space(self):
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self):
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
@@ -417,3 +512,17 @@ class ObservationsHandler:
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self):
|
||||
"""Create a list of names for the features of the obs space.
|
||||
|
||||
The order of labels follows the flattened version of the space.
|
||||
"""
|
||||
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
|
||||
# to fake it. each component has to just hard-code the expected label order after flattening...
|
||||
|
||||
labels = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
labels.extend(obs_comp.structure)
|
||||
|
||||
return labels
|
||||
|
||||
@@ -318,7 +318,8 @@ class Primaite(Env):
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
|
||||
transaction.set_obs_space_pre(self.obs_handler._flat_observation)
|
||||
|
||||
# Load the action space into the transaction
|
||||
transaction.set_action_space(copy.deepcopy(action))
|
||||
|
||||
@@ -400,7 +401,7 @@ class Primaite(Env):
|
||||
# 7. Update env_obs
|
||||
self.update_environent_obs()
|
||||
# Load the new observation space into the transaction
|
||||
transaction.set_obs_space_post(copy.deepcopy(self.env_obs))
|
||||
transaction.set_obs_space_post(self.obs_handler._flat_observation)
|
||||
|
||||
# 8. Add the transaction to the list of transactions
|
||||
self.transaction_list.append(copy.deepcopy(transaction))
|
||||
|
||||
@@ -325,6 +325,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
obs_space_description=env.obs_handler.describe_structure(),
|
||||
)
|
||||
|
||||
print("Updating Session Metadata file...")
|
||||
|
||||
@@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space):
|
||||
return [str(_action_space)]
|
||||
|
||||
|
||||
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_obs_space: The observation space
|
||||
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
_obs_features: The number of features associated with the asset
|
||||
"""
|
||||
return_array = []
|
||||
for x in range(_obs_assets):
|
||||
for y in range(_obs_features):
|
||||
return_array.append(str(_obs_space[x][y]))
|
||||
|
||||
return return_array
|
||||
|
||||
|
||||
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
|
||||
def write_transaction_to_file(
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
obs_space_description: list,
|
||||
):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
@@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
template_transation = transaction_list[0]
|
||||
action_length = template_transation.action_space.size
|
||||
obs_shape = template_transation.obs_space_post.shape
|
||||
obs_assets = template_transation.obs_space_post.shape[0]
|
||||
if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
obs_features = 1
|
||||
else:
|
||||
obs_features = template_transation.obs_space_post.shape[1]
|
||||
# obs_shape = template_transation.obs_space_post.shape
|
||||
# obs_assets = template_transation.obs_space_post.shape[0]
|
||||
# if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
# obs_features = 1
|
||||
# else:
|
||||
# obs_features = template_transation.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
@@ -70,12 +58,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = []
|
||||
obs_header_new = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
||||
obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||
obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
@@ -98,12 +82,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_pre, obs_assets, obs_features
|
||||
)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_post, obs_assets, obs_features
|
||||
)
|
||||
+ transaction.obs_space_pre.tolist()
|
||||
+ transaction.obs_space_post.tolist()
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user