diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 75bf7310..d7b68045 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -178,6 +178,9 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler + self._obs_space_description = None + "The env observation space description for transactions writing" + # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -318,9 +321,16 @@ class Primaite(Env): link.clear_traffic() # Create a Transaction (metric) object for this step - transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) + transaction = Transaction( + self.agent_identifier, + self.actual_episode_count, + self.step_count + ) # Load the initial observation space into the transaction - transaction.set_obs_space(self.obs_handler._flat_observation) + transaction.obs_space = self.obs_handler._flat_observation + + # Set the transaction obs space description + transaction.obs_space_description = self._obs_space_description # Load the action space into the transaction transaction.action_space = copy.deepcopy(action) @@ -675,6 +685,9 @@ class Primaite(Env): """ self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + if not self._obs_space_description: + self._obs_space_description = self.obs_handler.describe_structure() + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 69f0f545..763dc458 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -3,11 +3,18 @@ from datetime import datetime from typing import List, Tuple +from primaite.common.enums import AgentIdentifier + class Transaction(object): """Transaction class.""" - def __init__(self, agent_identifier, episode_number, step_number): + def __init__( + self, + agent_identifier: AgentIdentifier, + episode_number: int, + step_number: int + ): """ Transaction constructor. @@ -17,11 +24,14 @@ class Transaction(object): """ self.timestamp = datetime.now() "The datetime of the transaction" - self.agent_identifier = agent_identifier - self.episode_number = episode_number + self.agent_identifier: AgentIdentifier = agent_identifier + "The agent identifier" + self.episode_number: int = episode_number "The episode number" - self.step_number = step_number + self.step_number: int = step_number "The step number" + self.obs_space = None + "The observation space (pre)" self.obs_space_pre = None "The observation space before any actions are taken" self.obs_space_post = None @@ -30,16 +40,8 @@ class Transaction(object): "The reward value" self.action_space = None "The action space invoked by the agent" - - def set_obs_space(self, _obs_space): - """ - Sets the observation space (pre). - - Args: - _obs_space_pre: The observation space before any actions are taken - """ - self.obs_space = _obs_space - + self.obs_space_description = None + "The env observation space description" def as_csv_data(self) -> Tuple[List, List]: """ @@ -51,32 +53,16 @@ class Transaction(object): action_length = self.action_space else: action_length = self.action_space.size - obs_shape = self.obs_space_post.shape - obs_assets = self.obs_space_post.shape[0] - if len(obs_shape) == 1: - # A bit of a workaround but I think the way transactions are - # written will change soon - obs_features = 1 - else: - obs_features = self.obs_space_post.shape[1] # Create the action space headers array action_header = [] for x in range(action_length): action_header.append("AS_" + str(x)) - # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) - # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new - + header = header + action_header + self.obs_space_description + row = [ str(self.timestamp), str(self.episode_number), @@ -84,10 +70,9 @@ class Transaction(object): str(self.reward), ] row = ( - row - + _turn_action_space_to_array(self.action_space) - + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features) - + _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features) + row + + _turn_action_space_to_array(self.action_space) + + self.obs_space.tolist() ) return header, row diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py deleted file mode 100644 index 4e364f0b..00000000 --- a/src/primaite/transactions/transactions_to_file.py +++ /dev/null @@ -1,91 +0,0 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -"""Writes the Transaction log list out to file for evaluation to utilse.""" - -import csv -from pathlib import Path - -from primaite import getLogger - -_LOGGER = getLogger(__name__) - - -def turn_action_space_to_array(_action_space): - """ - Turns action space into a string array so it can be saved to csv. - - Args: - _action_space: The action space. - """ - if isinstance(_action_space, list): - return [str(i) for i in _action_space] - else: - return [str(_action_space)] - - -def write_transaction_to_file( - transaction_list, - session_path: Path, - timestamp_str: str, - obs_space_description: list, -): - """ - Writes transaction logs to file to support training evaluation. - - :param transaction_list: The list of transactions from all steps and all - episodes. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - # Get the first transaction and use it to determine the makeup of the - # observation space and action space - # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action - # space as "AS_1" - # This will be tied into the PrimAITE Use Case so that they make sense - template_transation = transaction_list[0] - action_length = template_transation.action_space.size - # obs_shape = template_transation.obs_space_post.shape - # obs_assets = template_transation.obs_space_post.shape[0] - # if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - # obs_features = 1 - # else: - # obs_features = template_transation.obs_space_post.shape[1] - - # Create the action space headers array - action_header = [] - for x in range(action_length): - action_header.append("AS_" + str(x)) - - # Create the observation space headers array - # obs_header_initial = [f"pre_{o}" for o in obs_space_description] - # obs_header_new = [f"post_{o}" for o in obs_space_description] - - # Open up a csv file - header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_space_description - - try: - filename = session_path / f"all_transactions_{timestamp_str}.csv" - _LOGGER.debug(f"Saving transaction logs: {filename}") - csv_file = open(filename, "w", encoding="UTF8", newline="") - csv_writer = csv.writer(csv_file) - csv_writer.writerow(header) - - for transaction in transaction_list: - csv_data = [ - str(transaction.timestamp), - str(transaction.episode_number), - str(transaction.step_number), - str(transaction.reward), - ] - csv_data = ( - csv_data - + turn_action_space_to_array(transaction.action_space) - + transaction.obs_space.tolist() - ) - csv_writer.writerow(csv_data) - - csv_file.close() - except Exception: - _LOGGER.error("Could not save the transaction file", exc_info=True)