Merge remote-tracking branch 'origin/dev' into feature/917_Integrate_with_RLLib

# Conflicts:
#	src/primaite/config/_package_data/training/training_config_main.yaml
#	src/primaite/environment/primaite_env.py
#	src/primaite/main.py
#	src/primaite/transactions/transaction.py
#	src/primaite/transactions/transactions_to_file.py
This commit is contained in:
Chris McCarthy
2023-07-03 19:51:52 +01:00
6 changed files with 251 additions and 18 deletions

View File

@@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session:
* Timestamp
* Episode number
* Step number
* Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y
* Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y
* Initial observation space (what the blue agent observed when it decided its action)
* Reward value
* Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
**Diagrams**

View File

@@ -41,8 +41,14 @@ hard_coded_agent_view: FULL
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: ANY
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10

View File

@@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC):
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
self.structure: List[str]
return NotImplemented
@abstractmethod
@@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC):
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
@abstractmethod
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
return NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes and links as rows and hardware/software status as cols.
@@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent):
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
@@ -125,6 +133,40 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_index += 1
item_index += 1
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
structure = []
for i, node in enumerate(nodes):
node_id = node.node_id
node_labels = [
f"node_{node_id}_id",
f"node_{node_id}_hardware_status",
f"node_{node_id}_os_status",
f"node_{node_id}_fs_status",
]
for j, serv in enumerate(self.env.services_list):
node_labels.append(f"node_{node_id}_service_{serv}_status")
structure.extend(node_labels)
for i, link in enumerate(links):
link_id = link.id
link_labels = [
f"link_{link_id}_id",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
]
for j, serv in enumerate(self.env.services_list):
link_labels.append(f"link_{link_id}_service_{serv}_load")
structure.extend(link_labels)
return structure
class NodeStatuses(AbstractObservationComponent):
"""Flat list of nodes' hardware, OS, file system, and service states.
@@ -173,6 +215,7 @@ class NodeStatuses(AbstractObservationComponent):
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
@@ -204,6 +247,30 @@ class NodeStatuses(AbstractObservationComponent):
)
self.current_observation[:] = obs
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
structure = []
for _, node in self.env.nodes.items():
node_id = node.node_id
structure.append(f"node_{node_id}_hardware_state_NONE")
for state in HardwareState:
structure.append(f"node_{node_id}_hardware_state_{state.name}")
structure.append(f"node_{node_id}_software_state_NONE")
for state in SoftwareState:
structure.append(f"node_{node_id}_software_state_{state.name}")
structure.append(f"node_{node_id}_file_system_state_NONE")
for state in FileSystemState:
structure.append(f"node_{node_id}_file_system_state_{state.name}")
for service in services:
structure.append(f"node_{node_id}_service_{service}_state_NONE")
for state in SoftwareState:
structure.append(
f"node_{node_id}_service_{service}_state_{state.name}"
)
return structure
class LinkTrafficLevels(AbstractObservationComponent):
"""Flat list of traffic levels encoded into banded categories.
@@ -265,6 +332,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self):
"""Update the observation based on current environment state.
@@ -290,6 +359,21 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.current_observation[:] = obs
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
link_id = link.id
if self._combine_service_traffic:
protocols = ["overall"]
else:
protocols = [protocol.name for protocol in link.protocol_list]
for p in protocols:
for i in range(self._quantisation_levels):
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
return structure
class ObservationsHandler:
"""Component-based observation space handler.
@@ -306,8 +390,17 @@ class ObservationsHandler:
def __init__(self):
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
# internal the observation space (unflattened version of space if flatten=True)
self._space: spaces.Space
# flattened version of the observation space
self._flat_space: spaces.Space
self._observation: Union[Tuple[np.ndarray], np.ndarray]
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
self.flatten: bool = False
def update_obs(self):
"""Fetch fresh information about the environment."""
@@ -316,12 +409,11 @@ class ObservationsHandler:
obs.update()
current_obs.append(obs.current_observation)
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1:
self.current_observation = current_obs[0]
self._observation = current_obs[0]
else:
self.current_observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
self._observation = tuple(current_obs)
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
@@ -348,12 +440,31 @@ class ObservationsHandler:
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
# If there is only one component, don't use a tuple space, just pass through that component's space.
# if there are multiple components, build a composite tuple space
if len(component_spaces) == 1:
self.space = component_spaces[0]
self._space = component_spaces[0]
else:
self.space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
self._space = spaces.Tuple(component_spaces)
if len(component_spaces) > 0:
self._flat_space = spaces.flatten_space(self._space)
else:
self._flat_space = spaces.Box(0, 1, (0,))
@property
def space(self):
"""Observation space, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_space
else:
return self._space
@property
def current_observation(self):
"""Current observation, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_observation
else:
return self._observation
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
@@ -383,6 +494,9 @@ class ObservationsHandler:
# Instantiate the handler
handler = cls()
if obs_space_config.get("flatten"):
handler.flatten = True
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
@@ -396,3 +510,17 @@ class ObservationsHandler:
handler.update_obs()
return handler
def describe_structure(self):
"""Create a list of names for the features of the obs space.
The order of labels follows the flattened version of the space.
"""
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
# to fake it. each component has to just hard-code the expected label order after flattening...
labels = []
for obs_comp in self.registered_obs_components:
labels.extend(obs_comp.structure)
return labels

