#917 - Synced with dev and integrated the new observation space

This commit is contained in:
Chris McCarthy
2023-07-03 20:36:21 +01:00
parent 1716786441
commit e271a28bf0
3 changed files with 36 additions and 129 deletions

View File

@@ -178,6 +178,9 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
self._obs_space_description = None
"The env observation space description for transactions writing"
# Open the config file and build the environment laydown
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
@@ -318,9 +321,16 @@ class Primaite(Env):
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
transaction = Transaction(
self.agent_identifier,
self.actual_episode_count,
self.step_count
)
# Load the initial observation space into the transaction
transaction.set_obs_space(self.obs_handler._flat_observation)
transaction.obs_space = self.obs_handler._flat_observation
# Set the transaction obs space description
transaction.obs_space_description = self._obs_space_description
# Load the action space into the transaction
transaction.action_space = copy.deepcopy(action)
@@ -675,6 +685,9 @@ class Primaite(Env):
"""
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
if not self._obs_space_description:
self._obs_space_description = self.obs_handler.describe_structure()
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):

View File

@@ -3,11 +3,18 @@
from datetime import datetime
from typing import List, Tuple
from primaite.common.enums import AgentIdentifier
class Transaction(object):
"""Transaction class."""
def __init__(self, agent_identifier, episode_number, step_number):
def __init__(
self,
agent_identifier: AgentIdentifier,
episode_number: int,
step_number: int
):
"""
Transaction constructor.
@@ -17,11 +24,14 @@ class Transaction(object):
"""
self.timestamp = datetime.now()
"The datetime of the transaction"
self.agent_identifier = agent_identifier
self.episode_number = episode_number
self.agent_identifier: AgentIdentifier = agent_identifier
"The agent identifier"
self.episode_number: int = episode_number
"The episode number"
self.step_number = step_number
self.step_number: int = step_number
"The step number"
self.obs_space = None
"The observation space (pre)"
self.obs_space_pre = None
"The observation space before any actions are taken"
self.obs_space_post = None
@@ -30,16 +40,8 @@ class Transaction(object):
"The reward value"
self.action_space = None
"The action space invoked by the agent"
def set_obs_space(self, _obs_space):
"""
Sets the observation space (pre).
Args:
_obs_space_pre: The observation space before any actions are taken
"""
self.obs_space = _obs_space
self.obs_space_description = None
"The env observation space description"
def as_csv_data(self) -> Tuple[List, List]:
"""
@@ -51,32 +53,16 @@ class Transaction(object):
action_length = self.action_space
else:
action_length = self.action_space.size
obs_shape = self.obs_space_post.shape
obs_assets = self.obs_space_post.shape[0]
if len(obs_shape) == 1:
# A bit of a workaround but I think the way transactions are
# written will change soon
obs_features = 1
else:
obs_features = self.obs_space_post.shape[1]
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
# Create the observation space headers array
obs_header_initial = []
obs_header_new = []
for x in range(obs_assets):
for y in range(obs_features):
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_header_initial + obs_header_new
header = header + action_header + self.obs_space_description
row = [
str(self.timestamp),
str(self.episode_number),
@@ -84,10 +70,9 @@ class Transaction(object):
str(self.reward),
]
row = (
row
+ _turn_action_space_to_array(self.action_space)
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features)
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features)
row
+ _turn_action_space_to_array(self.action_space)
+ self.obs_space.tolist()
)
return header, row

View File

@@ -1,91 +0,0 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Writes the Transaction log list out to file for evaluation to utilse."""
import csv
from pathlib import Path
from primaite import getLogger
_LOGGER = getLogger(__name__)
def turn_action_space_to_array(_action_space):
"""
Turns action space into a string array so it can be saved to csv.
Args:
_action_space: The action space.
"""
if isinstance(_action_space, list):
return [str(i) for i in _action_space]
else:
return [str(_action_space)]
def write_transaction_to_file(
transaction_list,
session_path: Path,
timestamp_str: str,
obs_space_description: list,
):
"""
Writes transaction logs to file to support training evaluation.
:param transaction_list: The list of transactions from all steps and all
episodes.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
# Get the first transaction and use it to determine the makeup of the
# observation space and action space
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
# space as "AS_1"
# This will be tied into the PrimAITE Use Case so that they make sense
template_transation = transaction_list[0]
action_length = template_transation.action_space.size
# obs_shape = template_transation.obs_space_post.shape
# obs_assets = template_transation.obs_space_post.shape[0]
# if len(obs_shape) == 1:
# bit of a workaround but I think the way transactions are written will change soon
# obs_features = 1
# else:
# obs_features = template_transation.obs_space_post.shape[1]
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
# Create the observation space headers array
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
# obs_header_new = [f"post_{o}" for o in obs_space_description]
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_space_description
try:
filename = session_path / f"all_transactions_{timestamp_str}.csv"
_LOGGER.debug(f"Saving transaction logs: {filename}")
csv_file = open(filename, "w", encoding="UTF8", newline="")
csv_writer = csv.writer(csv_file)
csv_writer.writerow(header)
for transaction in transaction_list:
csv_data = [
str(transaction.timestamp),
str(transaction.episode_number),
str(transaction.step_number),
str(transaction.reward),
]
csv_data = (
csv_data
+ turn_action_space_to_array(transaction.action_space)
+ transaction.obs_space.tolist()
)
csv_writer.writerow(csv_data)
csv_file.close()
except Exception:
_LOGGER.error("Could not save the transaction file", exc_info=True)