diff --git a/.gitignore b/.gitignore index f223391e..5d6434f1 100644 --- a/.gitignore +++ b/.gitignore @@ -138,5 +138,9 @@ dmypy.json # Cython debug symbols cython_debug/ +# IDE .idea/ docs/source/primaite-dependencies.rst + +# outputs +src/primaite/outputs/ diff --git a/README.md b/README.md index 78f36fba..f7c6efd7 100644 --- a/README.md +++ b/README.md @@ -1 +1,64 @@ # PrimAITE + +## Getting Started with PrimAITE + +### Pre-Requisites + +In order to get **PrimAITE** installed, you will need to have the following installed: + +- `python3.8+` +- `python3-pip` +- `virtualenv` + +**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS. + +### Installation from source +#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv) + +```unix +python3 -m venv +``` + +#### 2. Activate the venv + +##### Unix +```bash +source /bin/activate +``` + +##### Windows +```powershell +.\\Scripts\activate +``` + +#### 3. Install `primaite` into the venv along with all of it's dependencies + +```bash +python3 -m pip install -e . +``` + +### Development Installation +To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example: + +```bash +python3 -m pip install -e .[dev] +``` + +## Building documentation +The PrimAITE documentation can be built with the following commands: + +##### Unix +```bash +cd docs +make html +``` + +##### Windows +```powershell +cd docs +.\make.bat html +``` + +This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build +options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation +for your specific output requirements. diff --git a/docs/source/config.rst b/docs/source/config.rst index 71ade6c5..c80baa3c 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -82,203 +82,203 @@ The environment config file consists of the following attributes: Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment. -* **Generic [all_ok]** [int] +* **Generic [all_ok]** [float] The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) -* **Node Hardware State [off_should_be_on]** [int] +* **Node Hardware State [off_should_be_on]** [float] The score to give when the node should be on, but is off -* **Node Hardware State [off_should_be_resetting]** [int] +* **Node Hardware State [off_should_be_resetting]** [float] The score to give when the node should be resetting, but is off -* **Node Hardware State [on_should_be_off]** [int] +* **Node Hardware State [on_should_be_off]** [float] The score to give when the node should be off, but is on -* **Node Hardware State [on_should_be_resetting]** [int] +* **Node Hardware State [on_should_be_resetting]** [float] The score to give when the node should be resetting, but is on -* **Node Hardware State [resetting_should_be_on]** [int] +* **Node Hardware State [resetting_should_be_on]** [float] The score to give when the node should be on, but is resetting -* **Node Hardware State [resetting_should_be_off]** [int] +* **Node Hardware State [resetting_should_be_off]** [float] The score to give when the node should be off, but is resetting -* **Node Hardware State [resetting]** [int] +* **Node Hardware State [resetting]** [float] The score to give when the node is resetting -* **Node Operating System or Service State [good_should_be_patching]** [int] +* **Node Operating System or Service State [good_should_be_patching]** [float] The score to give when the state should be patching, but is good -* **Node Operating System or Service State [good_should_be_compromised]** [int] +* **Node Operating System or Service State [good_should_be_compromised]** [float] The score to give when the state should be compromised, but is good -* **Node Operating System or Service State [good_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [good_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is good -* **Node Operating System or Service State [patching_should_be_good]** [int] +* **Node Operating System or Service State [patching_should_be_good]** [float] The score to give when the state should be good, but is patching -* **Node Operating System or Service State [patching_should_be_compromised]** [int] +* **Node Operating System or Service State [patching_should_be_compromised]** [float] The score to give when the state should be compromised, but is patching -* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is patching -* **Node Operating System or Service State [patching]** [int] +* **Node Operating System or Service State [patching]** [float] The score to give when the state is patching -* **Node Operating System or Service State [compromised_should_be_good]** [int] +* **Node Operating System or Service State [compromised_should_be_good]** [float] The score to give when the state should be good, but is compromised -* **Node Operating System or Service State [compromised_should_be_patching]** [int] +* **Node Operating System or Service State [compromised_should_be_patching]** [float] The score to give when the state should be patching, but is compromised -* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is compromised -* **Node Operating System or Service State [compromised]** [int] +* **Node Operating System or Service State [compromised]** [float] The score to give when the state is compromised -* **Node Operating System or Service State [overwhelmed_should_be_good]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_good]** [float] The score to give when the state should be good, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float] The score to give when the state should be patching, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float] The score to give when the state should be compromised, but is overwhelmed -* **Node Operating System or Service State [overwhelmed]** [int] +* **Node Operating System or Service State [overwhelmed]** [float] The score to give when the state is overwhelmed -* **Node File System State [good_should_be_repairing]** [int] +* **Node File System State [good_should_be_repairing]** [float] The score to give when the state should be repairing, but is good -* **Node File System State [good_should_be_restoring]** [int] +* **Node File System State [good_should_be_restoring]** [float] The score to give when the state should be restoring, but is good -* **Node File System State [good_should_be_corrupt]** [int] +* **Node File System State [good_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is good -* **Node File System State [good_should_be_destroyed]** [int] +* **Node File System State [good_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is good -* **Node File System State [repairing_should_be_good]** [int] +* **Node File System State [repairing_should_be_good]** [float] The score to give when the state should be good, but is repairing -* **Node File System State [repairing_should_be_restoring]** [int] +* **Node File System State [repairing_should_be_restoring]** [float] The score to give when the state should be restoring, but is repairing -* **Node File System State [repairing_should_be_corrupt]** [int] +* **Node File System State [repairing_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is repairing -* **Node File System State [repairing_should_be_destroyed]** [int] +* **Node File System State [repairing_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is repairing -* **Node File System State [repairing]** [int] +* **Node File System State [repairing]** [float] The score to give when the state is repairing -* **Node File System State [restoring_should_be_good]** [int] +* **Node File System State [restoring_should_be_good]** [float] The score to give when the state should be good, but is restoring -* **Node File System State [restoring_should_be_repairing]** [int] +* **Node File System State [restoring_should_be_repairing]** [float] The score to give when the state should be repairing, but is restoring -* **Node File System State [restoring_should_be_corrupt]** [int] +* **Node File System State [restoring_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is restoring -* **Node File System State [restoring_should_be_destroyed]** [int] +* **Node File System State [restoring_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is restoring -* **Node File System State [restoring]** [int] +* **Node File System State [restoring]** [float] The score to give when the state is restoring -* **Node File System State [corrupt_should_be_good]** [int] +* **Node File System State [corrupt_should_be_good]** [float] The score to give when the state should be good, but is corrupt -* **Node File System State [corrupt_should_be_repairing]** [int] +* **Node File System State [corrupt_should_be_repairing]** [float] The score to give when the state should be repairing, but is corrupt -* **Node File System State [corrupt_should_be_restoring]** [int] +* **Node File System State [corrupt_should_be_restoring]** [float] The score to give when the state should be restoring, but is corrupt -* **Node File System State [corrupt_should_be_destroyed]** [int] +* **Node File System State [corrupt_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is corrupt -* **Node File System State [corrupt]** [int] +* **Node File System State [corrupt]** [float] The score to give when the state is corrupt -* **Node File System State [destroyed_should_be_good]** [int] +* **Node File System State [destroyed_should_be_good]** [float] The score to give when the state should be good, but is destroyed -* **Node File System State [destroyed_should_be_repairing]** [int] +* **Node File System State [destroyed_should_be_repairing]** [float] The score to give when the state should be repairing, but is destroyed -* **Node File System State [destroyed_should_be_restoring]** [int] +* **Node File System State [destroyed_should_be_restoring]** [float] The score to give when the state should be restoring, but is destroyed -* **Node File System State [destroyed_should_be_corrupt]** [int] +* **Node File System State [destroyed_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is destroyed -* **Node File System State [destroyed]** [int] +* **Node File System State [destroyed]** [float] The score to give when the state is destroyed -* **Node File System State [scanning]** [int] +* **Node File System State [scanning]** [float] The score to give when the state is scanning -* **IER Status [red_ier_running]** [int] +* **IER Status [red_ier_running]** [float] The score to give when a red agent IER is permitted to run -* **IER Status [green_ier_blocked]** [int] +* **IER Status [green_ier_blocked]** [float] The score to give when a green agent IER is prevented from running @@ -308,6 +308,14 @@ Rewards are calculated based on the difference between the current state and ref The number of steps to take when scanning the file system +* **deterministic** [bool] + + Set to true if the agent evaluation should be deterministic. Default is ``False`` + +* **seed** [int] + + Seed used in the randomisation in agent training. Default is ``None`` + The Lay Down Config ******************* diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 3b093f86..a9bdfb1e 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -266,10 +266,19 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass + @property + def _saved_agent_path(self) -> Path: + file_name = ( + f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip" + ) + return self.learning_path / file_name + @abstractmethod def save(self): """Save the agent.""" - self._agent.save(self.session_path) + pass @abstractmethod def export(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index bd5c8585..19939af8 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +import shutil from datetime import datetime from pathlib import Path from typing import Union +from uuid import uuid4 from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig @@ -118,6 +120,7 @@ class RLlibAgent(AgentSessionABC): timestamp_str=self.timestamp_str, ), ) + self._agent_config.seed = self._training_config.seed self._agent_config.training(train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") @@ -132,9 +135,11 @@ class RLlibAgent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - self._agent.save(str(self.checkpoints_path)) + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + self._agent.save(str(self.checkpoints_path)) def learn( self, @@ -152,9 +157,14 @@ class RLlibAgent(AgentSessionABC): for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() + self.save() self._agent.stop() + super().learn() + # save agent + self.save() + def evaluate( self, **kwargs, @@ -174,9 +184,25 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): + def save(self, overwrite_existing: bool = True): """Save the agent.""" - raise NotImplementedError + # Make temp dir to save in isolation + temp_dir = self.learning_path / str(uuid4()) + temp_dir.mkdir() + + # Save the agent to the temp dir + self._agent.save(str(temp_dir)) + + # Capture the saved Rllib checkpoint inside the temp directory + for file in temp_dir.iterdir(): + checkpoint_dir = file + break + + # Zip the folder + shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa + + # Drop the temp directory + shutil.rmtree(temp_dir) def export(self): """Export the agent to transportable file format.""" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 90a24ee2..885ff956 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -71,16 +71,19 @@ class SB3Agent(AgentSessionABC): verbose=self.sb3_output_verbose_level, n_steps=self._training_config.num_steps, tensorboard_log=str(self._tensorboard_log_path), + seed=self._training_config.seed, ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - self._agent.save(checkpoint_path) - _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -102,25 +105,27 @@ class SB3Agent(AgentSessionABC): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() self._env.reset() + self.save() self._env.close() super().learn() + # save agent + self.save() + def evaluate( self, - deterministic: bool = True, **kwargs, ): """ Evaluate the agent. - :param deterministic: Whether the evaluation is deterministic. :param kwargs: Any agent-specific key-word args to be passed. """ time_steps = self._training_config.num_steps episodes = self._training_config.num_episodes self._env.set_as_eval() self.is_eval = True - if deterministic: + if self._training_config.deterministic: deterministic_str = "deterministic" else: deterministic_str = "non-deterministic" @@ -131,7 +136,7 @@ class SB3Agent(AgentSessionABC): obs = self._env.reset() for step in range(time_steps): - action, _states = self._agent.predict(obs, deterministic=deterministic) + action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic) if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) @@ -146,7 +151,7 @@ class SB3Agent(AgentSessionABC): def save(self): """Save the agent.""" - raise NotImplementedError + self._agent.save(self._saved_agent_path) def export(self): """Export the agent to transportable file format.""" diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index a638fe14..15adc4dd 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -31,6 +31,16 @@ agent_identifier: PPO # False random_red_agent: False +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: null + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + # Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. # Options are: # "BASIC" (The current observation space only) @@ -83,58 +93,58 @@ sb3_output_verbose_level: NONE # Generic all_ok: 0 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 30edb79b..8d38c0ef 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -94,64 +94,64 @@ class TrainingConfig: # Reward values # Generic - all_ok: int = 0 + all_ok: float = 0 # Node Hardware State - off_should_be_on: int = -10 - off_should_be_resetting: int = -5 - on_should_be_off: int = -2 - on_should_be_resetting: int = -5 - resetting_should_be_on: int = -5 - resetting_should_be_off: int = -2 - resetting: int = -3 + off_should_be_on: float = -0.001 + off_should_be_resetting: float = -0.0005 + on_should_be_off: float = -0.0002 + on_should_be_resetting: float = -0.0005 + resetting_should_be_on: float = -0.0005 + resetting_should_be_off: float = -0.0002 + resetting: float = -0.0003 # Node Software or Service State - good_should_be_patching: int = 2 - good_should_be_compromised: int = 5 - good_should_be_overwhelmed: int = 5 - patching_should_be_good: int = -5 - patching_should_be_compromised: int = 2 - patching_should_be_overwhelmed: int = 2 - patching: int = -3 - compromised_should_be_good: int = -20 - compromised_should_be_patching: int = -20 - compromised_should_be_overwhelmed: int = -20 - compromised: int = -20 - overwhelmed_should_be_good: int = -20 - overwhelmed_should_be_patching: int = -20 - overwhelmed_should_be_compromised: int = -20 - overwhelmed: int = -20 + good_should_be_patching: float = 0.0002 + good_should_be_compromised: float = 0.0005 + good_should_be_overwhelmed: float = 0.0005 + patching_should_be_good: float = -0.0005 + patching_should_be_compromised: float = 0.0002 + patching_should_be_overwhelmed: float = 0.0002 + patching: float = -0.0003 + compromised_should_be_good: float = -0.002 + compromised_should_be_patching: float = -0.002 + compromised_should_be_overwhelmed: float = -0.002 + compromised: float = -0.002 + overwhelmed_should_be_good: float = -0.002 + overwhelmed_should_be_patching: float = -0.002 + overwhelmed_should_be_compromised: float = -0.002 + overwhelmed: float = -0.002 # Node File System State - good_should_be_repairing: int = 2 - good_should_be_restoring: int = 2 - good_should_be_corrupt: int = 5 - good_should_be_destroyed: int = 10 - repairing_should_be_good: int = -5 - repairing_should_be_restoring: int = 2 - repairing_should_be_corrupt: int = 2 - repairing_should_be_destroyed: int = 0 - repairing: int = -3 - restoring_should_be_good: int = -10 - restoring_should_be_repairing: int = -2 - restoring_should_be_corrupt: int = 1 - restoring_should_be_destroyed: int = 2 - restoring: int = -6 - corrupt_should_be_good: int = -10 - corrupt_should_be_repairing: int = -10 - corrupt_should_be_restoring: int = -10 - corrupt_should_be_destroyed: int = 2 - corrupt: int = -10 - destroyed_should_be_good: int = -20 - destroyed_should_be_repairing: int = -20 - destroyed_should_be_restoring: int = -20 - destroyed_should_be_corrupt: int = -20 - destroyed: int = -20 - scanning: int = -2 + good_should_be_repairing: float = 0.0002 + good_should_be_restoring: float = 0.0002 + good_should_be_corrupt: float = 0.0005 + good_should_be_destroyed: float = 0.001 + repairing_should_be_good: float = -0.0005 + repairing_should_be_restoring: float = 0.0002 + repairing_should_be_corrupt: float = 0.0002 + repairing_should_be_destroyed: float = 0.0000 + repairing: float = -0.0003 + restoring_should_be_good: float = -0.001 + restoring_should_be_repairing: float = -0.0002 + restoring_should_be_corrupt: float = 0.0001 + restoring_should_be_destroyed: float = 0.0002 + restoring: float = -0.0006 + corrupt_should_be_good: float = -0.001 + corrupt_should_be_repairing: float = -0.001 + corrupt_should_be_restoring: float = -0.001 + corrupt_should_be_destroyed: float = 0.0002 + corrupt: float = -0.001 + destroyed_should_be_good: float = -0.002 + destroyed_should_be_repairing: float = -0.002 + destroyed_should_be_restoring: float = -0.002 + destroyed_should_be_corrupt: float = -0.002 + destroyed: float = -0.002 + scanning: float = -0.0002 # IER status - red_ier_running: int = -5 - green_ier_blocked: int = -10 + red_ier_running: float = -0.0005 + green_ier_blocked: float = -0.001 # Patching / Reset durations os_patching_duration: int = 5 @@ -178,6 +178,12 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system" + deterministic: bool = False + "If true, the training will be deterministic" + + seed: Optional[int] = None + "The random number generator seed to be used while training the agent" + @classmethod def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: """ diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 9a5df13a..d3c37882 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -141,10 +141,10 @@ class Primaite(Env): self.step_info = {} # Total reward - self.total_reward = 0 + self.total_reward: float = 0 # Average reward - self.average_reward = 0 + self.average_reward: float = 0 # Episode count self.episode_count = 0 @@ -282,9 +282,9 @@ class Primaite(Env): self._create_random_red_agent() # Reset counters and totals - self.total_reward = 0 + self.total_reward = 0.0 self.step_count = 0 - self.average_reward = 0 + self.average_reward = 0.0 # Update observations space and return self.update_environent_obs() diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 19094a18..e4353cb9 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -20,7 +20,7 @@ def calculate_reward_function( red_iers, step_count, config_values, -): +) -> float: """ Compares the states of the initial and final nodes/links to get a reward. @@ -33,7 +33,7 @@ def calculate_reward_function( step_count: current step config_values: Config values """ - reward_value = 0 + reward_value: float = 0.0 # For each node, compare hardware state, SoftwareState, service states for node_key, final_node in final_nodes.items(): @@ -94,7 +94,7 @@ def calculate_reward_function( return reward_value -def score_node_operating_state(final_node, initial_node, reference_node, config_values): +def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the hardware state of a node. @@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_operating_state = final_node.hardware_state reference_node_operating_state = reference_node.hardware_state @@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ return score -def score_node_os_state(final_node, initial_node, reference_node, config_values): +def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the Software State of a node. @@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_os_state = final_node.software_state reference_node_os_state = reference_node.software_state @@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) return score -def score_node_service_state(final_node, initial_node, reference_node, config_values): +def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the service state(s) of a node. @@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_services: Dict[str, Service] = final_node.services reference_node_services: Dict[str, Service] = reference_node.services @@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va return score -def score_node_file_system(final_node, initial_node, reference_node, config_values): +def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the file system state of a node. @@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu initial_node: The node before red and blue agents take effect reference_node: The node if there had been no red or blue effect """ - score = 0 + score: float = 0.0 final_node_file_system_state = final_node.file_system_state_actual reference_node_file_system_state = reference_node.file_system_state_actual diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 3a5a13db..f49d4ec2 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -31,7 +31,7 @@ class Transaction(object): "The observation space before any actions are taken" self.obs_space_post = None "The observation space after any actions are taken" - self.reward = None + self.reward: float = None "The reward value" self.action_space = None "The action space invoked by the agent" diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml new file mode 100644 index 00000000..23cff44e --- /dev/null +++ b/tests/config/ppo_not_seeded_training_config.yaml @@ -0,0 +1,155 @@ +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: None + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: False + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS +# Number of episodes to run per session +num_episodes: 10 + +# Number of time_steps per episode +num_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 0 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + +# Reward values +# Generic +all_ok: 0.0000 +# Node Hardware State +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 +# Node Software or Service State +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 +# Node File System State +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 +# IER status +red_ier_running: -0.0005 +green_ier_blocked: -0.001 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/tests/config/ppo_seeded_training_config.yaml similarity index 52% rename from src/primaite/config/_package_data/training/training_config_random_red_agent.yaml rename to tests/config/ppo_seeded_training_config.yaml index 96243daf..181331d9 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/tests/config/ppo_seeded_training_config.yaml @@ -1,40 +1,94 @@ -# Main Config File +# Training Config File -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TF2 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO # Sets whether Red Agent POL and IER is randomised. # Options are: # True # False -random_red_agent: True +random_red_agent: False + +# The (integer) seed to be used in random number generation +# Default is None (null) +seed: 67890 + +# Set whether the agent will be deterministic instead of stochastic +# Options are: +# True +# False +deterministic: True + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session num_episodes: 10 + # Number of time_steps per episode num_steps: 256 -# Time delay between steps (for generic agents) -time_delay: 10 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 0 + +# Time delay (milliseconds) between steps for CUSTOM agents. +time_delay: 5 + +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN_EVAL # Environment config values # The high value for the observation space observation_space_high_value: 1000000000 +# The Stable Baselines3 learn/eval output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE + # Reward values # Generic all_ok: 0 diff --git a/tests/conftest.py b/tests/conftest.py index af76b314..388bc034 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,7 +58,6 @@ class TempPrimaiteSession(PrimaiteSession): def __exit__(self, type, value, tb): shutil.rmtree(self.session_path) - shutil.rmtree(self.session_path.parent) _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index feff52f6..90c0cb5d 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -1,6 +1,7 @@ import tempfile from datetime import datetime from pathlib import Path +from uuid import uuid4 from primaite import getLogger @@ -14,9 +15,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path: :param session_timestamp: This is the datetime that the session started. :return: The session directory path. """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + session_path = Path(tempfile.gettempdir()) / "primaite" / str(uuid4()) session_path.mkdir(exist_ok=True, parents=True) _LOGGER.debug(f"Created temp session directory: {session_path}") return session_path diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index ae0b0870..75ea5882 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session): # Check that the network png file exists assert (session_path / f"network_{session.timestamp_str}.png").exists() + # Check that the saved agent exists + assert session._agent_session._saved_agent_path.exists() + # Check that both the transactions and av reward csv files exist for file in session.learning_path.iterdir(): if file.suffix == ".csv": diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py new file mode 100644 index 00000000..34cb43fb --- /dev/null +++ b/tests/test_seeding_and_deterministic_session.py @@ -0,0 +1,49 @@ +import pytest as pytest + +from primaite.config.lay_down_config import dos_very_basic_config_path +from tests import TEST_CONFIG_ROOT + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]], + indirect=True, +) +def test_seeded_learning(temp_primaite_session): + """Test running seeded learning produces the same output when ran twice.""" + expected_mean_reward_per_episode = { + 1: -90.703125, + 2: -91.15234375, + 3: -87.5, + 4: -92.2265625, + 5: -94.6875, + 6: -91.19140625, + 7: -88.984375, + 8: -88.3203125, + 9: -112.79296875, + 10: -100.01953125, + } + with temp_primaite_session as session: + assert session._training_config.seed == 67890, ( + "Expected output is based upon a agent that was trained with " "seed 67890" + ) + session.learn() + actual_mean_reward_per_episode = session.learn_av_reward_per_episode() + + assert actual_mean_reward_per_episode == expected_mean_reward_per_episode + + +@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.") +@pytest.mark.parametrize( + "temp_primaite_session", + [[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]], + indirect=True, +) +def test_deterministic_evaluation(temp_primaite_session): + """Test running deterministic evaluation gives same av eward per episode.""" + with temp_primaite_session as session: + # do stuff + session.learn() + session.evaluate() + eval_mean_reward = session.eval_av_reward_per_episode_csv() + assert len(set(eval_mean_reward.values())) == 1