#917 - Overhauled transaction and mean reward writing.

- Separated out learning outputs from evaluation outputs
This commit is contained in:
Chris McCarthy
2023-06-28 16:34:00 +01:00
parent 7482192046
commit 1d3778f400
13 changed files with 258 additions and 250 deletions

View File

@@ -21,6 +21,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
def _get_primaite_config():
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
if not config_path.exists():
config_path = Path(
pkg_resources.resource_filename(
"primaite", "setup/_package_data/primaite_config.yaml"
@@ -36,7 +37,7 @@ def _get_primaite_config():
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
return primaite_config
@@ -108,11 +109,11 @@ def _log_dir() -> Path:
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"]
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"]
}
)
@@ -132,10 +133,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
backupCount=9, # Max 100MB of logs
encoding="utf8",
)
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"]
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
@@ -145,7 +146,7 @@ _LOGGER.addHandler(_STREAM_HANDLER)
_LOGGER.addHandler(_FILE_HANDLER)
def getLogger(name: str) -> Logger:
def getLogger(name: str) -> Logger: # noqa
"""
Get a PrimAITE logger.

View File

@@ -64,7 +64,11 @@ class AgentSessionABC(ABC):
"The session timestamp"
self.session_path = _get_session_path(self.session_timestamp)
"The Session path"
self.checkpoints_path = self.session_path / "checkpoints"
self.learning_path = self.session_path / "learning"
"The learning outputs path"
self.evaluation_path = self.session_path / "evaluation"
"The evaluation outputs path"
self.checkpoints_path = self.learning_path / "checkpoints"
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
"The Session checkpoints path"
@@ -205,7 +209,6 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=[],
session_path=self.session_path,
timestamp_str=self.timestamp_str
)

View File

@@ -21,7 +21,6 @@ def _env_creator(env_config):
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
transaction_list=env_config["transaction_list"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"]
)
@@ -106,7 +105,6 @@ class RLlibAgent(AgentSessionABC):
env_config=dict(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=[],
session_path=self.session_path,
timestamp_str=self.timestamp_str
)

View File

@@ -34,7 +34,7 @@ class SB3Agent(AgentSessionABC):
_LOGGER.error(msg)
raise ValueError(msg)
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
_LOGGER.debug(
@@ -49,7 +49,6 @@ class SB3Agent(AgentSessionABC):
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=[],
session_path=self.session_path,
timestamp_str=self.timestamp_str
)
@@ -108,10 +107,13 @@ class SB3Agent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
_LOGGER.info(f"Beginning evaluation for {episodes} episodes @"
f" {time_steps} time steps...")
self._env.set_as_eval()
if deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(f"Beginning {deterministic_str} evaluation for "
f"{episodes} episodes @ {time_steps} time steps...")
for episode in range(episodes):
obs = self._env.reset()
@@ -123,6 +125,7 @@ class SB3Agent(AgentSessionABC):
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
_LOGGER.info(f"Finished evaluation")
@classmethod
def load(self):

View File

@@ -38,7 +38,7 @@ hard_coded_agent_view: FULL
action_type: ANY
# Number of episodes to run per session
num_episodes: 10
num_episodes: 1000
# Number of time_steps per episode
num_steps: 256

View File

@@ -1,10 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import csv
import logging
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Tuple, Union, Final
@@ -14,6 +10,7 @@ import yaml
from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, \
is_valid_node_action
@@ -27,7 +24,7 @@ from primaite.common.enums import (
NodeType,
ObservationType,
Priority,
SoftwareState,
SoftwareState, SessionType,
)
from primaite.common.service import Service
from primaite.config import training_config
@@ -47,11 +44,9 @@ from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
from primaite.utils.session_output_writer import SessionOutputWriter
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
_LOGGER = getLogger(__name__)
class Primaite(Env):
@@ -67,7 +62,6 @@ class Primaite(Env):
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
):
@@ -76,7 +70,6 @@ class Primaite(Env):
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
:param transaction_list: The list of transactions to populate.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
@@ -96,9 +89,6 @@ class Primaite(Env):
super(Primaite, self).__init__()
# Transaction list
self.transaction_list = transaction_list
# The agent in use
self.agent_identifier = self.training_config.agent_identifier
@@ -245,20 +235,31 @@ class Primaite(Env):
_LOGGER.error(
f"Invalid action type selected: {self.training_config.action_type}"
)
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
file_name = f"average_reward_per_episode_{timestamp_str}.csv"
file_path = session_path / file_name
self.csv_file = open(file_path, "w", encoding="UTF8", newline="")
self.csv_writer = csv.writer(self.csv_file)
self.csv_writer.writerow(header)
except Exception:
_LOGGER.error(
"Could not create csv file to hold average reward per episode"
)
_LOGGER.error("Exception occured", exc_info=True)
self.episode_av_reward_writer = SessionOutputWriter(
self,
transaction_writer=False,
learning_session=True
)
self.transaction_writer = SessionOutputWriter(
self,
transaction_writer=True,
learning_session=True
)
def set_as_eval(self):
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(
self,
transaction_writer=False,
learning_session=False
)
self.transaction_writer = SessionOutputWriter(
self,
transaction_writer=True,
learning_session=False
)
self.episode_count = 0
def reset(self):
"""
@@ -267,12 +268,14 @@ class Primaite(Env):
Returns:
Environment observation space (reset)
"""
csv_data = self.episode_count, self.average_reward
self.csv_writer.writerow(csv_data)
if self.episode_count > 0:
csv_data = self.episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
self.episode_count += 1
# Don't need to reset links, as they are cleared and recalculated every step
# Don't need to reset links, as they are cleared and recalculated every
# step
# Clear the ACL
self.init_acl()
@@ -303,12 +306,8 @@ class Primaite(Env):
done: Indicates episode is complete if True
step_info: Additional information relating to this step
"""
if self.step_count == 0:
_LOGGER.info(f"Episode: {str(self.episode_count)}")
# TEMP
done = False
self.step_count += 1
self.total_step_count += 1
@@ -321,13 +320,16 @@ class Primaite(Env):
# Create a Transaction (metric) object for this step
transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count,
self.agent_identifier,
self.episode_count,
self.step_count
)
# Load the initial observation space into the transaction
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
# Load the action space into the transaction
transaction.set_action_space(copy.deepcopy(action))
transaction.action_space = copy.deepcopy(action)
initial_nodes = copy.deepcopy(self.nodes)
# 1. Implement Blue Action
self.interpret_action_and_apply(action)
@@ -381,7 +383,7 @@ class Primaite(Env):
# 5. Calculate reward signal (for RL)
reward = calculate_reward_function(
self.nodes_post_pol,
initial_nodes,
self.nodes_post_red,
self.nodes_reference,
self.green_iers,
@@ -390,17 +392,22 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
_LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}")
_LOGGER.debug(
f"Episode: {self.episode_count}, "
f"Step {self.step_count}, "
f"Reward: {reward}"
)
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
if self.training_config.session_type == "EVALUATION":
if self.training_config.session_type is SessionType.EVAL:
# For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True
_LOGGER.info(f" Average Reward: {str(self.average_reward)}")
_LOGGER.info(f"Episode: {self.episode_count}, "
f"Average Reward: {self.average_reward}")
# Load the reward into the transaction
transaction.set_reward(reward)
transaction.reward = reward
# 6. Output Verbose
# self.output_link_status()
@@ -408,28 +415,14 @@ class Primaite(Env):
# 7. Update env_obs
self.update_environent_obs()
# Load the new observation space into the transaction
transaction.set_obs_space_post(copy.deepcopy(self.env_obs))
transaction.obs_space_post = copy.deepcopy(self.env_obs)
# 8. Add the transaction to the list of transactions
self.transaction_list.append(copy.deepcopy(transaction))
# Write transaction to file
self.transaction_writer.write(transaction)
# Return
return self.env_obs, reward, done, self.step_info
def close(self):
self.__close__()
def __close__(self):
"""
Override close function
"""
write_transaction_to_file(
self.transaction_list,
self.session_path,
self.timestamp_str
)
self.csv_file.close()
def init_acl(self):
"""Initialise the Access Control List."""
self.acl.remove_all_rules()
@@ -467,7 +460,7 @@ class Primaite(Env):
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
logging.error("Invalid action type found")
_LOGGER.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
"""

View File

@@ -85,9 +85,6 @@ def calculate_reward_function(
)
if live_blocked and not reference_blocked:
_LOGGER.debug(
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
)
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(

View File

@@ -20,6 +20,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
session.setup()
session.learn()
session.evaluate()
if __name__ == "__main__":

View File

@@ -1,10 +1,11 @@
# The main PrimAITE application config file
# Logging
log_level: INFO
logger_format:
DEBUG: '%(asctime)s: %(message)s'
INFO: '%(asctime)s: %(message)s'
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
logging:
log_level: INFO
logger_format:
DEBUG: '%(asctime)s: %(message)s'
INFO: '%(asctime)s: %(message)s'
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'

View File

@@ -1,57 +1,115 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The Transaction class."""
from datetime import datetime
from typing import List, Tuple
class Transaction(object):
"""Transaction class."""
def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number):
def __init__(
self,
agent_identifier,
episode_number,
step_number
):
"""
Init.
Transaction constructor.
Args:
_timestamp: The time this object was created
_agent_identifier: An identifier for the agent in use
_episode_number: The episode number
_step_number: The step number
:param agent_identifier: An identifier for the agent in use
:param episode_number: The episode number
:param step_number: The step number
"""
self.timestamp = _timestamp
self.agent_identifier = _agent_identifier
self.episode_number = _episode_number
self.step_number = _step_number
self.timestamp = datetime.now()
"The datetime of the transaction"
self.agent_identifier = agent_identifier
self.episode_number = episode_number
"The episode number"
self.step_number = step_number
"The step number"
self.obs_space_pre = None
"The observation space before any actions are taken"
self.obs_space_post = None
"The observation space after any actions are taken"
self.reward = None
"The reward value"
self.action_space = None
"The action space invoked by the agent"
def set_obs_space_pre(self, _obs_space_pre):
"""
Sets the observation space (pre).
def as_csv_data(self) -> Tuple[List, List]:
if isinstance(self.action_space, int):
action_length = self.action_space
else:
action_length = self.action_space.size
obs_shape = self.obs_space_post.shape
obs_assets = self.obs_space_post.shape[0]
if len(obs_shape) == 1:
# A bit of a workaround but I think the way transactions are
# written will change soon
obs_features = 1
else:
obs_features = self.obs_space_post.shape[1]
Args:
_obs_space_pre: The observation space before any actions are taken
"""
self.obs_space_pre = _obs_space_pre
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
def set_obs_space_post(self, _obs_space_post):
"""
Sets the observation space (post).
# Create the observation space headers array
obs_header_initial = []
obs_header_new = []
for x in range(obs_assets):
for y in range(obs_features):
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
Args:
_obs_space_post: The observation space after any actions are taken
"""
self.obs_space_post = _obs_space_post
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_header_initial + obs_header_new
def set_reward(self, _reward):
"""
Sets the reward.
row = [
str(self.timestamp),
str(self.episode_number),
str(self.step_number),
str(self.reward),
]
row = (
row
+ _turn_action_space_to_array(self.action_space)
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets,
obs_features)
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets,
obs_features)
)
return header, row
Args:
_reward: The reward value
"""
self.reward = _reward
def set_action_space(self, _action_space):
"""
Sets the action space.
def _turn_action_space_to_array(action_space) -> List[str]:
"""
Turns action space into a string array so it can be saved to csv.
Args:
_action_space: The action space invoked by the agent
"""
self.action_space = _action_space
:param action_space: The action space
:return: The action space as an array of strings
"""
if isinstance(action_space, list):
return [str(i) for i in action_space]
else:
return [str(action_space)]
def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]:
"""
Turns observation space into a string array so it can be saved to csv.
:param obs_space: The observation space
:param obs_assets: The number of assets (i.e. nodes or links) in the
observation space
:param obs_features: The number of features associated with the asset
:return: The observation space as an array of strings
"""
return_array = []
for x in range(obs_assets):
for y in range(obs_features):
return_array.append(str(obs_space[x][y]))
return return_array

