Merge remote-tracking branch 'origin/dev' into feature/901-change-functionality-acl-rules
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -138,4 +138,8 @@ dmypy.json
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
|
||||
# outputs
|
||||
src/primaite/outputs/
|
||||
|
||||
63
README.md
63
README.md
@@ -1 +1,64 @@
|
||||
# PrimAITE
|
||||
|
||||
## Getting Started with PrimAITE
|
||||
|
||||
### Pre-Requisites
|
||||
|
||||
In order to get **PrimAITE** installed, you will need to have the following installed:
|
||||
|
||||
- `python3.8+`
|
||||
- `python3-pip`
|
||||
- `virtualenv`
|
||||
|
||||
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
|
||||
|
||||
### Installation from source
|
||||
#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv)
|
||||
|
||||
```unix
|
||||
python3 -m venv <name_of_venv>
|
||||
```
|
||||
|
||||
#### 2. Activate the venv
|
||||
|
||||
##### Unix
|
||||
```bash
|
||||
source <name_of_venv>/bin/activate
|
||||
```
|
||||
|
||||
##### Windows
|
||||
```powershell
|
||||
.\<name_of_venv>\Scripts\activate
|
||||
```
|
||||
|
||||
#### 3. Install `primaite` into the venv along with all of it's dependencies
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .
|
||||
```
|
||||
|
||||
### Development Installation
|
||||
To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
```
|
||||
|
||||
## Building documentation
|
||||
The PrimAITE documentation can be built with the following commands:
|
||||
|
||||
##### Unix
|
||||
```bash
|
||||
cd docs
|
||||
make html
|
||||
```
|
||||
|
||||
##### Windows
|
||||
```powershell
|
||||
cd docs
|
||||
.\make.bat html
|
||||
```
|
||||
|
||||
This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build
|
||||
options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation
|
||||
for your specific output requirements.
|
||||
|
||||
@@ -82,203 +82,203 @@ The environment config file consists of the following attributes:
|
||||
|
||||
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
|
||||
|
||||
* **Generic [all_ok]** [int]
|
||||
* **Generic [all_ok]** [float]
|
||||
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
|
||||
* **Node Hardware State [off_should_be_on]** [int]
|
||||
* **Node Hardware State [off_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is off
|
||||
|
||||
* **Node Hardware State [off_should_be_resetting]** [int]
|
||||
* **Node Hardware State [off_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is off
|
||||
|
||||
* **Node Hardware State [on_should_be_off]** [int]
|
||||
* **Node Hardware State [on_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is on
|
||||
|
||||
* **Node Hardware State [on_should_be_resetting]** [int]
|
||||
* **Node Hardware State [on_should_be_resetting]** [float]
|
||||
|
||||
The score to give when the node should be resetting, but is on
|
||||
|
||||
* **Node Hardware State [resetting_should_be_on]** [int]
|
||||
* **Node Hardware State [resetting_should_be_on]** [float]
|
||||
|
||||
The score to give when the node should be on, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting_should_be_off]** [int]
|
||||
* **Node Hardware State [resetting_should_be_off]** [float]
|
||||
|
||||
The score to give when the node should be off, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting]** [int]
|
||||
* **Node Hardware State [resetting]** [float]
|
||||
|
||||
The score to give when the node is resetting
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_patching]** [int]
|
||||
* **Node Operating System or Service State [good_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_compromised]** [int]
|
||||
* **Node Operating System or Service State [good_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int]
|
||||
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_good]** [int]
|
||||
* **Node Operating System or Service State [patching_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_compromised]** [int]
|
||||
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int]
|
||||
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching]** [int]
|
||||
* **Node Operating System or Service State [patching]** [float]
|
||||
|
||||
The score to give when the state is patching
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_good]** [int]
|
||||
* **Node Operating System or Service State [compromised_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_patching]** [int]
|
||||
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int]
|
||||
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised]** [int]
|
||||
* **Node Operating System or Service State [compromised]** [float]
|
||||
|
||||
The score to give when the state is compromised
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int]
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int]
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
|
||||
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int]
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
|
||||
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed]** [int]
|
||||
* **Node Operating System or Service State [overwhelmed]** [float]
|
||||
|
||||
The score to give when the state is overwhelmed
|
||||
|
||||
* **Node File System State [good_should_be_repairing]** [int]
|
||||
* **Node File System State [good_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is good
|
||||
|
||||
* **Node File System State [good_should_be_restoring]** [int]
|
||||
* **Node File System State [good_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is good
|
||||
|
||||
* **Node File System State [good_should_be_corrupt]** [int]
|
||||
* **Node File System State [good_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is good
|
||||
|
||||
* **Node File System State [good_should_be_destroyed]** [int]
|
||||
* **Node File System State [good_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is good
|
||||
|
||||
* **Node File System State [repairing_should_be_good]** [int]
|
||||
* **Node File System State [repairing_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_restoring]** [int]
|
||||
* **Node File System State [repairing_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_corrupt]** [int]
|
||||
* **Node File System State [repairing_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is repairing
|
||||
|
||||
* **Node File System State [repairing_should_be_destroyed]** [int]
|
||||
* **Node File System State [repairing_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is repairing
|
||||
|
||||
* **Node File System State [repairing]** [int]
|
||||
* **Node File System State [repairing]** [float]
|
||||
|
||||
The score to give when the state is repairing
|
||||
|
||||
* **Node File System State [restoring_should_be_good]** [int]
|
||||
* **Node File System State [restoring_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_repairing]** [int]
|
||||
* **Node File System State [restoring_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_corrupt]** [int]
|
||||
* **Node File System State [restoring_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is restoring
|
||||
|
||||
* **Node File System State [restoring_should_be_destroyed]** [int]
|
||||
* **Node File System State [restoring_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is restoring
|
||||
|
||||
* **Node File System State [restoring]** [int]
|
||||
* **Node File System State [restoring]** [float]
|
||||
|
||||
The score to give when the state is restoring
|
||||
|
||||
* **Node File System State [corrupt_should_be_good]** [int]
|
||||
* **Node File System State [corrupt_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_repairing]** [int]
|
||||
* **Node File System State [corrupt_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_restoring]** [int]
|
||||
* **Node File System State [corrupt_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt_should_be_destroyed]** [int]
|
||||
* **Node File System State [corrupt_should_be_destroyed]** [float]
|
||||
|
||||
The score to give when the state should be destroyed, but is corrupt
|
||||
|
||||
* **Node File System State [corrupt]** [int]
|
||||
* **Node File System State [corrupt]** [float]
|
||||
|
||||
The score to give when the state is corrupt
|
||||
|
||||
* **Node File System State [destroyed_should_be_good]** [int]
|
||||
* **Node File System State [destroyed_should_be_good]** [float]
|
||||
|
||||
The score to give when the state should be good, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_repairing]** [int]
|
||||
* **Node File System State [destroyed_should_be_repairing]** [float]
|
||||
|
||||
The score to give when the state should be repairing, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_restoring]** [int]
|
||||
* **Node File System State [destroyed_should_be_restoring]** [float]
|
||||
|
||||
The score to give when the state should be restoring, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed_should_be_corrupt]** [int]
|
||||
* **Node File System State [destroyed_should_be_corrupt]** [float]
|
||||
|
||||
The score to give when the state should be corrupt, but is destroyed
|
||||
|
||||
* **Node File System State [destroyed]** [int]
|
||||
* **Node File System State [destroyed]** [float]
|
||||
|
||||
The score to give when the state is destroyed
|
||||
|
||||
* **Node File System State [scanning]** [int]
|
||||
* **Node File System State [scanning]** [float]
|
||||
|
||||
The score to give when the state is scanning
|
||||
|
||||
* **IER Status [red_ier_running]** [int]
|
||||
* **IER Status [red_ier_running]** [float]
|
||||
|
||||
The score to give when a red agent IER is permitted to run
|
||||
|
||||
* **IER Status [green_ier_blocked]** [int]
|
||||
* **IER Status [green_ier_blocked]** [float]
|
||||
|
||||
The score to give when a green agent IER is prevented from running
|
||||
|
||||
@@ -308,6 +308,14 @@ Rewards are calculated based on the difference between the current state and ref
|
||||
|
||||
The number of steps to take when scanning the file system
|
||||
|
||||
* **deterministic** [bool]
|
||||
|
||||
Set to true if the agent evaluation should be deterministic. Default is ``False``
|
||||
|
||||
* **seed** [int]
|
||||
|
||||
Seed used in the randomisation in agent training. Default is ``None``
|
||||
|
||||
The Lay Down Config
|
||||
*******************
|
||||
|
||||
|
||||
@@ -257,10 +257,19 @@ class AgentSessionABC(ABC):
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = (
|
||||
f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip"
|
||||
)
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
self._agent.save(self.session_path)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
@@ -106,6 +108,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
timestamp_str=self.timestamp_str,
|
||||
),
|
||||
)
|
||||
self._agent_config.seed = self._training_config.seed
|
||||
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
@@ -120,9 +123,11 @@ class RLlibAgent(AgentSessionABC):
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
@@ -140,9 +145,14 @@ class RLlibAgent(AgentSessionABC):
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self.save()
|
||||
self._agent.stop()
|
||||
|
||||
super().learn()
|
||||
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
@@ -162,9 +172,25 @@ class RLlibAgent(AgentSessionABC):
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
def save(self, overwrite_existing: bool = True):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
# Make temp dir to save in isolation
|
||||
temp_dir = self.learning_path / str(uuid4())
|
||||
temp_dir.mkdir()
|
||||
|
||||
# Save the agent to the temp dir
|
||||
self._agent.save(str(temp_dir))
|
||||
|
||||
# Capture the saved Rllib checkpoint inside the temp directory
|
||||
for file in temp_dir.iterdir():
|
||||
checkpoint_dir = file
|
||||
break
|
||||
|
||||
# Zip the folder
|
||||
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
|
||||
# Drop the temp directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
|
||||
@@ -59,16 +59,19 @@ class SB3Agent(AgentSessionABC):
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
@@ -90,25 +93,27 @@ class SB3Agent(AgentSessionABC):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env.reset()
|
||||
self.save()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
deterministic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if deterministic:
|
||||
if self._training_config.deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
@@ -119,7 +124,7 @@ class SB3Agent(AgentSessionABC):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=deterministic)
|
||||
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
@@ -134,7 +139,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
raise NotImplementedError
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
|
||||
@@ -31,6 +31,16 @@ agent_identifier: PPO
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
@@ -83,58 +93,58 @@ sb3_output_verbose_level: NONE
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
|
||||
@@ -104,64 +104,64 @@ class TrainingConfig:
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: int = 0
|
||||
all_ok: float = 0
|
||||
|
||||
# Node Hardware State
|
||||
off_should_be_on: int = -10
|
||||
off_should_be_resetting: int = -5
|
||||
on_should_be_off: int = -2
|
||||
on_should_be_resetting: int = -5
|
||||
resetting_should_be_on: int = -5
|
||||
resetting_should_be_off: int = -2
|
||||
resetting: int = -3
|
||||
off_should_be_on: float = -0.001
|
||||
off_should_be_resetting: float = -0.0005
|
||||
on_should_be_off: float = -0.0002
|
||||
on_should_be_resetting: float = -0.0005
|
||||
resetting_should_be_on: float = -0.0005
|
||||
resetting_should_be_off: float = -0.0002
|
||||
resetting: float = -0.0003
|
||||
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: int = 2
|
||||
good_should_be_compromised: int = 5
|
||||
good_should_be_overwhelmed: int = 5
|
||||
patching_should_be_good: int = -5
|
||||
patching_should_be_compromised: int = 2
|
||||
patching_should_be_overwhelmed: int = 2
|
||||
patching: int = -3
|
||||
compromised_should_be_good: int = -20
|
||||
compromised_should_be_patching: int = -20
|
||||
compromised_should_be_overwhelmed: int = -20
|
||||
compromised: int = -20
|
||||
overwhelmed_should_be_good: int = -20
|
||||
overwhelmed_should_be_patching: int = -20
|
||||
overwhelmed_should_be_compromised: int = -20
|
||||
overwhelmed: int = -20
|
||||
good_should_be_patching: float = 0.0002
|
||||
good_should_be_compromised: float = 0.0005
|
||||
good_should_be_overwhelmed: float = 0.0005
|
||||
patching_should_be_good: float = -0.0005
|
||||
patching_should_be_compromised: float = 0.0002
|
||||
patching_should_be_overwhelmed: float = 0.0002
|
||||
patching: float = -0.0003
|
||||
compromised_should_be_good: float = -0.002
|
||||
compromised_should_be_patching: float = -0.002
|
||||
compromised_should_be_overwhelmed: float = -0.002
|
||||
compromised: float = -0.002
|
||||
overwhelmed_should_be_good: float = -0.002
|
||||
overwhelmed_should_be_patching: float = -0.002
|
||||
overwhelmed_should_be_compromised: float = -0.002
|
||||
overwhelmed: float = -0.002
|
||||
|
||||
# Node File System State
|
||||
good_should_be_repairing: int = 2
|
||||
good_should_be_restoring: int = 2
|
||||
good_should_be_corrupt: int = 5
|
||||
good_should_be_destroyed: int = 10
|
||||
repairing_should_be_good: int = -5
|
||||
repairing_should_be_restoring: int = 2
|
||||
repairing_should_be_corrupt: int = 2
|
||||
repairing_should_be_destroyed: int = 0
|
||||
repairing: int = -3
|
||||
restoring_should_be_good: int = -10
|
||||
restoring_should_be_repairing: int = -2
|
||||
restoring_should_be_corrupt: int = 1
|
||||
restoring_should_be_destroyed: int = 2
|
||||
restoring: int = -6
|
||||
corrupt_should_be_good: int = -10
|
||||
corrupt_should_be_repairing: int = -10
|
||||
corrupt_should_be_restoring: int = -10
|
||||
corrupt_should_be_destroyed: int = 2
|
||||
corrupt: int = -10
|
||||
destroyed_should_be_good: int = -20
|
||||
destroyed_should_be_repairing: int = -20
|
||||
destroyed_should_be_restoring: int = -20
|
||||
destroyed_should_be_corrupt: int = -20
|
||||
destroyed: int = -20
|
||||
scanning: int = -2
|
||||
good_should_be_repairing: float = 0.0002
|
||||
good_should_be_restoring: float = 0.0002
|
||||
good_should_be_corrupt: float = 0.0005
|
||||
good_should_be_destroyed: float = 0.001
|
||||
repairing_should_be_good: float = -0.0005
|
||||
repairing_should_be_restoring: float = 0.0002
|
||||
repairing_should_be_corrupt: float = 0.0002
|
||||
repairing_should_be_destroyed: float = 0.0000
|
||||
repairing: float = -0.0003
|
||||
restoring_should_be_good: float = -0.001
|
||||
restoring_should_be_repairing: float = -0.0002
|
||||
restoring_should_be_corrupt: float = 0.0001
|
||||
restoring_should_be_destroyed: float = 0.0002
|
||||
restoring: float = -0.0006
|
||||
corrupt_should_be_good: float = -0.001
|
||||
corrupt_should_be_repairing: float = -0.001
|
||||
corrupt_should_be_restoring: float = -0.001
|
||||
corrupt_should_be_destroyed: float = 0.0002
|
||||
corrupt: float = -0.001
|
||||
destroyed_should_be_good: float = -0.002
|
||||
destroyed_should_be_repairing: float = -0.002
|
||||
destroyed_should_be_restoring: float = -0.002
|
||||
destroyed_should_be_corrupt: float = -0.002
|
||||
destroyed: float = -0.002
|
||||
scanning: float = -0.0002
|
||||
|
||||
# IER status
|
||||
red_ier_running: int = -5
|
||||
green_ier_blocked: int = -10
|
||||
red_ier_running: float = -0.0005
|
||||
green_ier_blocked: float = -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: int = 5
|
||||
@@ -188,6 +188,12 @@ class TrainingConfig:
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system"
|
||||
|
||||
deterministic: bool = False
|
||||
"If true, the training will be deterministic"
|
||||
|
||||
seed: Optional[int] = None
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
"""
|
||||
|
||||
@@ -148,10 +148,10 @@ class Primaite(Env):
|
||||
self.step_info = {}
|
||||
|
||||
# Total reward
|
||||
self.total_reward = 0
|
||||
self.total_reward: float = 0
|
||||
|
||||
# Average reward
|
||||
self.average_reward = 0
|
||||
self.average_reward: float = 0
|
||||
|
||||
# Episode count
|
||||
self.episode_count = 0
|
||||
@@ -289,9 +289,9 @@ class Primaite(Env):
|
||||
self._create_random_red_agent()
|
||||
|
||||
# Reset counters and totals
|
||||
self.total_reward = 0
|
||||
self.total_reward = 0.0
|
||||
self.step_count = 0
|
||||
self.average_reward = 0
|
||||
self.average_reward = 0.0
|
||||
|
||||
# Update observations space and return
|
||||
self.update_environent_obs()
|
||||
|
||||
@@ -20,7 +20,7 @@ def calculate_reward_function(
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
):
|
||||
) -> float:
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
@@ -33,7 +33,7 @@ def calculate_reward_function(
|
||||
step_count: current step
|
||||
config_values: Config values
|
||||
"""
|
||||
reward_value = 0
|
||||
reward_value: float = 0.0
|
||||
|
||||
# For each node, compare hardware state, SoftwareState, service states
|
||||
for node_key, final_node in final_nodes.items():
|
||||
@@ -94,7 +94,7 @@ def calculate_reward_function(
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
@@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score = 0
|
||||
score: float = 0.0
|
||||
final_node_operating_state = final_node.hardware_state
|
||||
reference_node_operating_state = reference_node.hardware_state
|
||||
|
||||
@@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
@@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score = 0
|
||||
score: float = 0.0
|
||||
final_node_os_state = final_node.software_state
|
||||
reference_node_os_state = reference_node.software_state
|
||||
|
||||
@@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
@@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score = 0
|
||||
score: float = 0.0
|
||||
final_node_services: Dict[str, Service] = final_node.services
|
||||
reference_node_services: Dict[str, Service] = reference_node.services
|
||||
|
||||
@@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
@@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
"""
|
||||
score = 0
|
||||
score: float = 0.0
|
||||
final_node_file_system_state = final_node.file_system_state_actual
|
||||
reference_node_file_system_state = reference_node.file_system_state_actual
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class Transaction(object):
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward = None
|
||||
self.reward: float = None
|
||||
"The reward value"
|
||||
self.action_space = None
|
||||
"The action space invoked by the agent"
|
||||
|
||||
155
tests/config/ppo_not_seeded_training_config.yaml
Normal file
155
tests/config/ppo_not_seeded_training_config.yaml
Normal file
@@ -0,0 +1,155 @@
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: None
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 0
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0.0000
|
||||
# Node Hardware State
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -1,40 +1,94 @@
|
||||
# Main Config File
|
||||
# Training Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: True
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: 67890
|
||||
|
||||
# Set whether the agent will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: True
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# observation space
|
||||
observation_space:
|
||||
# flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
# - name: NODE_STATUSES
|
||||
# - name: LINK_TRAFFIC_LEVELS
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 10
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 0
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
@@ -58,7 +58,6 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
shutil.rmtree(self.session_path)
|
||||
shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
@@ -14,9 +15,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / str(uuid4())
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created temp session directory: {session_path}")
|
||||
return session_path
|
||||
|
||||
@@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session):
|
||||
# Check that the network png file exists
|
||||
assert (session_path / f"network_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that the saved agent exists
|
||||
assert session._agent_session._saved_agent_path.exists()
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
|
||||
49
tests/test_seeding_and_deterministic_session.py
Normal file
49
tests/test_seeding_and_deterministic_session.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest as pytest
|
||||
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_seeded_learning(temp_primaite_session):
|
||||
"""Test running seeded learning produces the same output when ran twice."""
|
||||
expected_mean_reward_per_episode = {
|
||||
1: -90.703125,
|
||||
2: -91.15234375,
|
||||
3: -87.5,
|
||||
4: -92.2265625,
|
||||
5: -94.6875,
|
||||
6: -91.19140625,
|
||||
7: -88.984375,
|
||||
8: -88.3203125,
|
||||
9: -112.79296875,
|
||||
10: -100.01953125,
|
||||
}
|
||||
with temp_primaite_session as session:
|
||||
assert session._training_config.seed == 67890, (
|
||||
"Expected output is based upon a agent that was trained with " "seed 67890"
|
||||
)
|
||||
session.learn()
|
||||
actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
|
||||
|
||||
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_deterministic_evaluation(temp_primaite_session):
|
||||
"""Test running deterministic evaluation gives same av eward per episode."""
|
||||
with temp_primaite_session as session:
|
||||
# do stuff
|
||||
session.learn()
|
||||
session.evaluate()
|
||||
eval_mean_reward = session.eval_av_reward_per_episode_csv()
|
||||
assert len(set(eval_mean_reward.values())) == 1
|
||||
Reference in New Issue
Block a user