diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py index ad3dd4f4..2ff4a16a 100644 --- a/src/primaite/utils/session_output_reader.py +++ b/src/primaite/utils/session_output_reader.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, Union +from typing import Any, Dict, Tuple, Union # Using polars as it's faster than Pandas; it will speed things up when # files get big! @@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: The dictionary keys are the episode number, and the values are the mean reward that episode. :param av_rewards_csv_file: The average rewards per episode csv file path. - :return: The average rewards per episode cdv as a dict. + :return: The average rewards per episode csv as a dict. """ - df = pl.read_csv(av_rewards_csv_file).to_dict() + df_dict = pl.read_csv(av_rewards_csv_file).to_dict() - return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])} + return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])} + + +def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]: + """ + Read an all transactions csv file and return as a dict. + + The dict keys are a tuple with the structure (episode, step). The dict + values are the remaining columns as a dict. + + :param all_transactions_csv_file: The all transactions csv file path. + :return: The all transactions csv file as a dict. + """ + df_dict = pl.read_csv(all_transactions_csv_file).to_dict() + new_dict = {} + + episodes = df_dict["Episode"] + steps = df_dict["Step"] + keys = list(df_dict.keys()) + + for i in range(len(episodes)): + key = (episodes[i], steps[i]) + value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]} + new_dict[key] = value_dict + + return new_dict diff --git a/tests/config/training_config_main_rllib.yaml b/tests/config/training_config_main_rllib.yaml new file mode 100644 index 00000000..88f82890 --- /dev/null +++ b/tests/config/training_config_main_rllib.yaml @@ -0,0 +1,163 @@ +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: RLLIB + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS + + +# Number of episodes for training to run per session +num_train_episodes: 10 + +# Number of time_steps for training per episode +num_train_steps: 256 + +# Number of episodes for evaluation to run per session +num_eval_episodes: 1 + +# Number of time_steps for evaluation per episode +num_eval_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 +# Node Software or Service State +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 +# Node File System State +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 +# IER status +red_ier_running: -0.0005 +green_ier_blocked: -0.001 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/conftest.py b/tests/conftest.py index 32a7edcf..3f022b6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import shutil import tempfile from datetime import datetime from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Tuple, Union from unittest.mock import patch import pytest @@ -13,7 +13,7 @@ import pytest from primaite import getLogger from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession -from primaite.utils.session_output_reader import av_rewards_dict +from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession): super().__init__(training_config_path, lay_down_config_path) self.setup() - def learn_av_reward_per_episode(self) -> Dict[int, float]: + def learn_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the learn av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.learning_path / csv_file) - def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: + def eval_av_reward_per_episode_dict(self) -> Dict[int, float]: """Get the eval av reward per episode from file.""" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" return av_rewards_dict(self.evaluation_path / csv_file) + def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the learn all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.learning_path / csv_file) + + def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]: + """Get the eval all transactions from file.""" + csv_file = f"all_transactions_{self.timestamp_str}.csv" + return all_transactions_dict(self.evaluation_path / csv_file) + def metadata_file_as_dict(self) -> Dict[str, Any]: """Read the session_metadata.json file and return as a dict.""" with open(self.session_path / "session_metadata.json", "r") as file: diff --git a/tests/test_reward.py b/tests/test_reward.py index bb6eb1b0..2edfd44a 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function( """ with temp_primaite_session as session: session.evaluate() - ev_rewards = session.eval_av_reward_per_episode_csv() + ev_rewards = session.eval_av_reward_per_episode_dict() assert ev_rewards[1] == -8.0 diff --git a/tests/test_rllib_agent.py b/tests/test_rllib_agent.py new file mode 100644 index 00000000..645214e3 --- /dev/null +++ b/tests/test_rllib_agent.py @@ -0,0 +1,23 @@ +import pytest + +from primaite import getLogger +from primaite.config.lay_down_config import dos_very_basic_config_path +from tests import TEST_CONFIG_ROOT + +_LOGGER = getLogger(__name__) + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]], + indirect=True, +) +def test_primaite_session(temp_primaite_session): + """Test the training_config_main_rllib.yaml training config file.""" + with temp_primaite_session as session: + session_path = session.session_path + assert session_path.exists() + session.learn() + + assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 + assert len(session.learn_all_transactions_dict().keys()) == 10 * 256 diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 34cb43fb..f52e9eee 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -28,7 +28,7 @@ def test_seeded_learning(temp_primaite_session): "Expected output is based upon a agent that was trained with " "seed 67890" ) session.learn() - actual_mean_reward_per_episode = session.learn_av_reward_per_episode() + actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict() assert actual_mean_reward_per_episode == expected_mean_reward_per_episode @@ -45,5 +45,5 @@ def test_deterministic_evaluation(temp_primaite_session): # do stuff session.learn() session.evaluate() - eval_mean_reward = session.eval_av_reward_per_episode_csv() + eval_mean_reward = session.eval_av_reward_per_episode_dict() assert len(set(eval_mean_reward.values())) == 1