#917 - Overhauled transaction and mean reward writing.
- Separated out learning outputs from evaluation outputs
This commit is contained in:
@@ -21,6 +21,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
def _get_primaite_config():
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
|
||||
config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
@@ -36,7 +37,7 @@ def _get_primaite_config():
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
|
||||
return primaite_config
|
||||
|
||||
|
||||
@@ -108,11 +109,11 @@ def _log_dir() -> Path:
|
||||
|
||||
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
|
||||
{
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"]
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -132,10 +133,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
|
||||
backupCount=9, # Max 100MB of logs
|
||||
encoding="utf8",
|
||||
)
|
||||
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
|
||||
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"]
|
||||
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
|
||||
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
|
||||
@@ -145,7 +146,7 @@ _LOGGER.addHandler(_STREAM_HANDLER)
|
||||
_LOGGER.addHandler(_FILE_HANDLER)
|
||||
|
||||
|
||||
def getLogger(name: str) -> Logger:
|
||||
def getLogger(name: str) -> Logger: # noqa
|
||||
"""
|
||||
Get a PrimAITE logger.
|
||||
|
||||
|
||||
@@ -64,7 +64,11 @@ class AgentSessionABC(ABC):
|
||||
"The session timestamp"
|
||||
self.session_path = _get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
self.checkpoints_path = self.session_path / "checkpoints"
|
||||
self.learning_path = self.session_path / "learning"
|
||||
"The learning outputs path"
|
||||
self.evaluation_path = self.session_path / "evaluation"
|
||||
"The evaluation outputs path"
|
||||
self.checkpoints_path = self.learning_path / "checkpoints"
|
||||
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
|
||||
"The Session checkpoints path"
|
||||
|
||||
@@ -205,7 +209,6 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ def _env_creator(env_config):
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
transaction_list=env_config["transaction_list"],
|
||||
session_path=env_config["session_path"],
|
||||
timestamp_str=env_config["timestamp_str"]
|
||||
)
|
||||
@@ -106,7 +105,6 @@ class RLlibAgent(AgentSessionABC):
|
||||
env_config=dict(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ class SB3Agent(AgentSessionABC):
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
self._setup()
|
||||
_LOGGER.debug(
|
||||
@@ -49,7 +49,6 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
@@ -108,10 +107,13 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning evaluation for {episodes} episodes @"
|
||||
f" {time_steps} time steps...")
|
||||
|
||||
self._env.set_as_eval()
|
||||
if deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(f"Beginning {deterministic_str} evaluation for "
|
||||
f"{episodes} episodes @ {time_steps} time steps...")
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
@@ -123,6 +125,7 @@ class SB3Agent(AgentSessionABC):
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
_LOGGER.info(f"Finished evaluation")
|
||||
|
||||
@classmethod
|
||||
def load(self):
|
||||
|
||||
@@ -38,7 +38,7 @@ hard_coded_agent_view: FULL
|
||||
action_type: ANY
|
||||
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
num_episodes: 1000
|
||||
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
||||
import copy
|
||||
import csv
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Union, Final
|
||||
|
||||
@@ -14,6 +10,7 @@ import yaml
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
||||
is_valid_node_action
|
||||
@@ -27,7 +24,7 @@ from primaite.common.enums import (
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
SoftwareState, SessionType,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
@@ -47,11 +44,9 @@ from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
@@ -67,7 +62,6 @@ class Primaite(Env):
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
@@ -76,7 +70,6 @@ class Primaite(Env):
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param transaction_list: The list of transactions to populate.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
@@ -96,9 +89,6 @@ class Primaite(Env):
|
||||
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# Transaction list
|
||||
self.transaction_list = transaction_list
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.training_config.agent_identifier
|
||||
|
||||
@@ -245,20 +235,31 @@ class Primaite(Env):
|
||||
_LOGGER.error(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
header = ["Episode", "Average Reward"]
|
||||
|
||||
file_name = f"average_reward_per_episode_{timestamp_str}.csv"
|
||||
file_path = session_path / file_name
|
||||
self.csv_file = open(file_path, "w", encoding="UTF8", newline="")
|
||||
self.csv_writer = csv.writer(self.csv_file)
|
||||
self.csv_writer.writerow(header)
|
||||
except Exception:
|
||||
_LOGGER.error(
|
||||
"Could not create csv file to hold average reward per episode"
|
||||
)
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=False,
|
||||
learning_session=True
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=True,
|
||||
learning_session=True
|
||||
)
|
||||
|
||||
def set_as_eval(self):
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=False,
|
||||
learning_session=False
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self,
|
||||
transaction_writer=True,
|
||||
learning_session=False
|
||||
)
|
||||
self.episode_count = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -267,12 +268,14 @@ class Primaite(Env):
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
"""
|
||||
csv_data = self.episode_count, self.average_reward
|
||||
self.csv_writer.writerow(csv_data)
|
||||
if self.episode_count > 0:
|
||||
csv_data = self.episode_count, self.average_reward
|
||||
self.episode_av_reward_writer.write(csv_data)
|
||||
|
||||
self.episode_count += 1
|
||||
|
||||
# Don't need to reset links, as they are cleared and recalculated every step
|
||||
# Don't need to reset links, as they are cleared and recalculated every
|
||||
# step
|
||||
|
||||
# Clear the ACL
|
||||
self.init_acl()
|
||||
@@ -303,12 +306,8 @@ class Primaite(Env):
|
||||
done: Indicates episode is complete if True
|
||||
step_info: Additional information relating to this step
|
||||
"""
|
||||
if self.step_count == 0:
|
||||
_LOGGER.info(f"Episode: {str(self.episode_count)}")
|
||||
|
||||
# TEMP
|
||||
done = False
|
||||
|
||||
self.step_count += 1
|
||||
self.total_step_count += 1
|
||||
|
||||
@@ -321,13 +320,16 @@ class Primaite(Env):
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count,
|
||||
self.agent_identifier,
|
||||
self.episode_count,
|
||||
self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
|
||||
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
|
||||
# Load the action space into the transaction
|
||||
transaction.set_action_space(copy.deepcopy(action))
|
||||
transaction.action_space = copy.deepcopy(action)
|
||||
|
||||
initial_nodes = copy.deepcopy(self.nodes)
|
||||
|
||||
# 1. Implement Blue Action
|
||||
self.interpret_action_and_apply(action)
|
||||
@@ -381,7 +383,7 @@ class Primaite(Env):
|
||||
|
||||
# 5. Calculate reward signal (for RL)
|
||||
reward = calculate_reward_function(
|
||||
self.nodes_post_pol,
|
||||
initial_nodes,
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
@@ -390,17 +392,22 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
_LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
_LOGGER.debug(
|
||||
f"Episode: {self.episode_count}, "
|
||||
f"Step {self.step_count}, "
|
||||
f"Reward: {reward}"
|
||||
)
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
if self.training_config.session_type == "EVALUATION":
|
||||
if self.training_config.session_type is SessionType.EVAL:
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
_LOGGER.info(f" Average Reward: {str(self.average_reward)}")
|
||||
_LOGGER.info(f"Episode: {self.episode_count}, "
|
||||
f"Average Reward: {self.average_reward}")
|
||||
# Load the reward into the transaction
|
||||
transaction.set_reward(reward)
|
||||
transaction.reward = reward
|
||||
|
||||
# 6. Output Verbose
|
||||
# self.output_link_status()
|
||||
@@ -408,28 +415,14 @@ class Primaite(Env):
|
||||
# 7. Update env_obs
|
||||
self.update_environent_obs()
|
||||
# Load the new observation space into the transaction
|
||||
transaction.set_obs_space_post(copy.deepcopy(self.env_obs))
|
||||
transaction.obs_space_post = copy.deepcopy(self.env_obs)
|
||||
|
||||
# 8. Add the transaction to the list of transactions
|
||||
self.transaction_list.append(copy.deepcopy(transaction))
|
||||
# Write transaction to file
|
||||
self.transaction_writer.write(transaction)
|
||||
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
self.__close__()
|
||||
|
||||
def __close__(self):
|
||||
"""
|
||||
Override close function
|
||||
"""
|
||||
write_transaction_to_file(
|
||||
self.transaction_list,
|
||||
self.session_path,
|
||||
self.timestamp_str
|
||||
)
|
||||
self.csv_file.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
@@ -467,7 +460,7 @@ class Primaite(Env):
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
logging.error("Invalid action type found")
|
||||
_LOGGER.error("Invalid action type found")
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
"""
|
||||
|
||||
@@ -85,9 +85,6 @@ def calculate_reward_function(
|
||||
)
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
_LOGGER.debug(
|
||||
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
|
||||
)
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
|
||||
@@ -20,6 +20,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
session.evaluate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# The main PrimAITE application config file
|
||||
|
||||
# Logging
|
||||
log_level: INFO
|
||||
logger_format:
|
||||
DEBUG: '%(asctime)s: %(message)s'
|
||||
INFO: '%(asctime)s: %(message)s'
|
||||
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
logging:
|
||||
log_level: INFO
|
||||
logger_format:
|
||||
DEBUG: '%(asctime)s: %(message)s'
|
||||
INFO: '%(asctime)s: %(message)s'
|
||||
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
|
||||
@@ -1,57 +1,115 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number):
|
||||
def __init__(
|
||||
self,
|
||||
agent_identifier,
|
||||
episode_number,
|
||||
step_number
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
Transaction constructor.
|
||||
|
||||
Args:
|
||||
_timestamp: The time this object was created
|
||||
_agent_identifier: An identifier for the agent in use
|
||||
_episode_number: The episode number
|
||||
_step_number: The step number
|
||||
:param agent_identifier: An identifier for the agent in use
|
||||
:param episode_number: The episode number
|
||||
:param step_number: The step number
|
||||
"""
|
||||
self.timestamp = _timestamp
|
||||
self.agent_identifier = _agent_identifier
|
||||
self.episode_number = _episode_number
|
||||
self.step_number = _step_number
|
||||
self.timestamp = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier = agent_identifier
|
||||
self.episode_number = episode_number
|
||||
"The episode number"
|
||||
self.step_number = step_number
|
||||
"The step number"
|
||||
self.obs_space_pre = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward = None
|
||||
"The reward value"
|
||||
self.action_space = None
|
||||
"The action space invoked by the agent"
|
||||
|
||||
def set_obs_space_pre(self, _obs_space_pre):
|
||||
"""
|
||||
Sets the observation space (pre).
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
if isinstance(self.action_space, int):
|
||||
action_length = self.action_space
|
||||
else:
|
||||
action_length = self.action_space.size
|
||||
obs_shape = self.obs_space_post.shape
|
||||
obs_assets = self.obs_space_post.shape[0]
|
||||
if len(obs_shape) == 1:
|
||||
# A bit of a workaround but I think the way transactions are
|
||||
# written will change soon
|
||||
obs_features = 1
|
||||
else:
|
||||
obs_features = self.obs_space_post.shape[1]
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
"""
|
||||
self.obs_space_pre = _obs_space_pre
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
def set_obs_space_post(self, _obs_space_post):
|
||||
"""
|
||||
Sets the observation space (post).
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = []
|
||||
obs_header_new = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
||||
|
||||
Args:
|
||||
_obs_space_post: The observation space after any actions are taken
|
||||
"""
|
||||
self.obs_space_post = _obs_space_post
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
|
||||
def set_reward(self, _reward):
|
||||
"""
|
||||
Sets the reward.
|
||||
row = [
|
||||
str(self.timestamp),
|
||||
str(self.episode_number),
|
||||
str(self.step_number),
|
||||
str(self.reward),
|
||||
]
|
||||
row = (
|
||||
row
|
||||
+ _turn_action_space_to_array(self.action_space)
|
||||
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets,
|
||||
obs_features)
|
||||
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets,
|
||||
obs_features)
|
||||
)
|
||||
return header, row
|
||||
|
||||
Args:
|
||||
_reward: The reward value
|
||||
"""
|
||||
self.reward = _reward
|
||||
|
||||
def set_action_space(self, _action_space):
|
||||
"""
|
||||
Sets the action space.
|
||||
def _turn_action_space_to_array(action_space) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_action_space: The action space invoked by the agent
|
||||
"""
|
||||
self.action_space = _action_space
|
||||
:param action_space: The action space
|
||||
:return: The action space as an array of strings
|
||||
"""
|
||||
if isinstance(action_space, list):
|
||||
return [str(i) for i in action_space]
|
||||
else:
|
||||
return [str(action_space)]
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
:param obs_space: The observation space
|
||||
:param obs_assets: The number of assets (i.e. nodes or links) in the
|
||||
observation space
|
||||
:param obs_features: The number of features associated with the asset
|
||||
:return: The observation space as an array of strings
|
||||
"""
|
||||
return_array = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
return_array.append(str(obs_space[x][y]))
|
||||
|
||||
return return_array
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Writes the Transaction log list out to file for evaluation to utilse."""
|
||||
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def turn_action_space_to_array(_action_space):
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_action_space: The action space.
|
||||
"""
|
||||
if isinstance(_action_space, list):
|
||||
return [str(i) for i in _action_space]
|
||||
else:
|
||||
return [str(_action_space)]
|
||||
|
||||
|
||||
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_obs_space: The observation space
|
||||
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
_obs_features: The number of features associated with the asset
|
||||
"""
|
||||
return_array = []
|
||||
for x in range(_obs_assets):
|
||||
for y in range(_obs_features):
|
||||
return_array.append(str(_obs_space[x][y]))
|
||||
|
||||
return return_array
|
||||
|
||||
|
||||
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
:param transaction_list: The list of transactions from all steps and all
|
||||
episodes.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
# Get the first transaction and use it to determine the makeup of the
|
||||
# observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
|
||||
# space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
|
||||
template_transation = transaction_list[0]
|
||||
if isinstance(template_transation.action_space, int):
|
||||
action_length = template_transation.action_space
|
||||
else:
|
||||
action_length = template_transation.action_space.size
|
||||
obs_shape = template_transation.obs_space_post.shape
|
||||
obs_assets = template_transation.obs_space_post.shape[0]
|
||||
if len(obs_shape) == 1:
|
||||
# bit of a workaround but I think the way transactions are written will change soon
|
||||
obs_features = 1
|
||||
else:
|
||||
obs_features = template_transation.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = []
|
||||
obs_header_new = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
|
||||
try:
|
||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||
_LOGGER.debug(f"Saving transaction logs: {filename}")
|
||||
csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(header)
|
||||
|
||||
for transaction in transaction_list:
|
||||
csv_data = [
|
||||
str(transaction.timestamp),
|
||||
str(transaction.episode_number),
|
||||
str(transaction.step_number),
|
||||
str(transaction.reward),
|
||||
]
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_pre, obs_assets, obs_features
|
||||
)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_post, obs_assets, obs_features
|
||||
)
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
csv_file.close()
|
||||
_LOGGER.debug("Finished writing transactions")
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save the transaction file", exc_info=True)
|
||||
73
src/primaite/utils/session_output_writer.py
Normal file
73
src/primaite/utils/session_output_writer.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import List, Final, IO, Union, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SessionOutputWriter:
|
||||
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
"Episode", "Average Reward"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: "Primaite",
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True
|
||||
):
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
self.learning_session = learning_session
|
||||
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
if self.learning_session:
|
||||
self._csv_file_path = self._env.session_path / "learning" / fn
|
||||
else:
|
||||
self._csv_file_path = self._env.session_path / "evaluation" / fn
|
||||
|
||||
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
self._csv_file = None
|
||||
self._csv_writer = None
|
||||
|
||||
self._first_write: bool = True
|
||||
|
||||
def _init_csv_writer(self):
|
||||
self._csv_file = open(
|
||||
self._csv_file_path, "w", encoding="UTF8", newline=""
|
||||
)
|
||||
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
def __del__(self):
|
||||
if self._csv_file:
|
||||
self._csv_file.close()
|
||||
_LOGGER.info(f"Finished writing file: {self._csv_file_path}")
|
||||
|
||||
def write(
|
||||
self,
|
||||
data: Union[Tuple, Transaction]
|
||||
):
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
else:
|
||||
header = self._AV_REWARD_PER_EPISODE_HEADER
|
||||
|
||||
if self._first_write:
|
||||
self._init_csv_writer()
|
||||
self._csv_writer.writerow(header)
|
||||
self._first_write = False
|
||||
|
||||
self._csv_writer.writerow(data)
|
||||
@@ -37,7 +37,6 @@ def _get_primaite_env_from_config(
|
||||
env = Primaite(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user