#1386: added documentation + dealing with pre-commit checks
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -142,4 +142,4 @@ cython_debug/
|
||||
.idea/
|
||||
|
||||
# outputs
|
||||
src/primaite/outputs/
|
||||
src/primaite/outputs/
|
||||
|
||||
62
README.md
62
README.md
@@ -1 +1,63 @@
|
||||
# PrimAITE
|
||||
|
||||
## Getting Started with PrimAITE
|
||||
|
||||
### Pre-Requisites
|
||||
|
||||
In order to get **PrimAITE** installed, you will need to have the following installed:
|
||||
|
||||
- `python3.8+`
|
||||
- `python3-pip`
|
||||
- `virtualenv`
|
||||
|
||||
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
|
||||
|
||||
### Installation from source
|
||||
#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv)
|
||||
|
||||
```unix
|
||||
python3 -m venv <name_of_venv>
|
||||
```
|
||||
|
||||
#### 2. Activate the venv
|
||||
|
||||
##### Unix
|
||||
```bash
|
||||
source <name_of_venv>/bin/activate
|
||||
```
|
||||
|
||||
##### Windows
|
||||
```powershell
|
||||
.\<name_of_venv>\Scripts\activate
|
||||
```
|
||||
|
||||
#### 3. Install `primaite` into the venv along with all of it's dependencies
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .
|
||||
```
|
||||
|
||||
### Development Installation
|
||||
To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example:
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
|
||||
## Building documentation
|
||||
The PrimAITE documentation can be built with the following commands:
|
||||
|
||||
##### Unix
|
||||
```bash
|
||||
cd docs
|
||||
make html
|
||||
```
|
||||
|
||||
##### Windows
|
||||
```powershell
|
||||
cd docs
|
||||
.\make.bat html
|
||||
```
|
||||
|
||||
This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build
|
||||
options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation
|
||||
for your specific output requirements.
|
||||
|
||||
@@ -293,6 +293,14 @@ Rewards are calculated based on the difference between the current state and ref
|
||||
|
||||
The number of steps to take when scanning the file system
|
||||
|
||||
* **deterministic** [bool]
|
||||
|
||||
Set to true if the agent should use deterministic actions. Default is ``False``
|
||||
|
||||
* **seed** [int]
|
||||
|
||||
Seed used in the randomisation in training / evaluation. Default is ``None``
|
||||
|
||||
The Lay Down Config
|
||||
*******************
|
||||
|
||||
|
||||
@@ -320,4 +320,4 @@
|
||||
| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+
|
||||
| zipp | 3.15.0 | MIT License | https://github.com/jaraco/zipp |
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+
|
||||
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+
|
||||
|
||||
@@ -31,7 +31,7 @@ def _get_primaite_config():
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
|
||||
return primaite_config
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pkg_resources
|
||||
@@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
import re
|
||||
|
||||
from primaite import LOG_PATH
|
||||
|
||||
if os.path.isfile(LOG_PATH):
|
||||
@@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
print(re.sub(r"\n*", "", line))
|
||||
|
||||
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union, Optional
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -173,8 +173,7 @@ def main_training_config_path() -> Path:
|
||||
return path
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path],
|
||||
legacy_file: bool = False) -> TrainingConfig:
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
@@ -219,9 +218,7 @@ def load(file_path: Union[str, Path],
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
num_steps: int = 256,
|
||||
action_type: str = "ANY"
|
||||
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
@@ -233,10 +230,7 @@ def convert_legacy_training_config_dict(
|
||||
don't have action_type values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"num_steps": num_steps,
|
||||
"action_type": action_type
|
||||
}
|
||||
config_dict = {"num_steps": num_steps, "action_type": action_type}
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
|
||||
@@ -14,8 +14,7 @@ from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
||||
is_valid_node_action
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -24,8 +23,9 @@ from primaite.common.enums import (
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState, ObservationType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
@@ -35,17 +35,14 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import \
|
||||
NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
@@ -178,7 +175,6 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -200,7 +196,6 @@ class Primaite(Env):
|
||||
try:
|
||||
plt.tight_layout()
|
||||
nx.draw_networkx(self.network, with_labels=True)
|
||||
now = datetime.now() # current date and time
|
||||
|
||||
file_path = session_path / f"network_{timestamp_str}.png"
|
||||
plt.savefig(file_path, format="PNG")
|
||||
@@ -222,7 +217,9 @@ class Primaite(Env):
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
|
||||
self.action_space = spaces.Discrete(
|
||||
len(self.action_dict), seed=self.training_config.seed
|
||||
)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
@@ -233,13 +230,19 @@ class Primaite(Env):
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
|
||||
self.action_space = spaces.Discrete(
|
||||
len(self.action_dict), seed=self.training_config.seed
|
||||
)
|
||||
elif self.training_config.action_type == ActionType.ANY:
|
||||
_LOGGER.info("Action space type ANY selected - Node + ACL")
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
|
||||
self.action_space = spaces.Discrete(
|
||||
len(self.action_dict), seed=self.training_config.seed
|
||||
)
|
||||
else:
|
||||
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
_LOGGER.info(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
header = ["Episode", "Average Reward"]
|
||||
@@ -380,7 +383,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
#print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -407,12 +410,11 @@ class Primaite(Env):
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
"""Calls the __close__ method."""
|
||||
self.__close__()
|
||||
|
||||
def __close__(self):
|
||||
"""
|
||||
Override close function
|
||||
"""
|
||||
"""Override close function."""
|
||||
self.csv_file.close()
|
||||
|
||||
def init_acl(self):
|
||||
@@ -1039,7 +1041,6 @@ class Primaite(Env):
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info.
|
||||
|
||||
@@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -87,7 +86,13 @@ def run_stable_baselines3_ppo(
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
|
||||
agent = PPO(
|
||||
PPOMlp,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
seed=env.training_config.seed,
|
||||
)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
@@ -106,8 +111,7 @@ def run_stable_baselines3_ppo(
|
||||
|
||||
for step in range(0, config_values.num_steps):
|
||||
action, _states = agent.predict(
|
||||
obs,
|
||||
deterministic=env.training_config.deterministic
|
||||
obs, deterministic=env.training_config.deterministic
|
||||
)
|
||||
# convert to int if action is a numpy array
|
||||
if isinstance(action, np.ndarray):
|
||||
@@ -146,7 +150,13 @@ def run_stable_baselines3_a2c(
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
|
||||
agent = A2C(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
seed=env.training_config.seed,
|
||||
)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
@@ -164,8 +174,7 @@ def run_stable_baselines3_a2c(
|
||||
|
||||
for step in range(0, config_values.num_steps):
|
||||
action, _states = agent.predict(
|
||||
obs,
|
||||
deterministic=env.training_config.deterministic
|
||||
obs, deterministic=env.training_config.deterministic
|
||||
)
|
||||
# convert to int if action is a numpy array
|
||||
if isinstance(action, np.ndarray):
|
||||
@@ -368,5 +377,3 @@ if __name__ == "__main__":
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class Node:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self):
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
@@ -64,14 +65,14 @@ class Node:
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self):
|
||||
"""Updates the booting count"""
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self):
|
||||
"""Updates the shutdown count"""
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
self.shutting_down_count = 0
|
||||
|
||||
@@ -190,13 +190,15 @@ class ServiceNode(ActiveNode):
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self):
|
||||
"""Updates the resetting counter for any service that are resetting."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self):
|
||||
"""Updates the booting counter for any service that are booting up."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state =SoftwareState.GOOD
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
@@ -17,7 +17,6 @@ def start_jupyter_session():
|
||||
|
||||
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
||||
"""
|
||||
|
||||
if importlib.util.find_spec("jupyter") is not None:
|
||||
jupyter_cmd = "python3 -m jupyter lab"
|
||||
if sys.platform == "win32":
|
||||
|
||||
@@ -3,7 +3,7 @@ import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union, Final
|
||||
from typing import Final, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -30,7 +30,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
|
||||
|
||||
def _get_primaite_env_from_config(
|
||||
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||
):
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
session_timestamp: datetime = datetime.now()
|
||||
@@ -84,7 +84,7 @@ def run_generic(env, config_values):
|
||||
|
||||
|
||||
def compare_file_content(output_a_file_path: str, output_b_file_path: str):
|
||||
"""Function used to check if output of both given files are the same"""
|
||||
"""Function used to check if output of both given files are the same."""
|
||||
with open(output_a_file_path) as f1:
|
||||
with open(output_b_file_path) as f2:
|
||||
f1_content = f1.read()
|
||||
@@ -95,13 +95,15 @@ def compare_file_content(output_a_file_path: str, output_b_file_path: str):
|
||||
# both files have the same content
|
||||
return True
|
||||
# both files have different content
|
||||
print(f"{output_a_file_path} and {output_b_file_path} has different contents")
|
||||
print(
|
||||
f"{output_a_file_path} and {output_b_file_path} has different contents"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
|
||||
"""Function used to check if contents of transaction files are the same"""
|
||||
"""Function used to check if contents of transaction files are the same."""
|
||||
# load output a file
|
||||
data_a = pd.read_csv(output_a_file_path)
|
||||
|
||||
@@ -109,19 +111,17 @@ def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
|
||||
data_b = pd.read_csv(output_b_file_path)
|
||||
|
||||
# remove the time stamp column
|
||||
data_a.drop('Timestamp', inplace=True, axis=1)
|
||||
data_b.drop('Timestamp', inplace=True, axis=1)
|
||||
data_a.drop("Timestamp", inplace=True, axis=1)
|
||||
data_b.drop("Timestamp", inplace=True, axis=1)
|
||||
|
||||
# if the comparison is empty, both files are the same i.e. True
|
||||
return data_a.compare(data_b).empty
|
||||
|
||||
|
||||
class TestSession:
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
laydown_config_path
|
||||
):
|
||||
"""Class that contains session values."""
|
||||
|
||||
def __init__(self, training_config_path, laydown_config_path):
|
||||
self.session_timestamp: Final[datetime] = datetime.now()
|
||||
self.session_dir = _get_session_path(self.session_timestamp)
|
||||
self.timestamp_str = self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
@@ -140,5 +140,8 @@ class TestSession:
|
||||
print("Writing Session Metadata file...")
|
||||
|
||||
_write_session_metadata_file(
|
||||
session_dir=self.session_dir, uuid="test", session_timestamp=self.session_timestamp, env=self.env
|
||||
session_dir=self.session_dir,
|
||||
uuid="test",
|
||||
session_timestamp=self.session_timestamp,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,13 @@ import time
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.main import run_stable_baselines3_a2c, \
|
||||
run_stable_baselines3_ppo, run_generic, _update_session_metadata_file, _get_session_path
|
||||
from primaite.main import (
|
||||
_get_session_path,
|
||||
_update_session_metadata_file,
|
||||
run_generic,
|
||||
run_stable_baselines3_a2c,
|
||||
run_stable_baselines3_ppo,
|
||||
)
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import TestSession, compare_file_content, compare_transaction_file
|
||||
@@ -12,18 +17,17 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def test_generic_same_results():
|
||||
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same"""
|
||||
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same."""
|
||||
print("")
|
||||
print("=======================")
|
||||
print("Generic test run")
|
||||
print("=======================")
|
||||
print("")
|
||||
|
||||
|
||||
# run session 1
|
||||
session1 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
config_values = session1.env.training_config
|
||||
@@ -31,17 +35,14 @@ def test_generic_same_results():
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = session1.env.episode_steps
|
||||
|
||||
run_generic(
|
||||
env=session1.env,
|
||||
config_values=session1.env.training_config
|
||||
)
|
||||
run_generic(env=session1.env, config_values=session1.env.training_config)
|
||||
|
||||
_update_session_metadata_file(session_dir=session1.session_dir, env=session1.env)
|
||||
|
||||
# run session 2
|
||||
session2 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
config_values = session2.env.training_config
|
||||
@@ -49,10 +50,7 @@ def test_generic_same_results():
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = session2.env.episode_steps
|
||||
|
||||
run_generic(
|
||||
env=session2.env,
|
||||
config_values=session2.env.training_config
|
||||
)
|
||||
run_generic(env=session2.env, config_values=session2.env.training_config)
|
||||
|
||||
_update_session_metadata_file(session_dir=session2.session_dir, env=session2.env)
|
||||
|
||||
@@ -61,36 +59,41 @@ def test_generic_same_results():
|
||||
time.sleep(1)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_file_content(
|
||||
session1.env.csv_file.name,
|
||||
session2.env.csv_file.name,
|
||||
) is True
|
||||
assert (
|
||||
compare_file_content(
|
||||
session1.env.csv_file.name,
|
||||
session2.env.csv_file.name,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# deterministic run
|
||||
deterministic = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
deterministic.env.training_config.deterministic = True
|
||||
|
||||
run_generic(
|
||||
env=deterministic.env,
|
||||
config_values=deterministic.env.training_config
|
||||
run_generic(env=deterministic.env, config_values=deterministic.env.training_config)
|
||||
|
||||
_update_session_metadata_file(
|
||||
session_dir=deterministic.session_dir, env=deterministic.env
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_file_content(
|
||||
deterministic.env.csv_file.name,
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_generic.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_file_content(
|
||||
deterministic.env.csv_file.name,
|
||||
TEST_CONFIG_ROOT
|
||||
/ "e2e/deterministic_test_outputs/deterministic_generic.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_ppo_same_results():
|
||||
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same"""
|
||||
|
||||
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same."""
|
||||
print("")
|
||||
print("=======================")
|
||||
print("PPO test run")
|
||||
@@ -99,7 +102,7 @@ def test_ppo_same_results():
|
||||
|
||||
training_session = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Train agent
|
||||
@@ -123,12 +126,14 @@ def test_ppo_same_results():
|
||||
timestamp_str=training_session.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=training_session.session_dir, env=training_session.env
|
||||
)
|
||||
|
||||
# Evaluate Agent again
|
||||
eval_session1 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -137,7 +142,10 @@ def test_ppo_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session1.env.training_config.load_agent = True
|
||||
eval_session1.env.training_config.agent_load_file = _get_session_path(training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session1.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session1.env.training_config
|
||||
|
||||
@@ -154,11 +162,13 @@ def test_ppo_same_results():
|
||||
timestamp_str=eval_session1.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session1.session_dir, env=eval_session1.env
|
||||
)
|
||||
|
||||
eval_session2 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -167,8 +177,10 @@ def test_ppo_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session2.env.training_config.load_agent = True
|
||||
eval_session2.env.training_config.agent_load_file = _get_session_path(
|
||||
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session2.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session2.env.training_config
|
||||
|
||||
@@ -185,18 +197,25 @@ def test_ppo_same_results():
|
||||
timestamp_str=eval_session2.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session2.session_dir, env=eval_session2.env
|
||||
)
|
||||
|
||||
# check if both eval outputs are the same
|
||||
assert compare_transaction_file(
|
||||
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
eval_session1.session_dir
|
||||
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir
|
||||
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# deterministic run
|
||||
deterministic = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
deterministic.env.training_config.deterministic = True
|
||||
@@ -214,18 +233,23 @@ def test_ppo_same_results():
|
||||
timestamp_str=deterministic.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=deterministic.session_dir, env=deterministic.env
|
||||
)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_transaction_file(
|
||||
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
deterministic.session_dir
|
||||
/ f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_a2c_same_results():
|
||||
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same"""
|
||||
|
||||
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same."""
|
||||
print("")
|
||||
print("=======================")
|
||||
print("A2C test run")
|
||||
@@ -234,7 +258,7 @@ def test_a2c_same_results():
|
||||
|
||||
training_session = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Train agent
|
||||
@@ -258,12 +282,14 @@ def test_a2c_same_results():
|
||||
timestamp_str=training_session.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=training_session.session_dir, env=training_session.env
|
||||
)
|
||||
|
||||
# Evaluate Agent again
|
||||
eval_session1 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -272,8 +298,10 @@ def test_a2c_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session1.env.training_config.load_agent = True
|
||||
eval_session1.env.training_config.agent_load_file = _get_session_path(
|
||||
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session1.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session1.env.training_config
|
||||
|
||||
@@ -290,11 +318,13 @@ def test_a2c_same_results():
|
||||
timestamp_str=eval_session1.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session1.session_dir, env=eval_session1.env
|
||||
)
|
||||
|
||||
eval_session2 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -303,8 +333,10 @@ def test_a2c_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session2.env.training_config.load_agent = True
|
||||
eval_session2.env.training_config.agent_load_file = _get_session_path(
|
||||
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session2.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session2.env.training_config
|
||||
|
||||
@@ -321,18 +353,25 @@ def test_a2c_same_results():
|
||||
timestamp_str=eval_session2.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session2.session_dir, env=eval_session2.env
|
||||
)
|
||||
|
||||
# check if both eval outputs are the same
|
||||
assert compare_transaction_file(
|
||||
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
eval_session1.session_dir
|
||||
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir
|
||||
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# deterministic run
|
||||
deterministic = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
deterministic.env.training_config.deterministic = True
|
||||
@@ -350,10 +389,16 @@ def test_a2c_same_results():
|
||||
timestamp_str=deterministic.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=deterministic.session_dir, env=deterministic.env
|
||||
)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_transaction_file(
|
||||
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
deterministic.session_dir
|
||||
/ f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@@ -27,7 +27,8 @@ def env(request):
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
@@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite):
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
@@ -140,7 +142,8 @@ class TestNodeLinkTable:
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"starting_operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.RESETTING, HardwareState.ON)
|
||||
],
|
||||
[(HardwareState.RESETTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node resets correctly.
|
||||
"""
|
||||
"""Tests that a node resets correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id = "0",
|
||||
name = "node",
|
||||
node_type = NodeType.COMPUTER,
|
||||
priority = Priority.P1,
|
||||
hardware_state = starting_operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.COMPROMISED,
|
||||
file_system_state = FileSystemState.CORRUPT,
|
||||
config_values=TrainingConfig()
|
||||
node_id="0",
|
||||
name="node",
|
||||
node_type=NodeType.COMPUTER,
|
||||
priority=Priority.P1,
|
||||
hardware_state=starting_operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.COMPROMISED,
|
||||
file_system_state=FileSystemState.CORRUPT,
|
||||
config_values=TrainingConfig(),
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
@@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat
|
||||
assert active_node.file_system_state_actual == FileSystemState.GOOD
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.BOOTING, HardwareState.ON)
|
||||
],
|
||||
[(HardwareState.BOOTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node boots correctly.
|
||||
"""
|
||||
"""Tests that a node boots correctly."""
|
||||
service_node = ServiceNode(
|
||||
node_id = 0,
|
||||
name = "node",
|
||||
node_type = "COMPUTER",
|
||||
priority = "1",
|
||||
hardware_state = operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.GOOD,
|
||||
file_system_state = "GOOD",
|
||||
config_values = 1,
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
service_attributes = Service(
|
||||
name = "node",
|
||||
port = "80",
|
||||
software_state = SoftwareState.COMPROMISED
|
||||
)
|
||||
service_node.add_service(
|
||||
service_attributes
|
||||
name="node", port="80", software_state=SoftwareState.COMPROMISED
|
||||
)
|
||||
service_node.add_service(service_attributes)
|
||||
|
||||
for x in range(5):
|
||||
service_node.update_booting_status()
|
||||
@@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
assert service_attributes.software_state == SoftwareState.GOOD
|
||||
assert service_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.SHUTTING_DOWN, HardwareState.OFF)
|
||||
],
|
||||
[(HardwareState.SHUTTING_DOWN, HardwareState.OFF)],
|
||||
)
|
||||
def test_node_shutdown_correctly(operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node shutdown correctly.
|
||||
"""
|
||||
"""Tests that a node shutdown correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id = 0,
|
||||
name = "node",
|
||||
node_type = "COMPUTER",
|
||||
priority = "1",
|
||||
hardware_state = operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.GOOD,
|
||||
file_system_state = "GOOD",
|
||||
config_values = 1,
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
active_node.update_shutdown_status()
|
||||
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,8 @@ def test_single_action_space_is_valid():
|
||||
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
|
||||
run_generic_set_actions(env)
|
||||
@@ -77,8 +78,10 @@ def test_single_action_space_is_valid():
|
||||
def test_agent_is_executing_actions_from_both_spaces():
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
|
||||
Reference in New Issue
Block a user