diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 4f639f11..a59b2361 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session: * Timestamp * Episode number * Step number - * Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y - * Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y + * Initial observation space (what the blue agent observed when it decided its action) * Reward value - * Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X + * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X **Diagrams** diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 91d9af11..a638fe14 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -41,8 +41,14 @@ hard_coded_agent_view: FULL # "NODE" # "ACL" # "ANY" node and acl actions -action_type: ANY - +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session num_episodes: 10 diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 0470828e..511fb008 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? + self.structure: List[str] return NotImplemented @abstractmethod @@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC): """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented + @abstractmethod + def generate_structure(self) -> List[str]: + """Return a list of labels for the components of the flattened observation space.""" + return NotImplemented + class NodeLinkTable(AbstractObservationComponent): """Table with nodes and links as rows and hardware/software status as cols. @@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent): # 3. Initialise Observation with zeroes self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -125,6 +133,40 @@ class NodeLinkTable(AbstractObservationComponent): protocol_index += 1 item_index += 1 + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + nodes = self.env.nodes.values() + links = self.env.links.values() + + structure = [] + + for i, node in enumerate(nodes): + node_id = node.node_id + node_labels = [ + f"node_{node_id}_id", + f"node_{node_id}_hardware_status", + f"node_{node_id}_os_status", + f"node_{node_id}_fs_status", + ] + for j, serv in enumerate(self.env.services_list): + node_labels.append(f"node_{node_id}_service_{serv}_status") + + structure.extend(node_labels) + + for i, link in enumerate(links): + link_id = link.id + link_labels = [ + f"link_{link_id}_id", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + ] + for j, serv in enumerate(self.env.services_list): + link_labels.append(f"link_{link_id}_service_{serv}_load") + + structure.extend(link_labels) + return structure + class NodeStatuses(AbstractObservationComponent): """Flat list of nodes' hardware, OS, file system, and service states. @@ -173,6 +215,7 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() def update(self): """Update the observation based on current environment state. @@ -204,6 +247,30 @@ class NodeStatuses(AbstractObservationComponent): ) self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + services = self.env.services_list + + structure = [] + for _, node in self.env.nodes.items(): + node_id = node.node_id + structure.append(f"node_{node_id}_hardware_state_NONE") + for state in HardwareState: + structure.append(f"node_{node_id}_hardware_state_{state.name}") + structure.append(f"node_{node_id}_software_state_NONE") + for state in SoftwareState: + structure.append(f"node_{node_id}_software_state_{state.name}") + structure.append(f"node_{node_id}_file_system_state_NONE") + for state in FileSystemState: + structure.append(f"node_{node_id}_file_system_state_{state.name}") + for service in services: + structure.append(f"node_{node_id}_service_{service}_state_NONE") + for state in SoftwareState: + structure.append( + f"node_{node_id}_service_{service}_state_{state.name}" + ) + return structure + class LinkTrafficLevels(AbstractObservationComponent): """Flat list of traffic levels encoded into banded categories. @@ -265,6 +332,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -290,6 +359,21 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + structure = [] + for _, link in self.env.links.items(): + link_id = link.id + if self._combine_service_traffic: + protocols = ["overall"] + else: + protocols = [protocol.name for protocol in link.protocol_list] + + for p in protocols: + for i in range(self._quantisation_levels): + structure.append(f"link_{link_id}_{p}_traffic_level_{i}") + return structure + class ObservationsHandler: """Component-based observation space handler. @@ -306,8 +390,17 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] - self.space: spaces.Space - self.current_observation: Union[Tuple[np.ndarray], np.ndarray] + + # internal the observation space (unflattened version of space if flatten=True) + self._space: spaces.Space + # flattened version of the observation space + self._flat_space: spaces.Space + + self._observation: Union[Tuple[np.ndarray], np.ndarray] + # used for transactions and when flatten=true + self._flat_observation: np.ndarray + + self.flatten: bool = False def update_obs(self): """Fetch fresh information about the environment.""" @@ -316,12 +409,11 @@ class ObservationsHandler: obs.update() current_obs.append(obs.current_observation) - # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: - self.current_observation = current_obs[0] + self._observation = current_obs[0] else: - self.current_observation = tuple(current_obs) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self._observation = tuple(current_obs) + self._flat_observation = spaces.flatten(self._space, self._observation) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -348,12 +440,31 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - # If there is only one component, don't use a tuple space, just pass through that component's space. + # if there are multiple components, build a composite tuple space if len(component_spaces) == 1: - self.space = component_spaces[0] + self._space = component_spaces[0] else: - self.space = spaces.Tuple(component_spaces) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self._space = spaces.Tuple(component_spaces) + if len(component_spaces) > 0: + self._flat_space = spaces.flatten_space(self._space) + else: + self._flat_space = spaces.Box(0, 1, (0,)) + + @property + def space(self): + """Observation space, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_space + else: + return self._space + + @property + def current_observation(self): + """Current observation, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_observation + else: + return self._observation @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -383,6 +494,9 @@ class ObservationsHandler: # Instantiate the handler handler = cls() + if obs_space_config.get("flatten"): + handler.flatten = True + for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] @@ -396,3 +510,17 @@ class ObservationsHandler: handler.update_obs() return handler + + def describe_structure(self): + """Create a list of names for the features of the obs space. + + The order of labels follows the flattened version of the space. + """ + # as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have + # to fake it. each component has to just hard-code the expected label order after flattening... + + labels = [] + for obs_comp in self.registered_obs_components: + labels.extend(obs_comp.structure) + + return labels diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 36632155..75bf7310 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -320,7 +320,8 @@ class Primaite(Env): # Create a Transaction (metric) object for this step transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) # Load the initial observation space into the transaction - transaction.obs_space_pre = copy.deepcopy(self.env_obs) + transaction.set_obs_space(self.obs_handler._flat_observation) + # Load the action space into the transaction transaction.action_space = copy.deepcopy(action) @@ -399,8 +400,6 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() - # Load the new observation space into the transaction - transaction.obs_space_post = copy.deepcopy(self.env_obs) # Write transaction to file if self.actual_episode_count > 0: diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index eeafe05e..69f0f545 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -31,6 +31,16 @@ class Transaction(object): self.action_space = None "The action space invoked by the agent" + def set_obs_space(self, _obs_space): + """ + Sets the observation space (pre). + + Args: + _obs_space_pre: The observation space before any actions are taken + """ + self.obs_space = _obs_space + + def as_csv_data(self) -> Tuple[List, List]: """ Converts the Transaction to a csv data row and provides a header. diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py new file mode 100644 index 00000000..4e364f0b --- /dev/null +++ b/src/primaite/transactions/transactions_to_file.py @@ -0,0 +1,91 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +"""Writes the Transaction log list out to file for evaluation to utilse.""" + +import csv +from pathlib import Path + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def turn_action_space_to_array(_action_space): + """ + Turns action space into a string array so it can be saved to csv. + + Args: + _action_space: The action space. + """ + if isinstance(_action_space, list): + return [str(i) for i in _action_space] + else: + return [str(_action_space)] + + +def write_transaction_to_file( + transaction_list, + session_path: Path, + timestamp_str: str, + obs_space_description: list, +): + """ + Writes transaction logs to file to support training evaluation. + + :param transaction_list: The list of transactions from all steps and all + episodes. + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + # Get the first transaction and use it to determine the makeup of the + # observation space and action space + # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action + # space as "AS_1" + # This will be tied into the PrimAITE Use Case so that they make sense + template_transation = transaction_list[0] + action_length = template_transation.action_space.size + # obs_shape = template_transation.obs_space_post.shape + # obs_assets = template_transation.obs_space_post.shape[0] + # if len(obs_shape) == 1: + # bit of a workaround but I think the way transactions are written will change soon + # obs_features = 1 + # else: + # obs_features = template_transation.obs_space_post.shape[1] + + # Create the action space headers array + action_header = [] + for x in range(action_length): + action_header.append("AS_" + str(x)) + + # Create the observation space headers array + # obs_header_initial = [f"pre_{o}" for o in obs_space_description] + # obs_header_new = [f"post_{o}" for o in obs_space_description] + + # Open up a csv file + header = ["Timestamp", "Episode", "Step", "Reward"] + header = header + action_header + obs_space_description + + try: + filename = session_path / f"all_transactions_{timestamp_str}.csv" + _LOGGER.debug(f"Saving transaction logs: {filename}") + csv_file = open(filename, "w", encoding="UTF8", newline="") + csv_writer = csv.writer(csv_file) + csv_writer.writerow(header) + + for transaction in transaction_list: + csv_data = [ + str(transaction.timestamp), + str(transaction.episode_number), + str(transaction.step_number), + str(transaction.reward), + ] + csv_data = ( + csv_data + + turn_action_space_to_array(transaction.action_space) + + transaction.obs_space.tolist() + ) + csv_writer.writerow(csv_data) + + csv_file.close() + except Exception: + _LOGGER.error("Could not save the transaction file", exc_info=True)