Merged PR 132: #1594 - Managed to get the evaluation of rllib agents working. A test has bee...

## Summary
Managed to get the evaluation of rllib agents working. A test has been added to test_primaite_session.py that now tests the full RLlib agent from end-to-end. I've also updated the tests in here to check that the mean reward per episode plot is created for both too. This will need a bit of a re-design further down the line, but for now, it works. Added a custom exception for RLlib eval only error.

Is this a hack? Yes. Does it work? Yes. we'll make this better later.

## Test process
Both a SB3 and Ray RLlib agent is tested now in the test_primaite_session.py module.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1594
This commit is contained in:
Christopher McCarthy
2023-07-21 10:26:31 +00:00
9 changed files with 272 additions and 41 deletions

View File

@@ -230,7 +230,6 @@ class AgentSessionABC(ABC):
self._update_session_metadata_file()
self._can_evaluate = True
self.is_eval = False
self._plot_av_reward_per_episode(learning_session=True)
@abstractmethod
def evaluate(
@@ -243,9 +242,9 @@ class AgentSessionABC(ABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_evaluate:
self._plot_av_reward_per_episode(learning_session=False)
self._update_session_metadata_file()
self.is_eval = True
self._plot_av_reward_per_episode(learning_session=False)
_LOGGER.info("Finished evaluation")
@abstractmethod

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import shutil
import zipfile
from datetime import datetime
from logging import Logger
from pathlib import Path
@@ -17,8 +18,9 @@ from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
from primaite.environment.primaite_env import Primaite
from primaite.exceptions import RLlibAgentError
_LOGGER: Logger = getLogger(__name__)
@@ -68,11 +70,14 @@ class RLlibAgent(AgentSessionABC):
# TODO: implement RLlib agent loading
if session_path is not None:
msg = "RLlib agent loading has not been implemented yet"
_LOGGER.error(msg)
print(msg)
raise NotImplementedError
_LOGGER.critical(msg)
raise NotImplementedError(msg)
super().__init__(training_config_path, lay_down_config_path)
if self._training_config.session_type == SessionType.EVAL:
msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
_LOGGER.critical(msg)
raise RLlibAgentError(msg)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
@@ -98,6 +103,7 @@ class RLlibAgent(AgentSessionABC):
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
self._train_agent = None # Required to capture the learning agent to close after eval
def _update_session_metadata_file(self) -> None:
"""
@@ -179,20 +185,73 @@ class RLlibAgent(AgentSessionABC):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent = self._agent
else:
self._agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
def _unpack_saved_agent_into_eval(self) -> Path:
"""Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
agent_restore_path = self.evaluation_path / "agent_restore"
if agent_restore_path.exists():
shutil.rmtree(agent_restore_path)
agent_restore_path.mkdir()
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
zip_file.extractall(agent_restore_path)
return agent_restore_path
def _setup_eval(self):
self._can_learn = False
self._can_evaluate = True
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
def evaluate(
self,
**kwargs: None,
) -> None:
**kwargs,
):
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
raise NotImplementedError
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._setup_eval()
self._env: Primaite = Primaite(
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
)
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action = self._agent.compute_single_action(observation=obs, explore=False)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
# Now we're safe to close the learning agent and write the mean rewards per episode for it
if self._training_config.session_type is not SessionType.TRAIN:
self._train_agent.stop()
self._plot_av_reward_per_episode(learning_session=True)
# Perform a clean-up of the unpacked agent
if (self.evaluation_path / "agent_restore").exists():
shutil.rmtree((self.evaluation_path / "agent_restore"))
def _get_latest_checkpoint(self) -> None:
raise NotImplementedError

View File

@@ -153,6 +153,8 @@ class SB3Agent(AgentSessionABC):
# save agent
self.save()
self._plot_av_reward_per_episode(learning_session=True)
def evaluate(
self,
**kwargs: Any,

View File

@@ -261,10 +261,12 @@ class Primaite(Env):
self, transaction_writer=True, learning_session=True
)
self.is_eval = False
@property
def actual_episode_count(self) -> int:
"""Shifts the episode_count by -1 for RLlib."""
if self.training_config.agent_framework is AgentFramework.RLLIB:
"""Shifts the episode_count by -1 for RLlib learning session."""
if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
return self.episode_count - 1
return self.episode_count
@@ -276,6 +278,7 @@ class Primaite(Env):
self.step_count = 0
self.total_step_count = 0
self.episode_steps = self.training_config.num_eval_steps
self.is_eval = True
def _write_av_reward_per_episode(self) -> None:
if self.actual_episode_count > 0:

View File

@@ -0,0 +1,10 @@
class PrimaiteError(Exception):
"""The root PrimAITe Error."""
pass
class RLlibAgentError(PrimaiteError):
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
pass

View File

@@ -69,7 +69,7 @@ num_train_episodes: 10
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
num_eval_episodes: 3
# Number of time_steps for evaluation per episode
num_eval_steps: 256

View File

@@ -0,0 +1,164 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 3
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -5,18 +5,25 @@ import pytest
from primaite import getLogger
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from tests import TEST_CONFIG_ROOT
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[[main_training_config_path(), dos_very_basic_config_path()]],
[
[TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
[TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()],
],
indirect=True,
)
def test_primaite_session(temp_primaite_session):
"""Tests the PrimaiteSession class and its outputs."""
"""
Tests the PrimaiteSession class and all of its outputs.
This test runs for both a Stable Baselines3 agent, and a Ray RLlib agent.
"""
with temp_primaite_session as session:
session_path = session.session_path
assert session_path.exists()
@@ -47,6 +54,17 @@ def test_primaite_session(temp_primaite_session):
if file.suffix == ".csv":
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
# Check that the average reward per episode plots exist
assert (session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
assert (session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
# Check that the metadata has captured the correct number of learning and eval episodes and steps
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
assert len(session.eval_av_reward_per_episode_dict().keys()) == 3
assert len(session.eval_all_transactions_dict().keys()) == 3 * 256
_LOGGER.debug("Inspecting files in temp session path...")
for dir_path, dir_names, file_names in os.walk(session_path):
for file in file_names:

View File

@@ -1,24 +0,0 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import pytest
from primaite import getLogger
from primaite.config.lay_down_config import dos_very_basic_config_path
from tests import TEST_CONFIG_ROOT
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_primaite_session(temp_primaite_session):
"""Test the training_config_main_rllib.yaml training config file."""
with temp_primaite_session as session:
session_path = session.session_path
assert session_path.exists()
session.learn()
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256