Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules

This commit is contained in:
SunilSamra
2023-07-07 15:14:05 +01:00
17 changed files with 590 additions and 200 deletions

4
.gitignore vendored
View File

@@ -138,4 +138,8 @@ dmypy.json
# Cython debug symbols # Cython debug symbols
cython_debug/ cython_debug/
# IDE
.idea/ .idea/
# outputs
src/primaite/outputs/

View File

@@ -1 +1,64 @@
# PrimAITE # PrimAITE
## Getting Started with PrimAITE
### Pre-Requisites
In order to get **PrimAITE** installed, you will need to have the following installed:
- `python3.8+`
- `python3-pip`
- `virtualenv`
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
### Installation from source
#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv)
```unix
python3 -m venv <name_of_venv>
```
#### 2. Activate the venv
##### Unix
```bash
source <name_of_venv>/bin/activate
```
##### Windows
```powershell
.\<name_of_venv>\Scripts\activate
```
#### 3. Install `primaite` into the venv along with all of it's dependencies
```bash
python3 -m pip install -e .
```
### Development Installation
To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example:
```bash
python3 -m pip install -e .[dev]
```
## Building documentation
The PrimAITE documentation can be built with the following commands:
##### Unix
```bash
cd docs
make html
```
##### Windows
```powershell
cd docs
.\make.bat html
```
This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build
options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation
for your specific output requirements.

View File