View File

@@ -1,119 +0,0 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Writes the Transaction log list out to file for evaluation to utilse."""
import csv
from pathlib import Path
import numpy as np
from primaite import getLogger
_LOGGER = getLogger(__name__)
def turn_action_space_to_array(_action_space):
"""
Turns action space into a string array so it can be saved to csv.
Args:
_action_space: The action space.
"""
if isinstance(_action_space, list):
return [str(i) for i in _action_space]
else:
return [str(_action_space)]
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
"""
Turns observation space into a string array so it can be saved to csv.
Args:
_obs_space: The observation space
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
_obs_features: The number of features associated with the asset
"""
return_array = []
for x in range(_obs_assets):
for y in range(_obs_features):
return_array.append(str(_obs_space[x][y]))
return return_array
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
"""
Writes transaction logs to file to support training evaluation.
:param transaction_list: The list of transactions from all steps and all
episodes.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
# Get the first transaction and use it to determine the makeup of the
# observation space and action space
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
# space as "AS_1"
# This will be tied into the PrimAITE Use Case so that they make sense
template_transation = transaction_list[0]
if isinstance(template_transation.action_space, int):
action_length = template_transation.action_space
else:
action_length = template_transation.action_space.size
obs_shape = template_transation.obs_space_post.shape
obs_assets = template_transation.obs_space_post.shape[0]
if len(obs_shape) == 1:
# bit of a workaround but I think the way transactions are written will change soon
obs_features = 1
else:
obs_features = template_transation.obs_space_post.shape[1]
# Create the action space headers array
action_header = []
for x in range(action_length):
action_header.append("AS_" + str(x))
# Create the observation space headers array
obs_header_initial = []
obs_header_new = []
for x in range(obs_assets):
for y in range(obs_features):
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_header_initial + obs_header_new
try:
filename = session_path / f"all_transactions_{timestamp_str}.csv"
_LOGGER.debug(f"Saving transaction logs: {filename}")
csv_file = open(filename, "w", encoding="UTF8", newline="")
csv_writer = csv.writer(csv_file)
csv_writer.writerow(header)
for transaction in transaction_list:
csv_data = [
str(transaction.timestamp),
str(transaction.episode_number),
str(transaction.step_number),
str(transaction.reward),
]
csv_data = (
csv_data
+ turn_action_space_to_array(transaction.action_space)
+ turn_obs_space_to_array(
transaction.obs_space_pre, obs_assets, obs_features
)
+ turn_obs_space_to_array(
transaction.obs_space_post, obs_assets, obs_features
)
)
csv_writer.writerow(csv_data)
csv_file.close()
_LOGGER.debug("Finished writing transactions")
except Exception:
_LOGGER.error("Could not save the transaction file", exc_info=True)

