Merge remote-tracking branch 'origin/dev' into 1566-configure-episode-steps-learn-eval

# Conflicts:
#	src/primaite/agents/rllib.py
This commit is contained in:
Chris McCarthy
2023-07-07 14:34:20 +01:00
17 changed files with 588 additions and 200 deletions

4
.gitignore vendored
View File

@@ -138,4 +138,8 @@ dmypy.json
# Cython debug symbols
cython_debug/
# IDE
.idea/
# outputs
src/primaite/outputs/

View File

@@ -1 +1,64 @@
# PrimAITE
## Getting Started with PrimAITE
### Pre-Requisites
In order to get **PrimAITE** installed, you will need to have the following installed:
- `python3.8+`
- `python3-pip`
- `virtualenv`
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
### Installation from source
#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv)
```unix
python3 -m venv <name_of_venv>
```
#### 2. Activate the venv
##### Unix
```bash
source <name_of_venv>/bin/activate
```
##### Windows
```powershell
.\<name_of_venv>\Scripts\activate
```
#### 3. Install `primaite` into the venv along with all of it's dependencies
```bash
python3 -m pip install -e .
```
### Development Installation
To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example:
```bash
python3 -m pip install -e .[dev]
```
## Building documentation
The PrimAITE documentation can be built with the following commands:
##### Unix
```bash
cd docs
make html
```
##### Windows
```powershell
cd docs
.\make.bat html
```
This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build
options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation
for your specific output requirements.

View File

