#1594 - Managed to get the evaluation of rllib agents working. A test has been added to test_primaite_session.py that now tests the full RLlib agent from end-to-end. I;ve also updated the tests in here to check that the mean reward per episode plot is created for both too. This will need a bit of a re-design further down the line, but for now, it works. Added a custom exception for RLlib eval only error.
This commit is contained in:
@@ -230,7 +230,6 @@ class AgentSessionABC(ABC):
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
self.is_eval = False
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
@@ -243,9 +242,9 @@ class AgentSessionABC(ABC):
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_evaluate:
|
||||
self._plot_av_reward_per_episode(learning_session=False)
|
||||
self._update_session_metadata_file()
|
||||
self.is_eval = True
|
||||
self._plot_av_reward_per_episode(learning_session=False)
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
@@ -17,8 +18,9 @@ from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.exceptions import RLlibAgentError
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
@@ -68,11 +70,14 @@ class RLlibAgent(AgentSessionABC):
|
||||
# TODO: implement RLlib agent loading
|
||||
if session_path is not None:
|
||||
msg = "RLlib agent loading has not been implemented yet"
|
||||
_LOGGER.error(msg)
|
||||
print(msg)
|
||||
raise NotImplementedError
|
||||
_LOGGER.critical(msg)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if self._training_config.session_type == SessionType.EVAL:
|
||||
msg = "Cannot evaluate an RLlib agent that hasn't been through trainig yet."
|
||||
_LOGGER.critical(msg)
|
||||
raise RLlibAgentError(msg)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
@@ -98,6 +103,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
f"deep_learning_framework="
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
self._train_agent = None # Required to capture the learning agent to close after eval
|
||||
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
@@ -179,20 +185,69 @@ class RLlibAgent(AgentSessionABC):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self.save()
|
||||
self._agent.stop()
|
||||
|
||||
super().learn()
|
||||
# Done this way as the RLlib eval can only be performed if the session hasn't been stopped
|
||||
if self._training_config.session_type is not SessionType.TRAIN:
|
||||
self._train_agent = self._agent
|
||||
else:
|
||||
self._agent.stop()
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def _unpack_saved_agent_into_eval(self) -> Path:
|
||||
agent_restore_path = self.evaluation_path / "agent_restore"
|
||||
if agent_restore_path.exists():
|
||||
shutil.rmtree(agent_restore_path)
|
||||
agent_restore_path.mkdir()
|
||||
with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
|
||||
zip_file.extractall(agent_restore_path)
|
||||
return agent_restore_path
|
||||
|
||||
def _setup_eval(self):
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
self._agent.restore(str(self._unpack_saved_agent_into_eval()))
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: None,
|
||||
) -> None:
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
self._setup_eval()
|
||||
|
||||
self._env: Primaite = Primaite(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
|
||||
)
|
||||
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if self._training_config.deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
for step in range(time_steps):
|
||||
action = self._agent.compute_single_action(observation=obs, explore=False)
|
||||
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
# Now we're safe to close the learning agent and write the mean rewards per episode for it
|
||||
if self._training_config.session_type is not SessionType.TRAIN:
|
||||
self._train_agent.stop()
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -153,6 +153,8 @@ class SB3Agent(AgentSessionABC):
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -261,10 +261,12 @@ class Primaite(Env):
|
||||
self, transaction_writer=True, learning_session=True
|
||||
)
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
"""Shifts the episode_count by -1 for RLlib."""
|
||||
if self.training_config.agent_framework is AgentFramework.RLLIB:
|
||||
"""Shifts the episode_count by -1 for RLlib learning session."""
|
||||
if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
|
||||
return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
@@ -276,6 +278,7 @@ class Primaite(Env):
|
||||
self.step_count = 0
|
||||
self.total_step_count = 0
|
||||
self.episode_steps = self.training_config.num_eval_steps
|
||||
self.is_eval = True
|
||||
|
||||
def _write_av_reward_per_episode(self) -> None:
|
||||
if self.actual_episode_count > 0:
|
||||
|
||||
10
src/primaite/exceptions.py
Normal file
10
src/primaite/exceptions.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class PrimaiteError(Exception):
|
||||
"""The root PrimAITe Error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RLlibAgentError(PrimaiteError):
|
||||
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
|
||||
|
||||
pass
|
||||
@@ -69,7 +69,7 @@ num_train_episodes: 10
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 1
|
||||
num_eval_episodes: 3
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 256
|
||||
164
tests/config/session_test/training_config_main_sb3.yaml
Normal file
164
tests/config/session_test/training_config_main_sb3.yaml
Normal file
@@ -0,0 +1,164 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 3
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -5,18 +5,25 @@ import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[main_training_config_path(), dos_very_basic_config_path()]],
|
||||
[
|
||||
[TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()],
|
||||
[TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()],
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
"""Tests the PrimaiteSession class and its outputs."""
|
||||
"""
|
||||
Tests the PrimaiteSession class and all of its outputs.
|
||||
|
||||
This test runs for both a Stable Baselines3 agent, and a Ray RLlib agent.
|
||||
"""
|
||||
with temp_primaite_session as session:
|
||||
session_path = session.session_path
|
||||
assert session_path.exists()
|
||||
@@ -47,6 +54,17 @@ def test_primaite_session(temp_primaite_session):
|
||||
if file.suffix == ".csv":
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
# Check that the average reward per episode plots exist
|
||||
assert (session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
|
||||
assert (session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that the metadata has captured the correct number of learning and eval episodes and steps
|
||||
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
|
||||
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
|
||||
|
||||
assert len(session.eval_av_reward_per_episode_dict().keys()) == 3
|
||||
assert len(session.eval_all_transactions_dict().keys()) == 3 * 256
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
for file in file_names:
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
"""Test the training_config_main_rllib.yaml training config file."""
|
||||
with temp_primaite_session as session:
|
||||
session_path = session.session_path
|
||||
assert session_path.exists()
|
||||
session.learn()
|
||||
|
||||
assert len(session.learn_av_reward_per_episode_dict().keys()) == 10
|
||||
assert len(session.learn_all_transactions_dict().keys()) == 10 * 256
|
||||
Reference in New Issue
Block a user