diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index 3c18e1f3..87f8ed8e 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -230,7 +230,6 @@ class AgentSessionABC(ABC): self._update_session_metadata_file() self._can_evaluate = True self.is_eval = False - self._plot_av_reward_per_episode(learning_session=True) @abstractmethod def evaluate( @@ -243,9 +242,9 @@ class AgentSessionABC(ABC): :param kwargs: Any agent-specific key-word args to be passed. """ if self._can_evaluate: - self._plot_av_reward_per_episode(learning_session=False) self._update_session_metadata_file() self.is_eval = True + self._plot_av_reward_per_episode(learning_session=False) _LOGGER.info("Finished evaluation") @abstractmethod diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index bde3a621..fb062f54 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import shutil +import zipfile from datetime import datetime from logging import Logger from pathlib import Path @@ -17,8 +18,9 @@ from ray.tune.registry import register_env from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType from primaite.environment.primaite_env import Primaite +from primaite.exceptions import RLlibAgentError _LOGGER: Logger = getLogger(__name__) @@ -68,11 +70,14 @@ class RLlibAgent(AgentSessionABC): # TODO: implement RLlib agent loading if session_path is not None: msg = "RLlib agent loading has not been implemented yet" - _LOGGER.error(msg) - print(msg) - raise NotImplementedError + _LOGGER.critical(msg) + raise NotImplementedError(msg) super().__init__(training_config_path, lay_down_config_path) + if self._training_config.session_type == SessionType.EVAL: + msg = "Cannot evaluate an RLlib agent that hasn't been through training yet." + _LOGGER.critical(msg) + raise RLlibAgentError(msg) if not self._training_config.agent_framework == AgentFramework.RLLIB: msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) @@ -98,6 +103,7 @@ class RLlibAgent(AgentSessionABC): f"deep_learning_framework=" f"{self._training_config.deep_learning_framework}" ) + self._train_agent = None # Required to capture the learning agent to close after eval def _update_session_metadata_file(self) -> None: """ @@ -179,20 +185,73 @@ class RLlibAgent(AgentSessionABC): self._current_result = self._agent.train() self._save_checkpoint() self.save() - self._agent.stop() - super().learn() + # Done this way as the RLlib eval can only be performed if the session hasn't been stopped + if self._training_config.session_type is not SessionType.TRAIN: + self._train_agent = self._agent + else: + self._agent.stop() + self._plot_av_reward_per_episode(learning_session=True) + + def _unpack_saved_agent_into_eval(self) -> Path: + """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval.""" + agent_restore_path = self.evaluation_path / "agent_restore" + if agent_restore_path.exists(): + shutil.rmtree(agent_restore_path) + agent_restore_path.mkdir() + with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file: + zip_file.extractall(agent_restore_path) + return agent_restore_path + + def _setup_eval(self): + self._can_learn = False + self._can_evaluate = True + self._agent.restore(str(self._unpack_saved_agent_into_eval())) def evaluate( self, - **kwargs: None, - ) -> None: + **kwargs, + ): """ Evaluate the agent. :param kwargs: Any agent-specific key-word args to be passed. """ - raise NotImplementedError + time_steps = self._training_config.num_eval_steps + episodes = self._training_config.num_eval_episodes + + self._setup_eval() + + self._env: Primaite = Primaite( + self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str + ) + + self._env.set_as_eval() + self.is_eval = True + if self._training_config.deterministic: + deterministic_str = "deterministic" + else: + deterministic_str = "non-deterministic" + _LOGGER.info( + f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." + ) + for episode in range(episodes): + obs = self._env.reset() + for step in range(time_steps): + action = self._agent.compute_single_action(observation=obs, explore=False) + + obs, rewards, done, info = self._env.step(action) + + self._env.reset() + self._env.close() + super().evaluate() + # Now we're safe to close the learning agent and write the mean rewards per episode for it + if self._training_config.session_type is not SessionType.TRAIN: + self._train_agent.stop() + self._plot_av_reward_per_episode(learning_session=True) + # Perform a clean-up of the unpacked agent + if (self.evaluation_path / "agent_restore").exists(): + shutil.rmtree((self.evaluation_path / "agent_restore")) def _get_latest_checkpoint(self) -> None: raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 5a9f9482..b347d44f 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -153,6 +153,8 @@ class SB3Agent(AgentSessionABC): # save agent self.save() + self._plot_av_reward_per_episode(learning_session=True) + def evaluate( self, **kwargs: Any, diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index bd9b3689..6a145498 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -261,10 +261,12 @@ class Primaite(Env): self, transaction_writer=True, learning_session=True ) + self.is_eval = False + @property def actual_episode_count(self) -> int: - """Shifts the episode_count by -1 for RLlib.""" - if self.training_config.agent_framework is AgentFramework.RLLIB: + """Shifts the episode_count by -1 for RLlib learning session.""" + if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval: return self.episode_count - 1 return self.episode_count @@ -276,6 +278,7 @@ class Primaite(Env): self.step_count = 0 self.total_step_count = 0 self.episode_steps = self.training_config.num_eval_steps + self.is_eval = True def _write_av_reward_per_episode(self) -> None: if self.actual_episode_count > 0: diff --git a/src/primaite/exceptions.py b/src/primaite/exceptions.py new file mode 100644 index 00000000..0baf3949 --- /dev/null +++ b/src/primaite/exceptions.py @@ -0,0 +1,10 @@ +class PrimaiteError(Exception): + """The root PrimAITe Error.""" + + pass + + +class RLlibAgentError(PrimaiteError): + """Raised when there is a generic error with a RLlib agent that is specific to PRimAITE.""" + + pass diff --git a/tests/config/training_config_main_rllib.yaml b/tests/config/session_test/training_config_main_rllib.yaml similarity index 99% rename from tests/config/training_config_main_rllib.yaml rename to tests/config/session_test/training_config_main_rllib.yaml index 40cbc0fc..118b2d4e 100644 --- a/tests/config/training_config_main_rllib.yaml +++ b/tests/config/session_test/training_config_main_rllib.yaml @@ -69,7 +69,7 @@ num_train_episodes: 10 num_train_steps: 256 # Number of episodes for evaluation to run per session -num_eval_episodes: 1 +num_eval_episodes: 3 # Number of time_steps for evaluation per episode num_eval_steps: 256 diff --git a/tests/config/session_test/training_config_main_sb3.yaml b/tests/config/session_test/training_config_main_sb3.yaml new file mode 100644 index 00000000..9065bf8a --- /dev/null +++ b/tests/config/session_test/training_config_main_sb3.yaml @@ -0,0 +1,164 @@ +# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS + + +# Number of episodes for training to run per session +num_train_episodes: 10 + +# Number of time_steps for training per episode +num_train_steps: 256 + +# Number of episodes for evaluation to run per session +num_eval_episodes: 3 + +# Number of time_steps for evaluation per episode +num_eval_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 10 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 +# Node Software or Service State +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 +# Node File System State +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 +# IER status +red_ier_running: -0.0005 +green_ier_blocked: -0.001 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index 210d931e..4b7b91ac 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -5,18 +5,25 @@ import pytest from primaite import getLogger from primaite.config.lay_down_config import dos_very_basic_config_path -from primaite.config.training_config import main_training_config_path +from tests import TEST_CONFIG_ROOT _LOGGER = getLogger(__name__) @pytest.mark.parametrize( "temp_primaite_session", - [[main_training_config_path(), dos_very_basic_config_path()]], + [ + [TEST_CONFIG_ROOT / "session_test/training_config_main_rllib.yaml", dos_very_basic_config_path()], + [TEST_CONFIG_ROOT / "session_test/training_config_main_sb3.yaml", dos_very_basic_config_path()], + ], indirect=True, ) def test_primaite_session(temp_primaite_session): - """Tests the PrimaiteSession class and its outputs.""" + """ + Tests the PrimaiteSession class and all of its outputs. + + This test runs for both a Stable Baselines3 agent, and a Ray RLlib agent. + """ with temp_primaite_session as session: session_path = session.session_path assert session_path.exists() @@ -47,6 +54,17 @@ def test_primaite_session(temp_primaite_session): if file.suffix == ".csv": assert "all_transactions" in file.name or "average_reward_per_episode" in file.name + # Check that the average reward per episode plots exist + assert (session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists() + assert (session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.png").exists() + + # Check that the metadata has captured the correct number of learning and eval episodes and steps + assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 + assert len(session.learn_all_transactions_dict().keys()) == 10 * 256 + + assert len(session.eval_av_reward_per_episode_dict().keys()) == 3 + assert len(session.eval_all_transactions_dict().keys()) == 3 * 256 + _LOGGER.debug("Inspecting files in temp session path...") for dir_path, dir_names, file_names in os.walk(session_path): for file in file_names: diff --git a/tests/test_rllib_agent.py b/tests/test_rllib_agent.py deleted file mode 100644 index f494ea81..00000000 --- a/tests/test_rllib_agent.py +++ /dev/null @@ -1,24 +0,0 @@ -# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. -import pytest - -from primaite import getLogger -from primaite.config.lay_down_config import dos_very_basic_config_path -from tests import TEST_CONFIG_ROOT - -_LOGGER = getLogger(__name__) - - -@pytest.mark.parametrize( - "temp_primaite_session", - [[TEST_CONFIG_ROOT / "training_config_main_rllib.yaml", dos_very_basic_config_path()]], - indirect=True, -) -def test_primaite_session(temp_primaite_session): - """Test the training_config_main_rllib.yaml training config file.""" - with temp_primaite_session as session: - session_path = session.session_path - assert session_path.exists() - session.learn() - - assert len(session.learn_av_reward_per_episode_dict().keys()) == 10 - assert len(session.learn_all_transactions_dict().keys()) == 10 * 256