@@ -82,203 +82,203 @@ The environment config file consists of the following attributes:
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
* **Generic [all_ok]** [int]
* **Generic [all_ok]** [float]
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
* **Node Hardware State [off_should_be_on]** [int]
* **Node Hardware State [off_should_be_on]** [float]
The score to give when the node should be on, but is off
* **Node Hardware State [off_should_be_resetting]** [int]
* **Node Hardware State [off_should_be_resetting]** [float]
The score to give when the node should be resetting, but is off
* **Node Hardware State [on_should_be_off]** [int]
* **Node Hardware State [on_should_be_off]** [float]
The score to give when the node should be off, but is on
* **Node Hardware State [on_should_be_resetting]** [int]
* **Node Hardware State [on_should_be_resetting]** [float]
The score to give when the node should be resetting, but is on
* **Node Hardware State [resetting_should_be_on]** [int]
* **Node Hardware State [resetting_should_be_on]** [float]
The score to give when the node should be on, but is resetting
* **Node Hardware State [resetting_should_be_off]** [int]
* **Node Hardware State [resetting_should_be_off]** [float]
The score to give when the node should be off, but is resetting
* **Node Hardware State [resetting]** [int]
* **Node Hardware State [resetting]** [float]
The score to give when the node is resetting
* **Node Operating System or Service State [good_should_be_patching]** [int]
* **Node Operating System or Service State [good_should_be_patching]** [float]
The score to give when the state should be patching, but is good
* **Node Operating System or Service State [good_should_be_compromised]** [int]
* **Node Operating System or Service State [good_should_be_compromised]** [float]
The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int]
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patching_should_be_good]** [int]
* **Node Operating System or Service State [patching_should_be_good]** [float]
The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patching_should_be_compromised]** [int]
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int]
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [int]
* **Node Operating System or Service State [patching]** [float]
The score to give when the state is patching
* **Node Operating System or Service State [compromised_should_be_good]** [int]
* **Node Operating System or Service State [compromised_should_be_good]** [float]
The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromised_should_be_patching]** [int]
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int]
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [int]
* **Node Operating System or Service State [compromised]** [float]
The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int]
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int]
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int]
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [int]
* **Node Operating System or Service State [overwhelmed]** [float]
The score to give when the state is overwhelmed
* **Node File System State [good_should_be_repairing]** [int]
* **Node File System State [good_should_be_repairing]** [float]
The score to give when the state should be repairing, but is good
* **Node File System State [good_should_be_restoring]** [int]
* **Node File System State [good_should_be_restoring]** [float]
The score to give when the state should be restoring, but is good
* **Node File System State [good_should_be_corrupt]** [int]
* **Node File System State [good_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is good
* **Node File System State [good_should_be_destroyed]** [int]
* **Node File System State [good_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is good
* **Node File System State [repairing_should_be_good]** [int]
* **Node File System State [repairing_should_be_good]** [float]
The score to give when the state should be good, but is repairing
* **Node File System State [repairing_should_be_restoring]** [int]
* **Node File System State [repairing_should_be_restoring]** [float]
The score to give when the state should be restoring, but is repairing
* **Node File System State [repairing_should_be_corrupt]** [int]
* **Node File System State [repairing_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is repairing
* **Node File System State [repairing_should_be_destroyed]** [int]
* **Node File System State [repairing_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is repairing
* **Node File System State [repairing]** [int]
* **Node File System State [repairing]** [float]
The score to give when the state is repairing
* **Node File System State [restoring_should_be_good]** [int]
* **Node File System State [restoring_should_be_good]** [float]
The score to give when the state should be good, but is restoring
* **Node File System State [restoring_should_be_repairing]** [int]
* **Node File System State [restoring_should_be_repairing]** [float]
The score to give when the state should be repairing, but is restoring
* **Node File System State [restoring_should_be_corrupt]** [int]
* **Node File System State [restoring_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is restoring
* **Node File System State [restoring_should_be_destroyed]** [int]
* **Node File System State [restoring_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is restoring
* **Node File System State [restoring]** [int]
* **Node File System State [restoring]** [float]
The score to give when the state is restoring
* **Node File System State [corrupt_should_be_good]** [int]
* **Node File System State [corrupt_should_be_good]** [float]
The score to give when the state should be good, but is corrupt
* **Node File System State [corrupt_should_be_repairing]** [int]
* **Node File System State [corrupt_should_be_repairing]** [float]
The score to give when the state should be repairing, but is corrupt
* **Node File System State [corrupt_should_be_restoring]** [int]
* **Node File System State [corrupt_should_be_restoring]** [float]
The score to give when the state should be restoring, but is corrupt
* **Node File System State [corrupt_should_be_destroyed]** [int]
* **Node File System State [corrupt_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is corrupt
* **Node File System State [corrupt]** [int]
* **Node File System State [corrupt]** [float]
The score to give when the state is corrupt
* **Node File System State [destroyed_should_be_good]** [int]
* **Node File System State [destroyed_should_be_good]** [float]
The score to give when the state should be good, but is destroyed
* **Node File System State [destroyed_should_be_repairing]** [int]
* **Node File System State [destroyed_should_be_repairing]** [float]
The score to give when the state should be repairing, but is destroyed
* **Node File System State [destroyed_should_be_restoring]** [int]
* **Node File System State [destroyed_should_be_restoring]** [float]
The score to give when the state should be restoring, but is destroyed
* **Node File System State [destroyed_should_be_corrupt]** [int]
* **Node File System State [destroyed_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is destroyed
* **Node File System State [destroyed]** [int]
* **Node File System State [destroyed]** [float]
The score to give when the state is destroyed
* **Node File System State [scanning]** [int]
* **Node File System State [scanning]** [float]
The score to give when the state is scanning
* **IER Status [red_ier_running]** [int]
* **IER Status [red_ier_running]** [float]
The score to give when a red agent IER is permitted to run
* **IER Status [green_ier_blocked]** [int]
* **IER Status [green_ier_blocked]** [float]
The score to give when a green agent IER is prevented from running
@@ -308,6 +308,14 @@ Rewards are calculated based on the difference between the current state and ref
The number of steps to take when scanning the file system
* **deterministic** [bool]
Set to true if the agent evaluation should be deterministic. Default is ``False``
* **seed** [int]
Seed used in the randomisation in agent training. Default is ``None``
The Lay Down Config
*******************

View File

@@ -257,10 +257,19 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg)
pass
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
return self.learning_path / file_name
@abstractmethod
def save(self):
"""Save the agent."""
self._agent.save(self.session_path)
pass
@abstractmethod
def export(self):

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
@@ -106,6 +108,7 @@ class RLlibAgent(AgentSessionABC):
timestamp_str=self.timestamp_str,
),
)
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
self._agent_config.framework(framework="tf")
@@ -120,9 +123,11 @@ class RLlibAgent(AgentSessionABC):
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_train_episodes):
self._agent.save(str(self.checkpoints_path))
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
@@ -140,9 +145,12 @@ class RLlibAgent(AgentSessionABC):
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
def evaluate(
self,
**kwargs,
@@ -162,9 +170,25 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self):
def save(self, overwrite_existing: bool = True):
"""Save the agent."""
raise NotImplementedError
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -60,16 +60,19 @@ class SB3Agent(AgentSessionABC):
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_eval_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
pass
@@ -91,25 +94,27 @@ class SB3Agent(AgentSessionABC):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.reset()
self.save()
self._env.close()
super().learn()
# save agent
self.save()
def evaluate(
self,
deterministic: bool = True,
**kwargs,
):
"""
Evaluate the agent.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._env.set_as_eval()
self.is_eval = True
if deterministic:
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
@@ -120,7 +125,7 @@ class SB3Agent(AgentSessionABC):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic)
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
@@ -135,7 +140,7 @@ class SB3Agent(AgentSessionABC):
def save(self):
"""Save the agent."""
raise NotImplementedError
self._agent.save(self._saved_agent_path)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -31,6 +31,16 @@ agent_identifier: PPO
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
@@ -91,58 +101,58 @@ sb3_output_verbose_level: NONE
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -5
green_ier_blocked: -10
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS

View File

@@ -100,64 +100,64 @@ class TrainingConfig:
# Reward values
# Generic
all_ok: int = 0
all_ok: float = 0
# Node Hardware State
off_should_be_on: int = -10
off_should_be_resetting: int = -5
on_should_be_off: int = -2
on_should_be_resetting: int = -5
resetting_should_be_on: int = -5
resetting_should_be_off: int = -2
resetting: int = -3
off_should_be_on: float = -0.001
off_should_be_resetting: float = -0.0005
on_should_be_off: float = -0.0002
on_should_be_resetting: float = -0.0005
resetting_should_be_on: float = -0.0005
resetting_should_be_off: float = -0.0002
resetting: float = -0.0003
# Node Software or Service State
good_should_be_patching: int = 2
good_should_be_compromised: int = 5
good_should_be_overwhelmed: int = 5
patching_should_be_good: int = -5
patching_should_be_compromised: int = 2
patching_should_be_overwhelmed: int = 2
patching: int = -3
compromised_should_be_good: int = -20
compromised_should_be_patching: int = -20
compromised_should_be_overwhelmed: int = -20
compromised: int = -20
overwhelmed_should_be_good: int = -20
overwhelmed_should_be_patching: int = -20
overwhelmed_should_be_compromised: int = -20
overwhelmed: int = -20
good_should_be_patching: float = 0.0002
good_should_be_compromised: float = 0.0005
good_should_be_overwhelmed: float = 0.0005
patching_should_be_good: float = -0.0005
patching_should_be_compromised: float = 0.0002
patching_should_be_overwhelmed: float = 0.0002
patching: float = -0.0003
compromised_should_be_good: float = -0.002
compromised_should_be_patching: float = -0.002
compromised_should_be_overwhelmed: float = -0.002
compromised: float = -0.002
overwhelmed_should_be_good: float = -0.002
overwhelmed_should_be_patching: float = -0.002
overwhelmed_should_be_compromised: float = -0.002
overwhelmed: float = -0.002
# Node File System State
good_should_be_repairing: int = 2
good_should_be_restoring: int = 2
good_should_be_corrupt: int = 5
good_should_be_destroyed: int = 10
repairing_should_be_good: int = -5
repairing_should_be_restoring: int = 2
repairing_should_be_corrupt: int = 2
repairing_should_be_destroyed: int = 0
repairing: int = -3
restoring_should_be_good: int = -10
restoring_should_be_repairing: int = -2
restoring_should_be_corrupt: int = 1
restoring_should_be_destroyed: int = 2
restoring: int = -6
corrupt_should_be_good: int = -10
corrupt_should_be_repairing: int = -10
corrupt_should_be_restoring: int = -10
corrupt_should_be_destroyed: int = 2
corrupt: int = -10
destroyed_should_be_good: int = -20
destroyed_should_be_repairing: int = -20
destroyed_should_be_restoring: int = -20
destroyed_should_be_corrupt: int = -20
destroyed: int = -20
scanning: int = -2
good_should_be_repairing: float = 0.0002
good_should_be_restoring: float = 0.0002
good_should_be_corrupt: float = 0.0005
good_should_be_destroyed: float = 0.001
repairing_should_be_good: float = -0.0005
repairing_should_be_restoring: float = 0.0002
repairing_should_be_corrupt: float = 0.0002
repairing_should_be_destroyed: float = 0.0000
repairing: float = -0.0003
restoring_should_be_good: float = -0.001
restoring_should_be_repairing: float = -0.0002
restoring_should_be_corrupt: float = 0.0001
restoring_should_be_destroyed: float = 0.0002
restoring: float = -0.0006
corrupt_should_be_good: float = -0.001
corrupt_should_be_repairing: float = -0.001
corrupt_should_be_restoring: float = -0.001
corrupt_should_be_destroyed: float = 0.0002
corrupt: float = -0.001
destroyed_should_be_good: float = -0.002
destroyed_should_be_repairing: float = -0.002
destroyed_should_be_restoring: float = -0.002
destroyed_should_be_corrupt: float = -0.002
destroyed: float = -0.002
scanning: float = -0.0002
# IER status
red_ier_running: int = -5
green_ier_blocked: int = -10
red_ier_running: float = -0.0005
green_ier_blocked: float = -0.001
# Patching / Reset durations
os_patching_duration: int = 5
@@ -184,6 +184,12 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system"
deterministic: bool = False
"If true, the training will be deterministic"
seed: Optional[int] = None
"The random number generator seed to be used while training the agent"
@classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
"""

View File

@@ -147,10 +147,10 @@ class Primaite(Env):
self.step_info = {}
# Total reward
self.total_reward = 0
self.total_reward: float = 0
# Average reward
self.average_reward = 0
self.average_reward: float = 0
# Episode count
self.episode_count = 0
@@ -289,9 +289,9 @@ class Primaite(Env):
self._create_random_red_agent()
# Reset counters and totals
self.total_reward = 0
self.total_reward = 0.0
self.step_count = 0
self.average_reward = 0
self.average_reward = 0.0
# Update observations space and return
self.update_environent_obs()

View File

@@ -20,7 +20,7 @@ def calculate_reward_function(
red_iers,
step_count,
config_values,
):
) -> float:
"""
Compares the states of the initial and final nodes/links to get a reward.
@@ -33,7 +33,7 @@ def calculate_reward_function(
step_count: current step
config_values: Config values
"""
reward_value = 0
reward_value: float = 0.0
# For each node, compare hardware state, SoftwareState, service states
for node_key, final_node in final_nodes.items():
@@ -94,7 +94,7 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the hardware state of a node.
@@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
score: float = 0.0
final_node_operating_state = final_node.hardware_state
reference_node_operating_state = reference_node.hardware_state
@@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values):
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the Software State of a node.
@@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
score: float = 0.0
final_node_os_state = final_node.software_state
reference_node_os_state = reference_node.software_state
@@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values):
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the service state(s) of a node.
@@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
score: float = 0.0
final_node_services: Dict[str, Service] = final_node.services
reference_node_services: Dict[str, Service] = reference_node.services
@@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values):
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the file system state of a node.
@@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
"""
score = 0
score: float = 0.0
final_node_file_system_state = final_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual

View File

@@ -31,7 +31,7 @@ class Transaction(object):
"The observation space before any actions are taken"
self.obs_space_post = None
"The observation space after any actions are taken"
self.reward = None
self.reward: float = None
"The reward value"
self.action_space = None
"The action space invoked by the agent"

View File

@@ -0,0 +1,155 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: None
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 0
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0.0000
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,40 +1,94 @@
# Main Config File
# Training Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: True
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: 67890
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: True
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 0
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0

View File

@@ -58,7 +58,6 @@ class TempPrimaiteSession(PrimaiteSession):
def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")

View File

@@ -1,6 +1,7 @@
import tempfile
from datetime import datetime
from pathlib import Path
from uuid import uuid4
from primaite import getLogger
@@ -14,9 +15,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path = Path(tempfile.gettempdir()) / "primaite" / str(uuid4())
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created temp session directory: {session_path}")
return session_path

View File

@@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session):
# Check that the network png file exists
assert (session_path / f"network_{session.timestamp_str}.png").exists()
# Check that the saved agent exists
assert session._agent_session._saved_agent_path.exists()
# Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir():
if file.suffix == ".csv":

View File

@@ -0,0 +1,49 @@
import pytest as pytest
from primaite.config.lay_down_config import dos_very_basic_config_path
from tests import TEST_CONFIG_ROOT
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_seeded_learning(temp_primaite_session):
"""Test running seeded learning produces the same output when ran twice."""
expected_mean_reward_per_episode = {
1: -90.703125,
2: -91.15234375,
3: -87.5,
4: -92.2265625,
5: -94.6875,
6: -91.19140625,
7: -88.984375,
8: -88.3203125,
9: -112.79296875,
10: -100.01953125,
}
with temp_primaite_session as session:
assert session._training_config.seed == 67890, (
"Expected output is based upon a agent that was trained with " "seed 67890"
)
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_deterministic_evaluation(temp_primaite_session):
"""Test running deterministic evaluation gives same av eward per episode."""
with temp_primaite_session as session:
# do stuff
session.learn()
session.evaluate()
eval_mean_reward = session.eval_av_reward_per_episode_csv()
assert len(set(eval_mean_reward.values())) == 1