Merged PR 100: Flatten observation spaces and improve transactions for observations
## Summary *Replace this text with an explanation of what the changes are and how you implemented them. Can this impact any other parts of the codebase that we should keep in mind?* ## Test process I ran some training sessions to ensure that the outputted transaction list has the correct data and headers. I was also able to verify that the agent is able to train with observation spaces containing multiple components. I trained an agent on laydown 3 with NODE_LINK_TABLE both as normal and flattened spaces and the agent learned in both instances.   ## Checklist - [x] This PR is linked to a **work item** - [x] I have performed **self-review** of the code - [ ] I have written **tests** for any new functionality added with this PR - [ ] I have updated the **documentation** if this PR changes or adds functionality - [x] I have run **pre-commit** checks for code style Related work items: #1558
This commit is contained in:
@@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session:
|
||||
* Timestamp
|
||||
* Episode number
|
||||
* Step number
|
||||
* Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y
|
||||
* Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y
|
||||
* Initial observation space (what the blue agent observed when it decided its action)
|
||||
* Reward value
|
||||
* Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||
|
||||
**Diagrams**
|
||||
|
||||
|
||||
@@ -16,6 +16,13 @@ random_red_agent: False
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
# Number of time_steps per episode
|
||||
|
||||
@@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC):
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
self.structure: List[str]
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
@@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC):
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""Table with nodes and links as rows and hardware/software status as cols.
|
||||
@@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
|
||||
structure = []
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
node_id = node.node_id
|
||||
node_labels = [
|
||||
f"node_{node_id}_id",
|
||||
f"node_{node_id}_hardware_status",
|
||||
f"node_{node_id}_os_status",
|
||||
f"node_{node_id}_fs_status",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
node_labels.append(f"node_{node_id}_service_{serv}_status")
|
||||
|
||||
structure.extend(node_labels)
|
||||
|
||||
for i, link in enumerate(links):
|
||||
link_id = link.id
|
||||
link_labels = [
|
||||
f"link_{link_id}_id",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
link_labels.append(f"link_{link_id}_service_{serv}_load")
|
||||
|
||||
structure.extend(link_labels)
|
||||
return structure
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""Flat list of nodes' hardware, OS, file system, and service states.
|
||||
@@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
@@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
for state in HardwareState:
|
||||
structure.append(f"node_{node_id}_hardware_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_software_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_software_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_file_system_state_NONE")
|
||||
for state in FileSystemState:
|
||||
structure.append(f"node_{node_id}_file_system_state_{state.name}")
|
||||
for service in services:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(
|
||||
f"node_{node_id}_service_{service}_state_{state.name}"
|
||||
)
|
||||
return structure
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""Flat list of traffic levels encoded into banded categories.
|
||||
@@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
link_id = link.id
|
||||
if self._combine_service_traffic:
|
||||
protocols = ["overall"]
|
||||
else:
|
||||
protocols = [protocol.name for protocol in link.protocol_list]
|
||||
|
||||
for p in protocols:
|
||||
for i in range(self._quantisation_levels):
|
||||
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
@@ -311,8 +395,17 @@ class ObservationsHandler:
|
||||
|
||||
def __init__(self):
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
self.space: spaces.Space
|
||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
|
||||
# internal the observation space (unflattened version of space if flatten=True)
|
||||
self._space: spaces.Space
|
||||
# flattened version of the observation space
|
||||
self._flat_space: spaces.Space
|
||||
|
||||
self._observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self):
|
||||
"""Fetch fresh information about the environment."""
|
||||
@@ -321,12 +414,11 @@ class ObservationsHandler:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
|
||||
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
||||
if len(current_obs) == 1:
|
||||
self.current_observation = current_obs[0]
|
||||
self._observation = current_obs[0]
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""Add a component for this handler to track.
|
||||
@@ -353,12 +445,31 @@ class ObservationsHandler:
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
|
||||
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
||||
# if there are multiple components, build a composite tuple space
|
||||
if len(component_spaces) == 1:
|
||||
self.space = component_spaces[0]
|
||||
self._space = component_spaces[0]
|
||||
else:
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
self._space = spaces.Tuple(component_spaces)
|
||||
if len(component_spaces) > 0:
|
||||
self._flat_space = spaces.flatten_space(self._space)
|
||||
else:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self):
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self):
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
@@ -388,6 +499,9 @@ class ObservationsHandler:
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
if obs_space_config.get("flatten"):
|
||||
handler.flatten = True
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
@@ -401,3 +515,17 @@ class ObservationsHandler:
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self):
|
||||
"""Create a list of names for the features of the obs space.
|
||||
|
||||
The order of labels follows the flattened version of the space.
|
||||
"""
|
||||
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
|
||||
# to fake it. each component has to just hard-code the expected label order after flattening...
|
||||
|
||||
labels = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
labels.extend(obs_comp.structure)
|
||||
|
||||
return labels
|
||||
|
||||
@@ -324,7 +324,8 @@ class Primaite(Env):
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
|
||||
transaction.set_obs_space(self.obs_handler._flat_observation)
|
||||
|
||||
# Load the action space into the transaction
|
||||
transaction.set_action_space(copy.deepcopy(action))
|
||||
|
||||
@@ -405,8 +406,6 @@ class Primaite(Env):
|
||||
|
||||
# 7. Update env_obs
|
||||
self.update_environent_obs()
|
||||
# Load the new observation space into the transaction
|
||||
transaction.set_obs_space_post(copy.deepcopy(self.env_obs))
|
||||
|
||||
# 8. Add the transaction to the list of transactions
|
||||
self.transaction_list.append(copy.deepcopy(transaction))
|
||||
|
||||
@@ -325,6 +325,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
obs_space_description=env.obs_handler.describe_structure(),
|
||||
)
|
||||
|
||||
print("Updating Session Metadata file...")
|
||||
|
||||
@@ -20,23 +20,14 @@ class Transaction(object):
|
||||
self.episode_number = _episode_number
|
||||
self.step_number = _step_number
|
||||
|
||||
def set_obs_space_pre(self, _obs_space_pre):
|
||||
def set_obs_space(self, _obs_space):
|
||||
"""
|
||||
Sets the observation space (pre).
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
"""
|
||||
self.obs_space_pre = _obs_space_pre
|
||||
|
||||
def set_obs_space_post(self, _obs_space_post):
|
||||
"""
|
||||
Sets the observation space (post).
|
||||
|
||||
Args:
|
||||
_obs_space_post: The observation space after any actions are taken
|
||||
"""
|
||||
self.obs_space_post = _obs_space_post
|
||||
self.obs_space = _obs_space
|
||||
|
||||
def set_reward(self, _reward):
|
||||
"""
|
||||
|
||||
@@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space):
|
||||
return [str(_action_space)]
|
||||
|
||||
|
||||
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_obs_space: The observation space
|
||||
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
_obs_features: The number of features associated with the asset
|
||||
"""
|
||||
return_array = []
|
||||
for x in range(_obs_assets):
|
||||
for y in range(_obs_features):
|
||||
return_array.append(str(_obs_space[x][y]))
|
||||
|
||||
return return_array
|
||||
|
||||
|
||||
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
|
||||
def write_transaction_to_file(
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
obs_space_description: list,
|
||||
):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
@@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
template_transation = transaction_list[0]
|
||||
action_length = template_transation.action_space.size
|
||||
obs_shape = template_transation.obs_space_post.shape
|
||||
obs_assets = template_transation.obs_space_post.shape[0]
|
||||
if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
obs_features = 1
|
||||
else:
|
||||
obs_features = template_transation.obs_space_post.shape[1]
|
||||
# obs_shape = template_transation.obs_space_post.shape
|
||||
# obs_assets = template_transation.obs_space_post.shape[0]
|
||||
# if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
# obs_features = 1
|
||||
# else:
|
||||
# obs_features = template_transation.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
@@ -70,16 +58,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = []
|
||||
obs_header_new = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
||||
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||
# obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
header = header + action_header + obs_space_description
|
||||
|
||||
try:
|
||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||
@@ -98,12 +82,7 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_pre, obs_assets, obs_features
|
||||
)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_post, obs_assets, obs_features
|
||||
)
|
||||
+ transaction.obs_space.tolist()
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user