View File

@@ -320,7 +320,8 @@ class Primaite(Env):
# Create a Transaction (metric) object for this step
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
# Load the initial observation space into the transaction
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
transaction.set_obs_space(self.obs_handler._flat_observation)
# Load the action space into the transaction
transaction.action_space = copy.deepcopy(action)
@@ -399,8 +400,6 @@ class Primaite(Env):
# 7. Update env_obs
self.update_environent_obs()
# Load the new observation space into the transaction
transaction.obs_space_post = copy.deepcopy(self.env_obs)
# Write transaction to file
if self.actual_episode_count > 0:

View File

@@ -31,6 +31,16 @@ class Transaction(object):
self.action_space = None
"The action space invoked by the agent"
def set_obs_space(self, _obs_space):
"""
Sets the observation space (pre).
Args:
_obs_space_pre: The observation space before any actions are taken
"""
self.obs_space = _obs_space
def as_csv_data(self) -> Tuple[List, List]:
"""
Converts the Transaction to a csv data row and provides a header.

View File

@@ -0,0 +1,91 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Writes the Transaction log list out to file for evaluation to utilse."""
import csv
from pathlib import Path
from primaite import getLogger
_LOGGER = getLogger(__name__)
def turn_action_space_to_array(_action_space):
"""
Turns action space into a string array so it can be saved to csv.
Args:
_action_space: The action space.
"""
if isinstance(_action_space, list):
return [str(i) for i in _action_space]
else:
return [str(_action_space)]
def write_transaction_to_file(
transaction_list,
session_path: Path,
timestamp_str: str,
obs_space_description: list,
):
"""
Writes transaction logs to file to support training evaluation.
:param transaction_list: The list of transactions from all steps and all
episodes.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
# Get the first transaction and use it to determine the makeup of the
# observation space and action space
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
# space as "AS_1"
# This will be tied into the PrimAITE Use Case so that they make sense
template_transation = transaction_list[0]
action_length = template_transation.action_space.size
# obs_shape = template_transation.obs_space_post.shape
# obs_assets = template_transation.obs_space_post.shape[0]
# if len(obs_shape) == 1:
# bit of a workaround but I think the way transactions are written will change soon
# obs_features = 1
# else:
# obs_features = template_transation.obs_space_post.shape[1]
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
# Create the observation space headers array
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
# obs_header_new = [f"post_{o}" for o in obs_space_description]
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_space_description
try:
filename = session_path / f"all_transactions_{timestamp_str}.csv"
_LOGGER.debug(f"Saving transaction logs: {filename}")
csv_file = open(filename, "w", encoding="UTF8", newline="")
csv_writer = csv.writer(csv_file)
csv_writer.writerow(header)
for transaction in transaction_list:
csv_data = [
str(transaction.timestamp),
str(transaction.episode_number),
str(transaction.step_number),
str(transaction.reward),
]
csv_data = (
csv_data
+ turn_action_space_to_array(transaction.action_space)
+ transaction.obs_space.tolist()
)
csv_writer.writerow(csv_data)
csv_file.close()
except Exception:
_LOGGER.error("Could not save the transaction file", exc_info=True)