View File

@@ -0,0 +1,73 @@
import csv
from logging import Logger
from typing import List, Final, IO, Union, Tuple
from typing import TYPE_CHECKING
from primaite import getLogger
from primaite.transactions.transaction import Transaction
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
class SessionOutputWriter:
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
"Episode", "Average Reward"
]
def __init__(
self,
env: "Primaite",
transaction_writer: bool = False,
learning_session: bool = True
):
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
if self.transaction_writer:
fn = f"all_transactions_{self._env.timestamp_str}.csv"
else:
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
if self.learning_session:
self._csv_file_path = self._env.session_path / "learning" / fn
else:
self._csv_file_path = self._env.session_path / "evaluation" / fn
self._csv_file_path.parent.mkdir(exist_ok=True, parents=True)
self._csv_file = None
self._csv_writer = None
self._first_write: bool = True
def _init_csv_writer(self):
self._csv_file = open(
self._csv_file_path, "w", encoding="UTF8", newline=""
)
self._csv_writer = csv.writer(self._csv_file)
def __del__(self):
if self._csv_file:
self._csv_file.close()
_LOGGER.info(f"Finished writing file: {self._csv_file_path}")
def write(
self,
data: Union[Tuple, Transaction]
):
if isinstance(data, Transaction):
header, data = data.as_csv_data()
else:
header = self._AV_REWARD_PER_EPISODE_HEADER
if self._first_write:
self._init_csv_writer()
self._csv_writer.writerow(header)
self._first_write = False
self._csv_writer.writerow(data)

View File

@@ -37,7 +37,6 @@ def _get_primaite_env_from_config(
env = Primaite(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
transaction_list=[],
session_path=session_path,
timestamp_str=timestamp_str,
)