#917 - Synced with dev and integrated the new observation space
This commit is contained in:
@@ -178,6 +178,9 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
self._obs_space_description = None
|
||||
"The env observation space description for transactions writing"
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -318,9 +321,16 @@ class Primaite(Env):
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
|
||||
transaction = Transaction(
|
||||
self.agent_identifier,
|
||||
self.actual_episode_count,
|
||||
self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space(self.obs_handler._flat_observation)
|
||||
transaction.obs_space = self.obs_handler._flat_observation
|
||||
|
||||
# Set the transaction obs space description
|
||||
transaction.obs_space_description = self._obs_space_description
|
||||
|
||||
# Load the action space into the transaction
|
||||
transaction.action_space = copy.deepcopy(action)
|
||||
@@ -675,6 +685,9 @@ class Primaite(Env):
|
||||
"""
|
||||
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||
|
||||
if not self._obs_space_description:
|
||||
self._obs_space_description = self.obs_handler.describe_structure()
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
def update_environent_obs(self):
|
||||
|
||||
@@ -3,11 +3,18 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, agent_identifier, episode_number, step_number):
|
||||
def __init__(
|
||||
self,
|
||||
agent_identifier: AgentIdentifier,
|
||||
episode_number: int,
|
||||
step_number: int
|
||||
):
|
||||
"""
|
||||
Transaction constructor.
|
||||
|
||||
@@ -17,11 +24,14 @@ class Transaction(object):
|
||||
"""
|
||||
self.timestamp = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier = agent_identifier
|
||||
self.episode_number = episode_number
|
||||
self.agent_identifier: AgentIdentifier = agent_identifier
|
||||
"The agent identifier"
|
||||
self.episode_number: int = episode_number
|
||||
"The episode number"
|
||||
self.step_number = step_number
|
||||
self.step_number: int = step_number
|
||||
"The step number"
|
||||
self.obs_space = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post = None
|
||||
@@ -30,16 +40,8 @@ class Transaction(object):
|
||||
"The reward value"
|
||||
self.action_space = None
|
||||
"The action space invoked by the agent"
|
||||
|
||||
def set_obs_space(self, _obs_space):
|
||||
"""
|
||||
Sets the observation space (pre).
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
"""
|
||||
self.obs_space = _obs_space
|
||||
|
||||
self.obs_space_description = None
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""
|
||||
@@ -51,32 +53,16 @@ class Transaction(object):
|
||||
action_length = self.action_space
|
||||
else:
|
||||
action_length = self.action_space.size
|
||||
obs_shape = self.obs_space_post.shape
|
||||
obs_assets = self.obs_space_post.shape[0]
|
||||
if len(obs_shape) == 1:
|
||||
# A bit of a workaround but I think the way transactions are
|
||||
# written will change soon
|
||||
obs_features = 1
|
||||
else:
|
||||
obs_features = self.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = []
|
||||
obs_header_new = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
|
||||
header = header + action_header + self.obs_space_description
|
||||
|
||||
row = [
|
||||
str(self.timestamp),
|
||||
str(self.episode_number),
|
||||
@@ -84,10 +70,9 @@ class Transaction(object):
|
||||
str(self.reward),
|
||||
]
|
||||
row = (
|
||||
row
|
||||
+ _turn_action_space_to_array(self.action_space)
|
||||
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features)
|
||||
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features)
|
||||
row
|
||||
+ _turn_action_space_to_array(self.action_space)
|
||||
+ self.obs_space.tolist()
|
||||
)
|
||||
return header, row
|
||||
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Writes the Transaction log list out to file for evaluation to utilse."""
|
||||
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def turn_action_space_to_array(_action_space):
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_action_space: The action space.
|
||||
"""
|
||||
if isinstance(_action_space, list):
|
||||
return [str(i) for i in _action_space]
|
||||
else:
|
||||
return [str(_action_space)]
|
||||
|
||||
|
||||
def write_transaction_to_file(
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
obs_space_description: list,
|
||||
):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
:param transaction_list: The list of transactions from all steps and all
|
||||
episodes.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
# Get the first transaction and use it to determine the makeup of the
|
||||
# observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
|
||||
# space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
template_transation = transaction_list[0]
|
||||
action_length = template_transation.action_space.size
|
||||
# obs_shape = template_transation.obs_space_post.shape
|
||||
# obs_assets = template_transation.obs_space_post.shape[0]
|
||||
# if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
# obs_features = 1
|
||||
# else:
|
||||
# obs_features = template_transation.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||
# obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_space_description
|
||||
|
||||
try:
|
||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||
_LOGGER.debug(f"Saving transaction logs: {filename}")
|
||||
csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(header)
|
||||
|
||||
for transaction in transaction_list:
|
||||
csv_data = [
|
||||
str(transaction.timestamp),
|
||||
str(transaction.episode_number),
|
||||
str(transaction.step_number),
|
||||
str(transaction.reward),
|
||||
]
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ transaction.obs_space.tolist()
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
csv_file.close()
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save the transaction file", exc_info=True)
|
||||
Reference in New Issue
Block a user