Merged PR 121: #1629 - Added rllib test

## Summary
Quick test that uses RLLIB in a session

## Test process
The learning session completes then we check that the number of rows in both the average reward per episode and all transactions csv files.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

#1629 - Added rllib test

Related work items: #1629
This commit is contained in:
Christopher McCarthy
2023-07-17 17:28:51 +00:00
6 changed files with 232 additions and 11 deletions

View File

@@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Any, Dict, Tuple, Union
# Using polars as it's faster than Pandas; it will speed things up when # Using polars as it's faster than Pandas; it will speed things up when
# files get big! # files get big!
@@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
The dictionary keys are the episode number, and the values are the mean reward that episode. The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path. :param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict. :return: The average rewards per episode csv as a dict.
""" """
df = pl.read_csv(av_rewards_csv_file).to_dict() df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])} return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""
Read an all transactions csv file and return as a dict.
The dict keys are a tuple with the structure (episode, step). The dict
values are the remaining columns as a dict.
:param all_transactions_csv_file: The all transactions csv file path.
:return: The all transactions csv file as a dict.
"""
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
new_dict = {}
episodes = df_dict["Episode"]
steps = df_dict["Step"]
keys = list(df_dict.keys())
for i in range(len(episodes)):
key = (episodes[i], steps[i])
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
new_dict[key] = value_dict
return new_dict

View File

@@ -0,0 +1,163 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: RLLIB
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -5,7 +5,7 @@ import shutil
import tempfile import tempfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Union from typing import Any, Dict, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -13,7 +13,7 @@ import pytest
from primaite import getLogger from primaite import getLogger
from primaite.environment.primaite_env import Primaite from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_VALUES = 1
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
super().__init__(training_config_path, lay_down_config_path) super().__init__(training_config_path, lay_down_config_path)
self.setup() self.setup()
def learn_av_reward_per_episode(self) -> Dict[int, float]: def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file.""" """Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file) return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file.""" """Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file) return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]: def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict.""" """Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file: with open(self.session_path / "session_metadata.json", "r") as file:

View File

@@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function(
""" """
with temp_primaite_session as session: with temp_primaite_session as session:
session.evaluate() session.evaluate()
ev_rewards = session.eval_av_reward_per_episode_csv() ev_rewards = session.eval_av_reward_per_episode_dict()
assert ev_rewards[1] == -8.0 assert ev_rewards[1] == -8.0

23
tests/test_rllib_agent.py Normal file
View File

@@ -0,0 +1,23 @@
import pytest
from primaite import getLogger
from primaite.config.lay_down_config import dos_very_basic_config_path
from tests import TEST_CONFIG_ROOT
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_primaite_session(temp_primaite_session):
"""Test the training_config_main_rllib.yaml training config file."""
with temp_primaite_session as session:
session_path = session.session_path
assert session_path.exists()
session.learn()
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256

View File

@@ -28,7 +28,7 @@ def test_seeded_learning(temp_primaite_session):
"Expected output is based upon a agent that was trained with " "seed 67890" "Expected output is based upon a agent that was trained with " "seed 67890"
) )
session.learn() session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode() actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@@ -45,5 +45,5 @@ def test_deterministic_evaluation(temp_primaite_session):
# do stuff # do stuff
session.learn() session.learn()
session.evaluate() session.evaluate()
eval_mean_reward = session.eval_av_reward_per_episode_csv() eval_mean_reward = session.eval_av_reward_per_episode_dict()
assert len(set(eval_mean_reward.values())) == 1 assert len(set(eval_mean_reward.values())) == 1