Merge remote-tracking branch 'origin/dev' into feature/917_Integrate_with_RLLib
# Conflicts: # src/primaite/config/_package_data/training/training_config_main.yaml # src/primaite/environment/primaite_env.py # src/primaite/main.py # src/primaite/transactions/transaction.py # src/primaite/transactions/transactions_to_file.py
This commit is contained in:
@@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session:
|
||||
* Timestamp
|
||||
* Episode number
|
||||
* Step number
|
||||
* Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y
|
||||
* Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y
|
||||
* Initial observation space (what the blue agent observed when it decided its action)
|
||||
* Reward value
|
||||
* Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||
|
||||
**Diagrams**
|
||||
|
||||
|
||||
@@ -41,8 +41,14 @@ hard_coded_agent_view: FULL
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC):
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
self.structure: List[str]
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
@@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC):
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""Table with nodes and links as rows and hardware/software status as cols.
|
||||
@@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -125,6 +133,40 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
|
||||
structure = []
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
node_id = node.node_id
|
||||
node_labels = [
|
||||
f"node_{node_id}_id",
|
||||
f"node_{node_id}_hardware_status",
|
||||
f"node_{node_id}_os_status",
|
||||
f"node_{node_id}_fs_status",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
node_labels.append(f"node_{node_id}_service_{serv}_status")
|
||||
|
||||
structure.extend(node_labels)
|
||||
|
||||
for i, link in enumerate(links):
|
||||
link_id = link.id
|
||||
link_labels = [
|
||||
f"link_{link_id}_id",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
link_labels.append(f"link_{link_id}_service_{serv}_load")
|
||||
|
||||
structure.extend(link_labels)
|
||||
return structure
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""Flat list of nodes' hardware, OS, file system, and service states.
|
||||
@@ -173,6 +215,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
@@ -204,6 +247,30 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
for state in HardwareState:
|
||||
structure.append(f"node_{node_id}_hardware_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_software_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_software_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_file_system_state_NONE")
|
||||
for state in FileSystemState:
|
||||
structure.append(f"node_{node_id}_file_system_state_{state.name}")
|
||||
for service in services:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(
|
||||
f"node_{node_id}_service_{service}_state_{state.name}"
|
||||
)
|
||||
return structure
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""Flat list of traffic levels encoded into banded categories.
|
||||
@@ -265,6 +332,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -290,6 +359,21 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self):
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
link_id = link.id
|
||||
if self._combine_service_traffic:
|
||||
protocols = ["overall"]
|
||||
else:
|
||||
protocols = [protocol.name for protocol in link.protocol_list]
|
||||
|
||||
for p in protocols:
|
||||
for i in range(self._quantisation_levels):
|
||||
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler.
|
||||
@@ -306,8 +390,17 @@ class ObservationsHandler:
|
||||
|
||||
def __init__(self):
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
self.space: spaces.Space
|
||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
|
||||
# internal the observation space (unflattened version of space if flatten=True)
|
||||
self._space: spaces.Space
|
||||
# flattened version of the observation space
|
||||
self._flat_space: spaces.Space
|
||||
|
||||
self._observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
self.flatten: bool = False
|
||||
|
||||
def update_obs(self):
|
||||
"""Fetch fresh information about the environment."""
|
||||
@@ -316,12 +409,11 @@ class ObservationsHandler:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
|
||||
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
||||
if len(current_obs) == 1:
|
||||
self.current_observation = current_obs[0]
|
||||
self._observation = current_obs[0]
|
||||
else:
|
||||
self.current_observation = tuple(current_obs)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""Add a component for this handler to track.
|
||||
@@ -348,12 +440,31 @@ class ObservationsHandler:
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
|
||||
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
||||
# if there are multiple components, build a composite tuple space
|
||||
if len(component_spaces) == 1:
|
||||
self.space = component_spaces[0]
|
||||
self._space = component_spaces[0]
|
||||
else:
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||
self._space = spaces.Tuple(component_spaces)
|
||||
if len(component_spaces) > 0:
|
||||
self._flat_space = spaces.flatten_space(self._space)
|
||||
else:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self):
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self):
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if self.flatten:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
@@ -383,6 +494,9 @@ class ObservationsHandler:
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
if obs_space_config.get("flatten"):
|
||||
handler.flatten = True
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
@@ -396,3 +510,17 @@ class ObservationsHandler:
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self):
|
||||
"""Create a list of names for the features of the obs space.
|
||||
|
||||
The order of labels follows the flattened version of the space.
|
||||
"""
|
||||
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
|
||||
# to fake it. each component has to just hard-code the expected label order after flattening...
|
||||
|
||||
labels = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
labels.extend(obs_comp.structure)
|
||||
|
||||
return labels
|
||||
|
||||
@@ -320,7 +320,8 @@ class Primaite(Env):
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
|
||||
transaction.set_obs_space(self.obs_handler._flat_observation)
|
||||
|
||||
# Load the action space into the transaction
|
||||
transaction.action_space = copy.deepcopy(action)
|
||||
|
||||
@@ -399,8 +400,6 @@ class Primaite(Env):
|
||||
|
||||
# 7. Update env_obs
|
||||
self.update_environent_obs()
|
||||
# Load the new observation space into the transaction
|
||||
transaction.obs_space_post = copy.deepcopy(self.env_obs)
|
||||
|
||||
# Write transaction to file
|
||||
if self.actual_episode_count > 0:
|
||||
|
||||
@@ -31,6 +31,16 @@ class Transaction(object):
|
||||
self.action_space = None
|
||||
"The action space invoked by the agent"
|
||||
|
||||
def set_obs_space(self, _obs_space):
|
||||
"""
|
||||
Sets the observation space (pre).
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
"""
|
||||
self.obs_space = _obs_space
|
||||
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""
|
||||
Converts the Transaction to a csv data row and provides a header.
|
||||
|
||||
91
src/primaite/transactions/transactions_to_file.py
Normal file
91
src/primaite/transactions/transactions_to_file.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Writes the Transaction log list out to file for evaluation to utilse."""
|
||||
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def turn_action_space_to_array(_action_space):
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_action_space: The action space.
|
||||
"""
|
||||
if isinstance(_action_space, list):
|
||||
return [str(i) for i in _action_space]
|
||||
else:
|
||||
return [str(_action_space)]
|
||||
|
||||
|
||||
def write_transaction_to_file(
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
obs_space_description: list,
|
||||
):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
:param transaction_list: The list of transactions from all steps and all
|
||||
episodes.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
# Get the first transaction and use it to determine the makeup of the
|
||||
# observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
|
||||
# space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
template_transation = transaction_list[0]
|
||||
action_length = template_transation.action_space.size
|
||||
# obs_shape = template_transation.obs_space_post.shape
|
||||
# obs_assets = template_transation.obs_space_post.shape[0]
|
||||
# if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
# obs_features = 1
|
||||
# else:
|
||||
# obs_features = template_transation.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||
# obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_space_description
|
||||
|
||||
try:
|
||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||
_LOGGER.debug(f"Saving transaction logs: {filename}")
|
||||
csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(header)
|
||||
|
||||
for transaction in transaction_list:
|
||||
csv_data = [
|
||||
str(transaction.timestamp),
|
||||
str(transaction.episode_number),
|
||||
str(transaction.step_number),
|
||||
str(transaction.reward),
|
||||
]
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ transaction.obs_space.tolist()
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
csv_file.close()
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save the transaction file", exc_info=True)
|
||||
Reference in New Issue
Block a user