@@ -82,203 +82,203 @@ The environment config file consists of the following attributes:
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment. Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
* **Generic [all_ok]** [int] * **Generic [all_ok]** [float]
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
* **Node Hardware State [off_should_be_on]** [int] * **Node Hardware State [off_should_be_on]** [float]
The score to give when the node should be on, but is off The score to give when the node should be on, but is off
* **Node Hardware State [off_should_be_resetting]** [int] * **Node Hardware State [off_should_be_resetting]** [float]
The score to give when the node should be resetting, but is off The score to give when the node should be resetting, but is off
* **Node Hardware State [on_should_be_off]** [int] * **Node Hardware State [on_should_be_off]** [float]
The score to give when the node should be off, but is on The score to give when the node should be off, but is on
* **Node Hardware State [on_should_be_resetting]** [int] * **Node Hardware State [on_should_be_resetting]** [float]
The score to give when the node should be resetting, but is on The score to give when the node should be resetting, but is on
* **Node Hardware State [resetting_should_be_on]** [int] * **Node Hardware State [resetting_should_be_on]** [float]
The score to give when the node should be on, but is resetting The score to give when the node should be on, but is resetting
* **Node Hardware State [resetting_should_be_off]** [int] * **Node Hardware State [resetting_should_be_off]** [float]
The score to give when the node should be off, but is resetting The score to give when the node should be off, but is resetting
* **Node Hardware State [resetting]** [int] * **Node Hardware State [resetting]** [float]
The score to give when the node is resetting The score to give when the node is resetting
* **Node Operating System or Service State [good_should_be_patching]** [int] * **Node Operating System or Service State [good_should_be_patching]** [float]
The score to give when the state should be patching, but is good The score to give when the state should be patching, but is good
* **Node Operating System or Service State [good_should_be_compromised]** [int] * **Node Operating System or Service State [good_should_be_compromised]** [float]
The score to give when the state should be compromised, but is good The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int] * **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is good The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patching_should_be_good]** [int] * **Node Operating System or Service State [patching_should_be_good]** [float]
The score to give when the state should be good, but is patching The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patching_should_be_compromised]** [int] * **Node Operating System or Service State [patching_should_be_compromised]** [float]
The score to give when the state should be compromised, but is patching The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] * **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is patching The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [int] * **Node Operating System or Service State [patching]** [float]
The score to give when the state is patching The score to give when the state is patching
* **Node Operating System or Service State [compromised_should_be_good]** [int] * **Node Operating System or Service State [compromised_should_be_good]** [float]
The score to give when the state should be good, but is compromised The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromised_should_be_patching]** [int] * **Node Operating System or Service State [compromised_should_be_patching]** [float]
The score to give when the state should be patching, but is compromised The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] * **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is compromised The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [int] * **Node Operating System or Service State [compromised]** [float]
The score to give when the state is compromised The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int] * **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
The score to give when the state should be good, but is overwhelmed The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] * **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
The score to give when the state should be patching, but is overwhelmed The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] * **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
The score to give when the state should be compromised, but is overwhelmed The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [int] * **Node Operating System or Service State [overwhelmed]** [float]
The score to give when the state is overwhelmed The score to give when the state is overwhelmed
* **Node File System State [good_should_be_repairing]** [int] * **Node File System State [good_should_be_repairing]** [float]
The score to give when the state should be repairing, but is good The score to give when the state should be repairing, but is good
* **Node File System State [good_should_be_restoring]** [int] * **Node File System State [good_should_be_restoring]** [float]
The score to give when the state should be restoring, but is good The score to give when the state should be restoring, but is good
* **Node File System State [good_should_be_corrupt]** [int] * **Node File System State [good_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is good The score to give when the state should be corrupt, but is good
* **Node File System State [good_should_be_destroyed]** [int] * **Node File System State [good_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is good The score to give when the state should be destroyed, but is good
* **Node File System State [repairing_should_be_good]** [int] * **Node File System State [repairing_should_be_good]** [float]
The score to give when the state should be good, but is repairing The score to give when the state should be good, but is repairing
* **Node File System State [repairing_should_be_restoring]** [int] * **Node File System State [repairing_should_be_restoring]** [float]
The score to give when the state should be restoring, but is repairing The score to give when the state should be restoring, but is repairing
* **Node File System State [repairing_should_be_corrupt]** [int] * **Node File System State [repairing_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is repairing The score to give when the state should be corrupt, but is repairing
* **Node File System State [repairing_should_be_destroyed]** [int] * **Node File System State [repairing_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is repairing The score to give when the state should be destroyed, but is repairing
* **Node File System State [repairing]** [int] * **Node File System State [repairing]** [float]
The score to give when the state is repairing The score to give when the state is repairing
* **Node File System State [restoring_should_be_good]** [int] * **Node File System State [restoring_should_be_good]** [float]
The score to give when the state should be good, but is restoring The score to give when the state should be good, but is restoring
* **Node File System State [restoring_should_be_repairing]** [int] * **Node File System State [restoring_should_be_repairing]** [float]
The score to give when the state should be repairing, but is restoring The score to give when the state should be repairing, but is restoring
* **Node File System State [restoring_should_be_corrupt]** [int] * **Node File System State [restoring_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is restoring The score to give when the state should be corrupt, but is restoring
* **Node File System State [restoring_should_be_destroyed]** [int] * **Node File System State [restoring_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is restoring The score to give when the state should be destroyed, but is restoring
* **Node File System State [restoring]** [int] * **Node File System State [restoring]** [float]
The score to give when the state is restoring The score to give when the state is restoring
* **Node File System State [corrupt_should_be_good]** [int] * **Node File System State [corrupt_should_be_good]** [float]
The score to give when the state should be good, but is corrupt The score to give when the state should be good, but is corrupt
* **Node File System State [corrupt_should_be_repairing]** [int] * **Node File System State [corrupt_should_be_repairing]** [float]
The score to give when the state should be repairing, but is corrupt The score to give when the state should be repairing, but is corrupt
* **Node File System State [corrupt_should_be_restoring]** [int] * **Node File System State [corrupt_should_be_restoring]** [float]
The score to give when the state should be restoring, but is corrupt The score to give when the state should be restoring, but is corrupt
* **Node File System State [corrupt_should_be_destroyed]** [int] * **Node File System State [corrupt_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is corrupt The score to give when the state should be destroyed, but is corrupt
* **Node File System State [corrupt]** [int] * **Node File System State [corrupt]** [float]
The score to give when the state is corrupt The score to give when the state is corrupt
* **Node File System State [destroyed_should_be_good]** [int] * **Node File System State [destroyed_should_be_good]** [float]
The score to give when the state should be good, but is destroyed The score to give when the state should be good, but is destroyed
* **Node File System State [destroyed_should_be_repairing]** [int] * **Node File System State [destroyed_should_be_repairing]** [float]
The score to give when the state should be repairing, but is destroyed The score to give when the state should be repairing, but is destroyed
* **Node File System State [destroyed_should_be_restoring]** [int] * **Node File System State [destroyed_should_be_restoring]** [float]
The score to give when the state should be restoring, but is destroyed The score to give when the state should be restoring, but is destroyed
* **Node File System State [destroyed_should_be_corrupt]** [int] * **Node File System State [destroyed_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is destroyed The score to give when the state should be corrupt, but is destroyed
* **Node File System State [destroyed]** [int] * **Node File System State [destroyed]** [float]
The score to give when the state is destroyed The score to give when the state is destroyed
* **Node File System State [scanning]** [int] * **Node File System State [scanning]** [float]
The score to give when the state is scanning The score to give when the state is scanning
* **IER Status [red_ier_running]** [int] * **IER Status [red_ier_running]** [float]
The score to give when a red agent IER is permitted to run The score to give when a red agent IER is permitted to run
* **IER Status [green_ier_blocked]** [int] * **IER Status [green_ier_blocked]** [float]
The score to give when a green agent IER is prevented from running The score to give when a green agent IER is prevented from running
@@ -308,6 +308,14 @@ Rewards are calculated based on the difference between the current state and ref
The number of steps to take when scanning the file system The number of steps to take when scanning the file system
* **deterministic** [bool]
Set to true if the agent evaluation should be deterministic. Default is ``False``
* **seed** [int]
Seed used in the randomisation in agent training. Default is ``None``
The Lay Down Config The Lay Down Config
******************* *******************

View File

@@ -257,10 +257,19 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg) raise FileNotFoundError(msg)
pass pass
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
return self.learning_path / file_name
@abstractmethod @abstractmethod
def save(self): def save(self):
"""Save the agent.""" """Save the agent."""
self._agent.save(self.session_path) pass
@abstractmethod @abstractmethod
def export(self): def export(self):

View File

@@ -1,9 +1,11 @@
from __future__ import annotations from __future__ import annotations
import json import json
import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig from ray.rllib.algorithms.a2c import A2CConfig
@@ -106,6 +108,7 @@ class RLlibAgent(AgentSessionABC):
timestamp_str=self.timestamp_str, timestamp_str=self.timestamp_str,
), ),
) )
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_steps) self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf") self._agent_config.framework(framework="tf")
@@ -120,9 +123,11 @@ class RLlibAgent(AgentSessionABC):
def _save_checkpoint(self): def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"] episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0: save_checkpoint = False
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): if checkpoint_n:
self._agent.save(str(self.checkpoints_path)) save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn( def learn(
self, self,
@@ -140,9 +145,14 @@ class RLlibAgent(AgentSessionABC):
for i in range(episodes): for i in range(episodes):
self._current_result = self._agent.train() self._current_result = self._agent.train()
self._save_checkpoint() self._save_checkpoint()
self.save()
self._agent.stop() self._agent.stop()
super().learn() super().learn()
# save agent
self.save()
def evaluate( def evaluate(
self, self,
**kwargs, **kwargs,
@@ -162,9 +172,25 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file.""" """Load an agent from file."""
raise NotImplementedError raise NotImplementedError
def save(self): def save(self, overwrite_existing: bool = True):
"""Save the agent.""" """Save the agent."""
raise NotImplementedError # Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self): def export(self):
"""Export the agent to transportable file format.""" """Export the agent to transportable file format."""

View File

@@ -59,16 +59,19 @@ class SB3Agent(AgentSessionABC):
verbose=self.sb3_output_verbose_level, verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps, n_steps=self._training_config.num_steps,
tensorboard_log=str(self._tensorboard_log_path), tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
) )
def _save_checkpoint(self): def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0: save_checkpoint = False
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): if checkpoint_n:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" save_checkpoint = episode_count % checkpoint_n == 0
self._agent.save(checkpoint_path) if episode_count and save_checkpoint:
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self): def _get_latest_checkpoint(self):
pass pass
@@ -90,25 +93,27 @@ class SB3Agent(AgentSessionABC):
self._agent.learn(total_timesteps=time_steps) self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint() self._save_checkpoint()
self._env.reset() self._env.reset()
self.save()
self._env.close() self._env.close()
super().learn() super().learn()
# save agent
self.save()
def evaluate( def evaluate(
self, self,
deterministic: bool = True,
**kwargs, **kwargs,
): ):
""" """
Evaluate the agent. Evaluate the agent.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed. :param kwargs: Any agent-specific key-word args to be passed.
""" """
time_steps = self._training_config.num_steps time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes episodes = self._training_config.num_episodes
self._env.set_as_eval() self._env.set_as_eval()
self.is_eval = True self.is_eval = True
if deterministic: if self._training_config.deterministic:
deterministic_str = "deterministic" deterministic_str = "deterministic"
else: else:
deterministic_str = "non-deterministic" deterministic_str = "non-deterministic"
@@ -119,7 +124,7 @@ class SB3Agent(AgentSessionABC):
obs = self._env.reset() obs = self._env.reset()
for step in range(time_steps): for step in range(time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic) action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
if isinstance(action, np.ndarray): if isinstance(action, np.ndarray):
action = np.int64(action) action = np.int64(action)
obs, rewards, done, info = self._env.step(action) obs, rewards, done, info = self._env.step(action)
@@ -134,7 +139,7 @@ class SB3Agent(AgentSessionABC):
def save(self): def save(self):
"""Save the agent.""" """Save the agent."""
raise NotImplementedError self._agent.save(self._saved_agent_path)
def export(self): def export(self):
"""Export the agent to transportable file format.""" """Export the agent to transportable file format."""

View File

@@ -31,6 +31,16 @@ agent_identifier: PPO
# False # False
random_red_agent: False random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. # Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are: # Options are:
# "BASIC" (The current observation space only) # "BASIC" (The current observation space only)
@@ -83,58 +93,58 @@ sb3_output_verbose_level: NONE
# Generic # Generic
all_ok: 0 all_ok: 0
# Node Hardware State # Node Hardware State
off_should_be_on: -10 off_should_be_on: -0.001
off_should_be_resetting: -5 off_should_be_resetting: -0.0005
on_should_be_off: -2 on_should_be_off: -0.0002
on_should_be_resetting: -5 on_should_be_resetting: -0.0005
resetting_should_be_on: -5 resetting_should_be_on: -0.0005
resetting_should_be_off: -2 resetting_should_be_off: -0.0002
resetting: -3 resetting: -0.0003
# Node Software or Service State # Node Software or Service State
good_should_be_patching: 2 good_should_be_patching: 0.0002
good_should_be_compromised: 5 good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 5 good_should_be_overwhelmed: 0.0005
patching_should_be_good: -5 patching_should_be_good: -0.0005
patching_should_be_compromised: 2 patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 2 patching_should_be_overwhelmed: 0.0002
patching: -3 patching: -0.0003
compromised_should_be_good: -20 compromised_should_be_good: -0.002
compromised_should_be_patching: -20 compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -20 compromised_should_be_overwhelmed: -0.002
compromised: -20 compromised: -0.002
overwhelmed_should_be_good: -20 overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -20 overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -20 overwhelmed_should_be_compromised: -0.002
overwhelmed: -20 overwhelmed: -0.002
# Node File System State # Node File System State
good_should_be_repairing: 2 good_should_be_repairing: 0.0002
good_should_be_restoring: 2 good_should_be_restoring: 0.0002
good_should_be_corrupt: 5 good_should_be_corrupt: 0.0005
good_should_be_destroyed: 10 good_should_be_destroyed: 0.001
repairing_should_be_good: -5 repairing_should_be_good: -0.0005
repairing_should_be_restoring: 2 repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 2 repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0 repairing_should_be_destroyed: 0.0000
repairing: -3 repairing: -0.0003
restoring_should_be_good: -10 restoring_should_be_good: -0.001
restoring_should_be_repairing: -2 restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 1 restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 2 restoring_should_be_destroyed: 0.0002
restoring: -6 restoring: -0.0006
corrupt_should_be_good: -10 corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -10 corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -10 corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 2 corrupt_should_be_destroyed: 0.0002
corrupt: -10 corrupt: -0.001
destroyed_should_be_good: -20 destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -20 destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -20 destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -20 destroyed_should_be_corrupt: -0.002
destroyed: -20 destroyed: -0.002
scanning: -2 scanning: -0.0002
# IER status # IER status
red_ier_running: -5 red_ier_running: -0.0005
green_ier_blocked: -10 green_ier_blocked: -0.001
# Patching / Reset durations # Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS os_patching_duration: 5 # The time taken to patch the OS

View File

@@ -104,64 +104,64 @@ class TrainingConfig:
# Reward values # Reward values
# Generic # Generic
all_ok: int = 0 all_ok: float = 0
# Node Hardware State # Node Hardware State
off_should_be_on: int = -10 off_should_be_on: float = -0.001
off_should_be_resetting: int = -5 off_should_be_resetting: float = -0.0005
on_should_be_off: int = -2 on_should_be_off: float = -0.0002
on_should_be_resetting: int = -5 on_should_be_resetting: float = -0.0005
resetting_should_be_on: int = -5 resetting_should_be_on: float = -0.0005
resetting_should_be_off: int = -2 resetting_should_be_off: float = -0.0002
resetting: int = -3 resetting: float = -0.0003
# Node Software or Service State # Node Software or Service State
good_should_be_patching: int = 2 good_should_be_patching: float = 0.0002
good_should_be_compromised: int = 5 good_should_be_compromised: float = 0.0005
good_should_be_overwhelmed: int = 5 good_should_be_overwhelmed: float = 0.0005
patching_should_be_good: int = -5 patching_should_be_good: float = -0.0005
patching_should_be_compromised: int = 2 patching_should_be_compromised: float = 0.0002
patching_should_be_overwhelmed: int = 2 patching_should_be_overwhelmed: float = 0.0002
patching: int = -3 patching: float = -0.0003
compromised_should_be_good: int = -20 compromised_should_be_good: float = -0.002
compromised_should_be_patching: int = -20 compromised_should_be_patching: float = -0.002
compromised_should_be_overwhelmed: int = -20 compromised_should_be_overwhelmed: float = -0.002
compromised: int = -20 compromised: float = -0.002
overwhelmed_should_be_good: int = -20 overwhelmed_should_be_good: float = -0.002
overwhelmed_should_be_patching: int = -20 overwhelmed_should_be_patching: float = -0.002
overwhelmed_should_be_compromised: int = -20 overwhelmed_should_be_compromised: float = -0.002
overwhelmed: int = -20 overwhelmed: float = -0.002
# Node File System State # Node File System State
good_should_be_repairing: int = 2 good_should_be_repairing: float = 0.0002
good_should_be_restoring: int = 2 good_should_be_restoring: float = 0.0002
good_should_be_corrupt: int = 5 good_should_be_corrupt: float = 0.0005
good_should_be_destroyed: int = 10 good_should_be_destroyed: float = 0.001
repairing_should_be_good: int = -5 repairing_should_be_good: float = -0.0005
repairing_should_be_restoring: int = 2 repairing_should_be_restoring: float = 0.0002
repairing_should_be_corrupt: int = 2 repairing_should_be_corrupt: float = 0.0002
repairing_should_be_destroyed: int = 0 repairing_should_be_destroyed: float = 0.0000
repairing: int = -3 repairing: float = -0.0003
restoring_should_be_good: int = -10 restoring_should_be_good: float = -0.001
restoring_should_be_repairing: int = -2 restoring_should_be_repairing: float = -0.0002
restoring_should_be_corrupt: int = 1 restoring_should_be_corrupt: float = 0.0001
restoring_should_be_destroyed: int = 2 restoring_should_be_destroyed: float = 0.0002
restoring: int = -6 restoring: float = -0.0006
corrupt_should_be_good: int = -10 corrupt_should_be_good: float = -0.001
corrupt_should_be_repairing: int = -10 corrupt_should_be_repairing: float = -0.001
corrupt_should_be_restoring: int = -10 corrupt_should_be_restoring: float = -0.001
corrupt_should_be_destroyed: int = 2 corrupt_should_be_destroyed: float = 0.0002
corrupt: int = -10 corrupt: float = -0.001
destroyed_should_be_good: int = -20 destroyed_should_be_good: float = -0.002
destroyed_should_be_repairing: int = -20 destroyed_should_be_repairing: float = -0.002
destroyed_should_be_restoring: int = -20 destroyed_should_be_restoring: float = -0.002
destroyed_should_be_corrupt: int = -20 destroyed_should_be_corrupt: float = -0.002
destroyed: int = -20 destroyed: float = -0.002
scanning: int = -2 scanning: float = -0.0002
# IER status # IER status
red_ier_running: int = -5 red_ier_running: float = -0.0005
green_ier_blocked: int = -10 green_ier_blocked: float = -0.001
# Patching / Reset durations # Patching / Reset durations
os_patching_duration: int = 5 os_patching_duration: int = 5
@@ -188,6 +188,12 @@ class TrainingConfig:
file_system_scanning_limit: int = 5 file_system_scanning_limit: int = 5
"The time taken to scan the file system" "The time taken to scan the file system"
deterministic: bool = False
"If true, the training will be deterministic"
seed: Optional[int] = None
"The random number generator seed to be used while training the agent"
@classmethod @classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
""" """

View File

@@ -148,10 +148,10 @@ class Primaite(Env):
self.step_info = {} self.step_info = {}
# Total reward # Total reward
self.total_reward = 0 self.total_reward: float = 0
# Average reward # Average reward
self.average_reward = 0 self.average_reward: float = 0
# Episode count # Episode count
self.episode_count = 0 self.episode_count = 0
@@ -289,9 +289,9 @@ class Primaite(Env):
self._create_random_red_agent() self._create_random_red_agent()
# Reset counters and totals # Reset counters and totals
self.total_reward = 0 self.total_reward = 0.0
self.step_count = 0 self.step_count = 0
self.average_reward = 0 self.average_reward = 0.0
# Update observations space and return # Update observations space and return
self.update_environent_obs() self.update_environent_obs()

View File

@@ -20,7 +20,7 @@ def calculate_reward_function(
red_iers, red_iers,
step_count, step_count,
config_values, config_values,
): ) -> float:
""" """
Compares the states of the initial and final nodes/links to get a reward. Compares the states of the initial and final nodes/links to get a reward.
@@ -33,7 +33,7 @@ def calculate_reward_function(
step_count: current step step_count: current step
config_values: Config values config_values: Config values
""" """
reward_value = 0 reward_value: float = 0.0
# For each node, compare hardware state, SoftwareState, service states # For each node, compare hardware state, SoftwareState, service states
for node_key, final_node in final_nodes.items(): for node_key, final_node in final_nodes.items():
@@ -94,7 +94,7 @@ def calculate_reward_function(
return reward_value return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values): def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
""" """
Calculates score relating to the hardware state of a node. Calculates score relating to the hardware state of a node.
@@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
config_values: Config values config_values: Config values
""" """
score = 0 score: float = 0.0
final_node_operating_state = final_node.hardware_state final_node_operating_state = final_node.hardware_state
reference_node_operating_state = reference_node.hardware_state reference_node_operating_state = reference_node.hardware_state
@@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score return score
def score_node_os_state(final_node, initial_node, reference_node, config_values): def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
""" """
Calculates score relating to the Software State of a node. Calculates score relating to the Software State of a node.
@@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
config_values: Config values config_values: Config values
""" """
score = 0 score: float = 0.0
final_node_os_state = final_node.software_state final_node_os_state = final_node.software_state
reference_node_os_state = reference_node.software_state reference_node_os_state = reference_node.software_state
@@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score return score
def score_node_service_state(final_node, initial_node, reference_node, config_values): def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
""" """
Calculates score relating to the service state(s) of a node. Calculates score relating to the service state(s) of a node.
@@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
config_values: Config values config_values: Config values
""" """
score = 0 score: float = 0.0
final_node_services: Dict[str, Service] = final_node.services final_node_services: Dict[str, Service] = final_node.services
reference_node_services: Dict[str, Service] = reference_node.services reference_node_services: Dict[str, Service] = reference_node.services
@@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score return score
def score_node_file_system(final_node, initial_node, reference_node, config_values): def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
""" """
Calculates score relating to the file system state of a node. Calculates score relating to the file system state of a node.
@@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
initial_node: The node before red and blue agents take effect initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
""" """
score = 0 score: float = 0.0
final_node_file_system_state = final_node.file_system_state_actual final_node_file_system_state = final_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual reference_node_file_system_state = reference_node.file_system_state_actual

View File

@@ -31,7 +31,7 @@ class Transaction(object):
"The observation space before any actions are taken" "The observation space before any actions are taken"
self.obs_space_post = None self.obs_space_post = None
"The observation space after any actions are taken" "The observation space after any actions are taken"
self.reward = None self.reward: float = None
"The reward value" "The reward value"
self.action_space = None self.action_space = None
"The action space invoked by the agent" "The action space invoked by the agent"

View File

@@ -0,0 +1,155 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: None
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 0
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0.0000
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,40 +1,94 @@
# Main Config File # Training Config File
# Generic config values # Sets which agent algorithm framework will be used.
# Choose one of these (dependent on Agent being trained) # Options are:
# "STABLE_BASELINES3_PPO" # "SB3" (Stable Baselines3)
# "STABLE_BASELINES3_A2C" # "RLLIB" (Ray RLlib)
# "GENERIC" # "CUSTOM" (Custom Agent)
agent_identifier: STABLE_BASELINES3_A2C agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised. # Sets whether Red Agent POL and IER is randomised.
# Options are: # Options are:
# True # True
# False # False
random_red_agent: True random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: 67890
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: True
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined: # Sets How the Action Space is defined:
# "NODE" # "NODE"
# "ACL" # "ACL"
# "ANY" node and acl actions # "ANY" node and acl actions
action_type: NODE action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session # Number of episodes to run per session
num_episodes: 10 num_episodes: 10
# Number of time_steps per episode # Number of time_steps per episode
num_steps: 256 num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10 # Sets how often the agent will save a checkpoint (every n time episodes).
# Type of session to be run (TRAINING or EVALUATION) # Set to 0 if no checkpoints are required. Default is 10
session_type: TRAINING checkpoint_every_n_episodes: 0
# Determine whether to load an agent from file
load_agent: False # Time delay (milliseconds) between steps for CUSTOM agents.
# File path and file name of agent if you're loading one in time_delay: 5
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values # Environment config values
# The high value for the observation space # The high value for the observation space
observation_space_high_value: 1000000000 observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values # Reward values
# Generic # Generic
all_ok: 0 all_ok: 0

View File

@@ -58,7 +58,6 @@ class TempPrimaiteSession(PrimaiteSession):
def __exit__(self, type, value, tb): def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path) shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}") _LOGGER.debug(f"Deleted temp session directory: {self.session_path}")

View File

@@ -1,6 +1,7 @@
import tempfile import tempfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from uuid import uuid4
from primaite import getLogger from primaite import getLogger
@@ -14,9 +15,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
:param session_timestamp: This is the datetime that the session started. :param session_timestamp: This is the datetime that the session started.
:return: The session directory path. :return: The session directory path.
""" """
date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = Path(tempfile.gettempdir()) / "primaite" / str(uuid4())
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True) session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created temp session directory: {session_path}") _LOGGER.debug(f"Created temp session directory: {session_path}")
return session_path return session_path

View File

@@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session):
# Check that the network png file exists # Check that the network png file exists
assert (session_path / f"network_{session.timestamp_str}.png").exists() assert (session_path / f"network_{session.timestamp_str}.png").exists()
# Check that the saved agent exists
assert session._agent_session._saved_agent_path.exists()
# Check that both the transactions and av reward csv files exist # Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir(): for file in session.learning_path.iterdir():
if file.suffix == ".csv": if file.suffix == ".csv":

View File

@@ -0,0 +1,49 @@
import pytest as pytest
from primaite.config.lay_down_config import dos_very_basic_config_path
from tests import TEST_CONFIG_ROOT
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_seeded_learning(temp_primaite_session):
"""Test running seeded learning produces the same output when ran twice."""
expected_mean_reward_per_episode = {
1: -90.703125,
2: -91.15234375,
3: -87.5,
4: -92.2265625,
5: -94.6875,
6: -91.19140625,
7: -88.984375,
8: -88.3203125,
9: -112.79296875,
10: -100.01953125,
}
with temp_primaite_session as session:
assert session._training_config.seed == 67890, (
"Expected output is based upon a agent that was trained with " "seed 67890"
)
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_deterministic_evaluation(temp_primaite_session):
"""Test running deterministic evaluation gives same av eward per episode."""
with temp_primaite_session as session:
# do stuff
session.learn()
session.evaluate()
eval_mean_reward = session.eval_av_reward_per_episode_csv()
assert len(set(eval_mean_reward.values())) == 1