Fix observation representation in transactions

This commit is contained in:
Marek Wolan
2023-06-29 15:26:07 +01:00
parent e086d419ad
commit c9f58fdb2a
4 changed files with 150 additions and 59 deletions

View File

@@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC):
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
self.structure: list[str]
return NotImplemented
@abstractmethod
@@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC):
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
@abstractmethod
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
return NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes and links as rows and hardware/software status as cols.
@@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent):
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
@@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_index += 1
item_index += 1
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
structure = []
for i, node in enumerate(nodes):
node_id = node.node_id
node_labels = [
f"node_{node_id}_id",
f"node_{node_id}_hardware_status",
f"node_{node_id}_os_status",
f"node_{node_id}_fs_status",
]
for j, serv in enumerate(self.env.services_list):
node_labels.append(f"node_{node_id}_service_{serv}_status")
structure.extend(node_labels)
for i, link in enumerate(links):
link_id = link.id
link_labels = [
f"link_{link_id}_id",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
]
for j, serv in enumerate(self.env.services_list):
link_labels.append(f"node_{node_id}_service_{serv}_load")
structure.extend(link_labels)
return structure
class NodeStatuses(AbstractObservationComponent):
"""Flat list of nodes' hardware, OS, file system, and service states.
@@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent):
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
@@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent):
)
self.current_observation[:] = obs
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
structure = []
for _, node in self.env.nodes.items():
node_id = node.node_id
structure.append(f"node_{node_id}_hardware_state_NONE")
for state in HardwareState:
structure.append(f"node_{node_id}_hardware_state_{state.name}")
structure.append(f"node_{node_id}_software_state_NONE")
for state in SoftwareState:
structure.append(f"node_{node_id}_software_state_{state.name}")
structure.append(f"node_{node_id}_file_system_state_NONE")
for state in FileSystemState:
structure.append(f"node_{node_id}_file_system_state_{state.name}")
for service in services:
structure.append(f"node_{node_id}_service_{service}_state_NONE")
for state in SoftwareState:
structure.append(
f"node_{node_id}_service_{service}_state_{state.name}"
)
return structure
class LinkTrafficLevels(AbstractObservationComponent):
"""Flat list of traffic levels encoded into banded categories.
@@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
@@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.current_observation[:] = obs
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
link_id = link.id
if self._combine_service_traffic:
protocols = ["overall"]
else:
protocols = [protocol.name for protocol in link.protocol_list]
for p in protocols:
for i in range(self._quantisation_levels):
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
return structure
class ObservationsHandler:
"""Component-based observation space handler.
@@ -312,11 +396,15 @@ class ObservationsHandler:
def __init__(self):
self.registered_obs_components: List[AbstractObservationComponent] = []
# need to keep track of the flattened and unflattened version of the space (if there is one)
self.space: spaces.Space
self.unflattened_space: spaces.Space
# internal the observation space (unflattened version of space if flatten=True)
self._space: spaces.Space
# flattened version of the observation space
self._flat_space: spaces.Space
self._observation: Union[Tuple[np.ndarray], np.ndarray]
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
self.flatten: bool = False
def update_obs(self):
@@ -326,17 +414,11 @@ class ObservationsHandler:
obs.update()
current_obs.append(obs.current_observation)
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1:
self.current_observation = current_obs[0]
# If there are many compoenents, the space may need to be flattened
self._observation = current_obs[0]
else:
if self.flatten:
self.current_observation = spaces.flatten(
self.unflattened_space, tuple(current_obs)
)
else:
self.current_observation = tuple(current_obs)
self._observation = tuple(current_obs)
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
@@ -363,15 +445,28 @@ class ObservationsHandler:
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
# If there is only one component, don't use a tuple space, just pass through that component's space.
# if there are multiple components, build a composite tuple space
if len(component_spaces) == 1:
self.space = component_spaces[0]
self._space = component_spaces[0]
else:
self.unflattened_space = spaces.Tuple(component_spaces)
if self.flatten:
self.space = spaces.flatten_space(spaces.Tuple(component_spaces))
else:
self.space = self.unflattened_space
self._space = spaces.Tuple(component_spaces)
self._flat_space = spaces.flatten_space(self._space)
@property
def space(self):
"""Observation space, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_space
else:
return self._space
@property
def current_observation(self):
"""Current observation, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_observation
else:
return self._observation
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
@@ -417,3 +512,17 @@ class ObservationsHandler:
handler.update_obs()
return handler
def describe_structure(self):
"""Create a list of names for the features of the obs space.
The order of labels follows the flattened version of the space.
"""
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
# to fake it. each component has to just hard-code the expected label order after flattening...
labels = []
for obs_comp in self.registered_obs_components:
labels.extend(obs_comp.structure)
return labels

View File

@@ -318,7 +318,8 @@ class Primaite(Env):
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
)
# Load the initial observation space into the transaction
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
transaction.set_obs_space_pre(self.obs_handler._flat_observation)
# Load the action space into the transaction
transaction.set_action_space(copy.deepcopy(action))
@@ -400,7 +401,7 @@ class Primaite(Env):
# 7. Update env_obs
self.update_environent_obs()
# Load the new observation space into the transaction
transaction.set_obs_space_post(copy.deepcopy(self.env_obs))
transaction.set_obs_space_post(self.obs_handler._flat_observation)
# 8. Add the transaction to the list of transactions
self.transaction_list.append(copy.deepcopy(transaction))

View File

@@ -325,6 +325,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
obs_space_description=env.obs_handler.describe_structure(),
)
print("Updating Session Metadata file...")

View File

@@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space):
return [str(_action_space)]
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
"""
Turns observation space into a string array so it can be saved to csv.
Args:
_obs_space: The observation space
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
_obs_features: The number of features associated with the asset
"""
return_array = []
for x in range(_obs_assets):
for y in range(_obs_features):
return_array.append(str(_obs_space[x][y]))
return return_array
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
def write_transaction_to_file(
transaction_list,
session_path: Path,
timestamp_str: str,
obs_space_description: list,
):
"""
Writes transaction logs to file to support training evaluation.
@@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
# This will be tied into the PrimAITE Use Case so that they make sense
template_transation = transaction_list[0]
action_length = template_transation.action_space.size
obs_shape = template_transation.obs_space_post.shape
obs_assets = template_transation.obs_space_post.shape[0]
if len(obs_shape) == 1:
# bit of a workaround but I think the way transactions are written will change soon
obs_features = 1
else:
obs_features = template_transation.obs_space_post.shape[1]
# obs_shape = template_transation.obs_space_post.shape
# obs_assets = template_transation.obs_space_post.shape[0]
# if len(obs_shape) == 1:
# bit of a workaround but I think the way transactions are written will change soon
# obs_features = 1
# else:
# obs_features = template_transation.obs_space_post.shape[1]
# Create the action space headers array
action_header = []
@@ -70,12 +58,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
action_header.append("AS_" + str(x))
# Create the observation space headers array
obs_header_initial = []
obs_header_new = []
for x in range(obs_assets):
for y in range(obs_features):
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
obs_header_initial = [f"pre_{o}" for o in obs_space_description]
obs_header_new = [f"post_{o}" for o in obs_space_description]
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
@@ -98,12 +82,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
csv_data = (
csv_data
+ turn_action_space_to_array(transaction.action_space)
+ turn_obs_space_to_array(
transaction.obs_space_pre, obs_assets, obs_features
)
+ turn_obs_space_to_array(
transaction.obs_space_post, obs_assets, obs_features
)
+ transaction.obs_space_pre.tolist()
+ transaction.obs_space_post.tolist()
)
csv_writer.writerow(csv_data)