Merged PR 121: #1629 - Added rllib test
## Summary Quick test that uses RLLIB in a session ## Test process The learning session completes then we check that the number of rows in both the average reward per episode and all transactions csv files. ## Checklist - [X] This PR is linked to a **work item** - [X] I have performed **self-review** of the code - [X] I have written **tests** for any new functionality added with this PR - [ ] I have updated the **documentation** if this PR changes or adds functionality - [X] I have run **pre-commit** checks for code style #1629 - Added rllib test Related work items: #1629
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
# Using polars as it's faster than Pandas; it will speed things up when
|
# Using polars as it's faster than Pandas; it will speed things up when
|
||||||
# files get big!
|
# files get big!
|
||||||
@@ -13,8 +13,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
|||||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||||
|
|
||||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||||
:return: The average rewards per episode cdv as a dict.
|
:return: The average rewards per episode csv as a dict.
|
||||||
"""
|
"""
|
||||||
df = pl.read_csv(av_rewards_csv_file).to_dict()
|
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||||
|
|
||||||
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
|
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
|
||||||
|
|
||||||
|
|
||||||
|
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Read an all transactions csv file and return as a dict.
|
||||||
|
|
||||||
|
The dict keys are a tuple with the structure (episode, step). The dict
|
||||||
|
values are the remaining columns as a dict.
|
||||||
|
|
||||||
|
:param all_transactions_csv_file: The all transactions csv file path.
|
||||||
|
:return: The all transactions csv file as a dict.
|
||||||
|
"""
|
||||||
|
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
|
||||||
|
new_dict = {}
|
||||||
|
|
||||||
|
episodes = df_dict["Episode"]
|
||||||
|
steps = df_dict["Step"]
|
||||||
|
keys = list(df_dict.keys())
|
||||||
|
|
||||||
|
for i in range(len(episodes)):
|
||||||
|
key = (episodes[i], steps[i])
|
||||||
|
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
|
||||||
|
new_dict[key] = value_dict
|
||||||
|
|
||||||
|
return new_dict
|
||||||
|
|||||||
163
tests/config/training_config_main_rllib.yaml
Normal file
163
tests/config/training_config_main_rllib.yaml
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
# Training Config File
|
||||||
|
|
||||||
|
# Sets which agent algorithm framework will be used.
|
||||||
|
# Options are:
|
||||||
|
# "SB3" (Stable Baselines3)
|
||||||
|
# "RLLIB" (Ray RLlib)
|
||||||
|
# "CUSTOM" (Custom Agent)
|
||||||
|
agent_framework: RLLIB
|
||||||
|
|
||||||
|
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||||
|
# Default is TF (Tensorflow).
|
||||||
|
# Options are:
|
||||||
|
# "TF" (Tensorflow)
|
||||||
|
# TF2 (Tensorflow 2.X)
|
||||||
|
# TORCH (PyTorch)
|
||||||
|
deep_learning_framework: TF2
|
||||||
|
|
||||||
|
# Sets which Agent class will be used.
|
||||||
|
# Options are:
|
||||||
|
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||||
|
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||||
|
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||||
|
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||||
|
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||||
|
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||||
|
agent_identifier: PPO
|
||||||
|
|
||||||
|
# Sets whether Red Agent POL and IER is randomised.
|
||||||
|
# Options are:
|
||||||
|
# True
|
||||||
|
# False
|
||||||
|
random_red_agent: False
|
||||||
|
|
||||||
|
# The (integer) seed to be used in random number generation
|
||||||
|
# Default is None (null)
|
||||||
|
seed: null
|
||||||
|
|
||||||
|
# Set whether the agent will be deterministic instead of stochastic
|
||||||
|
# Options are:
|
||||||
|
# True
|
||||||
|
# False
|
||||||
|
deterministic: False
|
||||||
|
|
||||||
|
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||||
|
# Options are:
|
||||||
|
# "BASIC" (The current observation space only)
|
||||||
|
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||||
|
hard_coded_agent_view: FULL
|
||||||
|
|
||||||
|
# Sets How the Action Space is defined:
|
||||||
|
# "NODE"
|
||||||
|
# "ACL"
|
||||||
|
# "ANY" node and acl actions
|
||||||
|
action_type: NODE
|
||||||
|
# observation space
|
||||||
|
observation_space:
|
||||||
|
# flatten: true
|
||||||
|
components:
|
||||||
|
- name: NODE_LINK_TABLE
|
||||||
|
# - name: NODE_STATUSES
|
||||||
|
# - name: LINK_TRAFFIC_LEVELS
|
||||||
|
|
||||||
|
|
||||||
|
# Number of episodes for training to run per session
|
||||||
|
num_train_episodes: 10
|
||||||
|
|
||||||
|
# Number of time_steps for training per episode
|
||||||
|
num_train_steps: 256
|
||||||
|
|
||||||
|
# Number of episodes for evaluation to run per session
|
||||||
|
num_eval_episodes: 1
|
||||||
|
|
||||||
|
# Number of time_steps for evaluation per episode
|
||||||
|
num_eval_steps: 256
|
||||||
|
|
||||||
|
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||||
|
# Set to 0 if no checkpoints are required. Default is 10
|
||||||
|
checkpoint_every_n_episodes: 10
|
||||||
|
|
||||||
|
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||||
|
time_delay: 5
|
||||||
|
|
||||||
|
# Type of session to be run. Options are:
|
||||||
|
# "TRAIN" (Trains an agent)
|
||||||
|
# "EVAL" (Evaluates an agent)
|
||||||
|
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||||
|
session_type: TRAIN_EVAL
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observation_space_high_value: 1000000000
|
||||||
|
|
||||||
|
# The Stable Baselines3 learn/eval output verbosity level:
|
||||||
|
# Options are:
|
||||||
|
# "NONE" (No Output)
|
||||||
|
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||||
|
# "DEBUG" (All Messages)
|
||||||
|
sb3_output_verbose_level: NONE
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
all_ok: 0
|
||||||
|
# Node Hardware State
|
||||||
|
off_should_be_on: -0.001
|
||||||
|
off_should_be_resetting: -0.0005
|
||||||
|
on_should_be_off: -0.0002
|
||||||
|
on_should_be_resetting: -0.0005
|
||||||
|
resetting_should_be_on: -0.0005
|
||||||
|
resetting_should_be_off: -0.0002
|
||||||
|
resetting: -0.0003
|
||||||
|
# Node Software or Service State
|
||||||
|
good_should_be_patching: 0.0002
|
||||||
|
good_should_be_compromised: 0.0005
|
||||||
|
good_should_be_overwhelmed: 0.0005
|
||||||
|
patching_should_be_good: -0.0005
|
||||||
|
patching_should_be_compromised: 0.0002
|
||||||
|
patching_should_be_overwhelmed: 0.0002
|
||||||
|
patching: -0.0003
|
||||||
|
compromised_should_be_good: -0.002
|
||||||
|
compromised_should_be_patching: -0.002
|
||||||
|
compromised_should_be_overwhelmed: -0.002
|
||||||
|
compromised: -0.002
|
||||||
|
overwhelmed_should_be_good: -0.002
|
||||||
|
overwhelmed_should_be_patching: -0.002
|
||||||
|
overwhelmed_should_be_compromised: -0.002
|
||||||
|
overwhelmed: -0.002
|
||||||
|
# Node File System State
|
||||||
|
good_should_be_repairing: 0.0002
|
||||||
|
good_should_be_restoring: 0.0002
|
||||||
|
good_should_be_corrupt: 0.0005
|
||||||
|
good_should_be_destroyed: 0.001
|
||||||
|
repairing_should_be_good: -0.0005
|
||||||
|
repairing_should_be_restoring: 0.0002
|
||||||
|
repairing_should_be_corrupt: 0.0002
|
||||||
|
repairing_should_be_destroyed: 0.0000
|
||||||
|
repairing: -0.0003
|
||||||
|
restoring_should_be_good: -0.001
|
||||||
|
restoring_should_be_repairing: -0.0002
|
||||||
|
restoring_should_be_corrupt: 0.0001
|
||||||
|
restoring_should_be_destroyed: 0.0002
|
||||||
|
restoring: -0.0006
|
||||||
|
corrupt_should_be_good: -0.001
|
||||||
|
corrupt_should_be_repairing: -0.001
|
||||||
|
corrupt_should_be_restoring: -0.001
|
||||||
|
corrupt_should_be_destroyed: 0.0002
|
||||||
|
corrupt: -0.001
|
||||||
|
destroyed_should_be_good: -0.002
|
||||||
|
destroyed_should_be_repairing: -0.002
|
||||||
|
destroyed_should_be_restoring: -0.002
|
||||||
|
destroyed_should_be_corrupt: -0.002
|
||||||
|
destroyed: -0.002
|
||||||
|
scanning: -0.0002
|
||||||
|
# IER status
|
||||||
|
red_ier_running: -0.0005
|
||||||
|
green_ier_blocked: -0.001
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
os_patching_duration: 5 # The time taken to patch the OS
|
||||||
|
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||||
|
service_patching_duration: 5 # The time taken to patch a service
|
||||||
|
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||||
|
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||||
|
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||||
@@ -5,7 +5,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -13,7 +13,7 @@ import pytest
|
|||||||
from primaite import getLogger
|
from primaite import getLogger
|
||||||
from primaite.environment.primaite_env import Primaite
|
from primaite.environment.primaite_env import Primaite
|
||||||
from primaite.primaite_session import PrimaiteSession
|
from primaite.primaite_session import PrimaiteSession
|
||||||
from primaite.utils.session_output_reader import av_rewards_dict
|
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
|
||||||
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
||||||
|
|
||||||
ACTION_SPACE_NODE_VALUES = 1
|
ACTION_SPACE_NODE_VALUES = 1
|
||||||
@@ -37,16 +37,26 @@ class TempPrimaiteSession(PrimaiteSession):
|
|||||||
super().__init__(training_config_path, lay_down_config_path)
|
super().__init__(training_config_path, lay_down_config_path)
|
||||||
self.setup()
|
self.setup()
|
||||||
|
|
||||||
def learn_av_reward_per_episode(self) -> Dict[int, float]:
|
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||||
"""Get the learn av reward per episode from file."""
|
"""Get the learn av reward per episode from file."""
|
||||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||||
return av_rewards_dict(self.learning_path / csv_file)
|
return av_rewards_dict(self.learning_path / csv_file)
|
||||||
|
|
||||||
def eval_av_reward_per_episode_csv(self) -> Dict[int, float]:
|
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||||
"""Get the eval av reward per episode from file."""
|
"""Get the eval av reward per episode from file."""
|
||||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||||
return av_rewards_dict(self.evaluation_path / csv_file)
|
return av_rewards_dict(self.evaluation_path / csv_file)
|
||||||
|
|
||||||
|
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||||
|
"""Get the learn all transactions from file."""
|
||||||
|
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||||
|
return all_transactions_dict(self.learning_path / csv_file)
|
||||||
|
|
||||||
|
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||||
|
"""Get the eval all transactions from file."""
|
||||||
|
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||||
|
return all_transactions_dict(self.evaluation_path / csv_file)
|
||||||
|
|
||||||
def metadata_file_as_dict(self) -> Dict[str, Any]:
|
def metadata_file_as_dict(self) -> Dict[str, Any]:
|
||||||
"""Read the session_metadata.json file and return as a dict."""
|
"""Read the session_metadata.json file and return as a dict."""
|
||||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||||
|
|||||||
@@ -48,5 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function(
|
|||||||
"""
|
"""
|
||||||
with temp_primaite_session as session:
|
with temp_primaite_session as session:
|
||||||
session.evaluate()
|
session.evaluate()
|
||||||
ev_rewards = session.eval_av_reward_per_episode_csv()
|
ev_rewards = session.eval_av_reward_per_episode_dict()
|
||||||
assert ev_rewards[1] == -8.0
|
assert ev_rewards[1] == -8.0
|
||||||
|
|||||||
23
tests/test_rllib_agent.py
Normal file
23
tests/test_rllib_agent.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from primaite import getLogger
|
||||||
|
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||||
|
from tests import TEST_CONFIG_ROOT
|
||||||
|
|
||||||
|
_LOGGER = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"temp_primaite_session",
|
||||||
|
[[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
def test_primaite_session(temp_primaite_session):
|
||||||
|
"""Test the training_config_main_rllib.yaml training config file."""
|
||||||
|
with temp_primaite_session as session:
|
||||||
|
session_path = session.session_path
|
||||||
|
assert session_path.exists()
|
||||||
|
session.learn()
|
||||||
|
|
||||||
|
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
|
||||||
|
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
|
||||||
@@ -28,7 +28,7 @@ def test_seeded_learning(temp_primaite_session):
|
|||||||
"Expected output is based upon a agent that was trained with " "seed 67890"
|
"Expected output is based upon a agent that was trained with " "seed 67890"
|
||||||
)
|
)
|
||||||
session.learn()
|
session.learn()
|
||||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
|
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
|
||||||
|
|
||||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||||
|
|
||||||
@@ -45,5 +45,5 @@ def test_deterministic_evaluation(temp_primaite_session):
|
|||||||
# do stuff
|
# do stuff
|
||||||
session.learn()
|
session.learn()
|
||||||
session.evaluate()
|
session.evaluate()
|
||||||
eval_mean_reward = session.eval_av_reward_per_episode_csv()
|
eval_mean_reward = session.eval_av_reward_per_episode_dict()
|
||||||
assert len(set(eval_mean_reward.values())) == 1
|
assert len(set(eval_mean_reward.values())) == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user