#1386: added documentation + dealing with pre-commit checks

This commit is contained in:
Czar Echavez
2023-06-20 11:19:05 +01:00
parent 9fb30ffe1b
commit 99399cbda6
17 changed files with 311 additions and 192 deletions

2
.gitignore vendored
View File

@@ -142,4 +142,4 @@ cython_debug/
.idea/
# outputs
src/primaite/outputs/
src/primaite/outputs/

View File

@@ -1 +1,63 @@
# PrimAITE
## Getting Started with PrimAITE
### Pre-Requisites
In order to get **PrimAITE** installed, you will need to have the following installed:
- `python3.8+`
- `python3-pip`
- `virtualenv`
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
### Installation from source
#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv)
```unix
python3 -m venv <name_of_venv>
```
#### 2. Activate the venv
##### Unix
```bash
source <name_of_venv>/bin/activate
```
##### Windows
```powershell
.\<name_of_venv>\Scripts\activate
```
#### 3. Install `primaite` into the venv along with all of it's dependencies
```bash
python3 -m pip install -e .
```
### Development Installation
To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example:
```bash
python3 -m pip install -e .[dev]
## Building documentation
The PrimAITE documentation can be built with the following commands:
##### Unix
```bash
cd docs
make html
```
##### Windows
```powershell
cd docs
.\make.bat html
```
This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build
options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation
for your specific output requirements.

View File

@@ -293,6 +293,14 @@ Rewards are calculated based on the difference between the current state and ref
The number of steps to take when scanning the file system
* **deterministic** [bool]
Set to true if the agent should use deterministic actions. Default is ``False``
* **seed** [int]
Seed used in the randomisation in training / evaluation. Default is ``None``
The Lay Down Config
*******************

View File

@@ -320,4 +320,4 @@
| ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+
| zipp | 3.15.0 | MIT License | https://github.com/jaraco/zipp |
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+
+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+

View File

@@ -31,7 +31,7 @@ def _get_primaite_config():
"INFO": logging.INFO,
"WARN": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
return primaite_config

View File

@@ -3,8 +3,8 @@
import logging
import os
import shutil
from pathlib import Path
from enum import Enum
from pathlib import Path
from typing import Optional
import pkg_resources
@@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
:param last_n: The number of lines to print. Default value is 10.
"""
import re
from primaite import LOG_PATH
if os.path.isfile(LOG_PATH):
@@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
print(re.sub(r"\n*", "", line))
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
@app.command()

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Union, Optional
from typing import Any, Dict, Final, Optional, Union
import yaml
@@ -173,8 +173,7 @@ def main_training_config_path() -> Path:
return path
def load(file_path: Union[str, Path],
legacy_file: bool = False) -> TrainingConfig:
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
"""
Read in a training config yaml file.
@@ -219,9 +218,7 @@ def load(file_path: Union[str, Path],
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
num_steps: int = 256,
action_type: str = "ANY"
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
@@ -233,10 +230,7 @@ def convert_legacy_training_config_dict(
don't have action_type values.
:return: The converted training config dict.
"""
config_dict = {
"num_steps": num_steps,
"action_type": action_type
}
config_dict = {"num_steps": num_steps, "action_type": action_type}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:

View File

@@ -14,8 +14,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, \
is_valid_node_action
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -24,8 +23,9 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState, ObservationType,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config import training_config
@@ -35,17 +35,14 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import \
NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
@@ -178,7 +175,6 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
@@ -200,7 +196,6 @@ class Primaite(Env):
try:
plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True)
now = datetime.now() # current date and time
file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG")
@@ -222,7 +217,9 @@ class Primaite(Env):
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
self.action_space = spaces.Discrete(
len(self.action_dict), seed=self.training_config.seed
)
elif self.training_config.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
@@ -233,13 +230,19 @@ class Primaite(Env):
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
self.action_space = spaces.Discrete(
len(self.action_dict), seed=self.training_config.seed
)
elif self.training_config.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
self.action_space = spaces.Discrete(
len(self.action_dict), seed=self.training_config.seed
)
else:
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
_LOGGER.info(
f"Invalid action type selected: {self.training_config.action_type}"
)
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
@@ -380,7 +383,7 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
#print(f" Step {self.step_count} Reward: {str(reward)}")
# print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -407,12 +410,11 @@ class Primaite(Env):
return self.env_obs, reward, done, self.step_info
def close(self):
"""Calls the __close__ method."""
self.__close__()
def __close__(self):
"""
Override close function
"""
"""Override close function."""
self.csv_file.close()
def init_acl(self):
@@ -1039,7 +1041,6 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
"""
Extracts action_info.

View File

@@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = getLogger(__name__)
@@ -87,7 +86,13 @@ def run_stable_baselines3_ppo(
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
agent = PPO(
PPOMlp,
env,
verbose=0,
n_steps=config_values.num_steps,
seed=env.training_config.seed,
)
if config_values.session_type == "TRAINING":
# We're in a training session
@@ -106,8 +111,7 @@ def run_stable_baselines3_ppo(
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs,
deterministic=env.training_config.deterministic
obs, deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
@@ -146,7 +150,13 @@ def run_stable_baselines3_a2c(
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
agent = A2C(
"MlpPolicy",
env,
verbose=0,
n_steps=config_values.num_steps,
seed=env.training_config.seed,
)
if config_values.session_type == "TRAINING":
# We're in a training session
@@ -164,8 +174,7 @@ def run_stable_baselines3_a2c(
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs,
deterministic=env.training_config.deterministic
obs, deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
@@ -368,5 +377,3 @@ if __name__ == "__main__":
"Please provide a lay down config file using the --ldc " "argument"
)
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -46,6 +46,7 @@ class Node:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self):
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
@@ -64,14 +65,14 @@ class Node:
self.hardware_state = HardwareState.ON
def update_booting_status(self):
"""Updates the booting count"""
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self):
"""Updates the shutdown count"""
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:
self.shutting_down_count = 0

View File

@@ -190,13 +190,15 @@ class ServiceNode(ActiveNode):
service_value.reduce_patching_count()
def update_resetting_status(self):
"""Updates the resetting counter for any service that are resetting."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self):
"""Updates the booting counter for any service that are booting up."""
super().update_booting_status()
if self.booting_count <= 0:
for service in self.services.values():
service.software_state =SoftwareState.GOOD
service.software_state = SoftwareState.GOOD

View File

@@ -17,7 +17,6 @@ def start_jupyter_session():
.. todo:: Figure out how to get this working for Linux and MacOS too.
"""
if importlib.util.find_spec("jupyter") is not None:
jupyter_cmd = "python3 -m jupyter lab"
if sys.platform == "win32":

View File

@@ -3,7 +3,7 @@ import tempfile
import time
from datetime import datetime
from pathlib import Path
from typing import Union, Final
from typing import Final, Union
import pandas as pd
@@ -30,7 +30,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
def _get_primaite_env_from_config(
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
@@ -84,7 +84,7 @@ def run_generic(env, config_values):
def compare_file_content(output_a_file_path: str, output_b_file_path: str):
"""Function used to check if output of both given files are the same"""
"""Function used to check if output of both given files are the same."""
with open(output_a_file_path) as f1:
with open(output_b_file_path) as f2:
f1_content = f1.read()
@@ -95,13 +95,15 @@ def compare_file_content(output_a_file_path: str, output_b_file_path: str):
# both files have the same content
return True
# both files have different content
print(f"{output_a_file_path} and {output_b_file_path} has different contents")
print(
f"{output_a_file_path} and {output_b_file_path} has different contents"
)
return False
def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
"""Function used to check if contents of transaction files are the same"""
"""Function used to check if contents of transaction files are the same."""
# load output a file
data_a = pd.read_csv(output_a_file_path)
@@ -109,19 +111,17 @@ def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
data_b = pd.read_csv(output_b_file_path)
# remove the time stamp column
data_a.drop('Timestamp', inplace=True, axis=1)
data_b.drop('Timestamp', inplace=True, axis=1)
data_a.drop("Timestamp", inplace=True, axis=1)
data_b.drop("Timestamp", inplace=True, axis=1)
# if the comparison is empty, both files are the same i.e. True
return data_a.compare(data_b).empty
class TestSession:
def __init__(
self,
training_config_path,
laydown_config_path
):
"""Class that contains session values."""
def __init__(self, training_config_path, laydown_config_path):
self.session_timestamp: Final[datetime] = datetime.now()
self.session_dir = _get_session_path(self.session_timestamp)
self.timestamp_str = self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
@@ -140,5 +140,8 @@ class TestSession:
print("Writing Session Metadata file...")
_write_session_metadata_file(
session_dir=self.session_dir, uuid="test", session_timestamp=self.session_timestamp, env=self.env
session_dir=self.session_dir,
uuid="test",
session_timestamp=self.session_timestamp,
env=self.env,
)

View File

@@ -2,8 +2,13 @@ import time
from primaite import getLogger
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.main import run_stable_baselines3_a2c, \
run_stable_baselines3_ppo, run_generic, _update_session_metadata_file, _get_session_path
from primaite.main import (
_get_session_path,
_update_session_metadata_file,
run_generic,
run_stable_baselines3_a2c,
run_stable_baselines3_ppo,
)
from primaite.transactions.transactions_to_file import write_transaction_to_file
from tests import TEST_CONFIG_ROOT
from tests.conftest import TestSession, compare_file_content, compare_transaction_file
@@ -12,18 +17,17 @@ _LOGGER = getLogger(__name__)
def test_generic_same_results():
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same"""
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same."""
print("")
print("=======================")
print("Generic test run")
print("=======================")
print("")
# run session 1
session1 = TestSession(
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
config_values = session1.env.training_config
@@ -31,17 +35,14 @@ def test_generic_same_results():
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = session1.env.episode_steps
run_generic(
env=session1.env,
config_values=session1.env.training_config
)
run_generic(env=session1.env, config_values=session1.env.training_config)
_update_session_metadata_file(session_dir=session1.session_dir, env=session1.env)
# run session 2
session2 = TestSession(
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
config_values = session2.env.training_config
@@ -49,10 +50,7 @@ def test_generic_same_results():
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = session2.env.episode_steps
run_generic(
env=session2.env,
config_values=session2.env.training_config
)
run_generic(env=session2.env, config_values=session2.env.training_config)
_update_session_metadata_file(session_dir=session2.session_dir, env=session2.env)
@@ -61,36 +59,41 @@ def test_generic_same_results():
time.sleep(1)
# check if both outputs are the same
assert compare_file_content(
session1.env.csv_file.name,
session2.env.csv_file.name,
) is True
assert (
compare_file_content(
session1.env.csv_file.name,
session2.env.csv_file.name,
)
is True
)
# deterministic run
deterministic = TestSession(
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
deterministic.env.training_config.deterministic = True
run_generic(
env=deterministic.env,
config_values=deterministic.env.training_config
run_generic(env=deterministic.env, config_values=deterministic.env.training_config)
_update_session_metadata_file(
session_dir=deterministic.session_dir, env=deterministic.env
)
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
# check if both outputs are the same
assert compare_file_content(
deterministic.env.csv_file.name,
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_generic.csv",
) is True
assert (
compare_file_content(
deterministic.env.csv_file.name,
TEST_CONFIG_ROOT
/ "e2e/deterministic_test_outputs/deterministic_generic.csv",
)
is True
)
def test_ppo_same_results():
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same"""
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same."""
print("")
print("=======================")
print("PPO test run")
@@ -99,7 +102,7 @@ def test_ppo_same_results():
training_session = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Train agent
@@ -123,12 +126,14 @@ def test_ppo_same_results():
timestamp_str=training_session.timestamp_str,
)
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
_update_session_metadata_file(
session_dir=training_session.session_dir, env=training_session.env
)
# Evaluate Agent again
eval_session1 = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -137,7 +142,10 @@ def test_ppo_same_results():
# load the agent that was trained previously
eval_session1.env.training_config.load_agent = True
eval_session1.env.training_config.agent_load_file = _get_session_path(training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session1.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session1.env.training_config
@@ -154,11 +162,13 @@ def test_ppo_same_results():
timestamp_str=eval_session1.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
_update_session_metadata_file(
session_dir=eval_session1.session_dir, env=eval_session1.env
)
eval_session2 = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -167,8 +177,10 @@ def test_ppo_same_results():
# load the agent that was trained previously
eval_session2.env.training_config.load_agent = True
eval_session2.env.training_config.agent_load_file = _get_session_path(
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session2.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session2.env.training_config
@@ -185,18 +197,25 @@ def test_ppo_same_results():
timestamp_str=eval_session2.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
_update_session_metadata_file(
session_dir=eval_session2.session_dir, env=eval_session2.env
)
# check if both eval outputs are the same
assert compare_transaction_file(
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
) is True
assert (
compare_transaction_file(
eval_session1.session_dir
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
)
is True
)
# deterministic run
deterministic = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
deterministic.env.training_config.deterministic = True
@@ -214,18 +233,23 @@ def test_ppo_same_results():
timestamp_str=deterministic.timestamp_str,
)
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
_update_session_metadata_file(
session_dir=deterministic.session_dir, env=deterministic.env
)
# check if both outputs are the same
assert compare_transaction_file(
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
) is True
assert (
compare_transaction_file(
deterministic.session_dir
/ f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
)
is True
)
def test_a2c_same_results():
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same"""
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same."""
print("")
print("=======================")
print("A2C test run")
@@ -234,7 +258,7 @@ def test_a2c_same_results():
training_session = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Train agent
@@ -258,12 +282,14 @@ def test_a2c_same_results():
timestamp_str=training_session.timestamp_str,
)
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
_update_session_metadata_file(
session_dir=training_session.session_dir, env=training_session.env
)
# Evaluate Agent again
eval_session1 = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -272,8 +298,10 @@ def test_a2c_same_results():
# load the agent that was trained previously
eval_session1.env.training_config.load_agent = True
eval_session1.env.training_config.agent_load_file = _get_session_path(
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session1.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session1.env.training_config
@@ -290,11 +318,13 @@ def test_a2c_same_results():
timestamp_str=eval_session1.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
_update_session_metadata_file(
session_dir=eval_session1.session_dir, env=eval_session1.env
)
eval_session2 = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -303,8 +333,10 @@ def test_a2c_same_results():
# load the agent that was trained previously
eval_session2.env.training_config.load_agent = True
eval_session2.env.training_config.agent_load_file = _get_session_path(
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session2.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session2.env.training_config
@@ -321,18 +353,25 @@ def test_a2c_same_results():
timestamp_str=eval_session2.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
_update_session_metadata_file(
session_dir=eval_session2.session_dir, env=eval_session2.env
)
# check if both eval outputs are the same
assert compare_transaction_file(
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
) is True
assert (
compare_transaction_file(
eval_session1.session_dir
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
)
is True
)
# deterministic run
deterministic = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
deterministic.env.training_config.deterministic = True
@@ -350,10 +389,16 @@ def test_a2c_same_results():
timestamp_str=deterministic.timestamp_str,
)
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
_update_session_metadata_file(
session_dir=deterministic.session_dir, env=deterministic.env
)
# check if both outputs are the same
assert compare_transaction_file(
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
) is True
assert (
compare_transaction_file(
deterministic.session_dir
/ f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
)
is True
)

View File

@@ -27,7 +27,8 @@ def env(request):
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
@@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite):
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
@@ -140,7 +142,8 @@ class TestNodeLinkTable:
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_STATUSES.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)

View File

@@ -1,7 +1,13 @@
"""Used to test Active Node functions."""
import pytest
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode
@pytest.mark.parametrize(
"starting_operating_state, expected_operating_state",
[
(HardwareState.RESETTING, HardwareState.ON)
],
[(HardwareState.RESETTING, HardwareState.ON)],
)
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
"""
Tests that a node resets correctly.
"""
"""Tests that a node resets correctly."""
active_node = ActiveNode(
node_id = "0",
name = "node",
node_type = NodeType.COMPUTER,
priority = Priority.P1,
hardware_state = starting_operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.COMPROMISED,
file_system_state = FileSystemState.CORRUPT,
config_values=TrainingConfig()
node_id="0",
name="node",
node_type=NodeType.COMPUTER,
priority=Priority.P1,
hardware_state=starting_operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.COMPROMISED,
file_system_state=FileSystemState.CORRUPT,
config_values=TrainingConfig(),
)
for x in range(5):
@@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat
assert active_node.file_system_state_actual == FileSystemState.GOOD
assert active_node.hardware_state == expected_operating_state
@pytest.mark.parametrize(
"operating_state, expected_operating_state",
[
(HardwareState.BOOTING, HardwareState.ON)
],
[(HardwareState.BOOTING, HardwareState.ON)],
)
def test_node_boots_correctly(operating_state, expected_operating_state):
"""
Tests that a node boots correctly.
"""
"""Tests that a node boots correctly."""
service_node = ServiceNode(
node_id = 0,
name = "node",
node_type = "COMPUTER",
priority = "1",
hardware_state = operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.GOOD,
file_system_state = "GOOD",
config_values = 1,
node_id=0,
name="node",
node_type="COMPUTER",
priority="1",
hardware_state=operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.GOOD,
file_system_state="GOOD",
config_values=1,
)
service_attributes = Service(
name = "node",
port = "80",
software_state = SoftwareState.COMPROMISED
)
service_node.add_service(
service_attributes
name="node", port="80", software_state=SoftwareState.COMPROMISED
)
service_node.add_service(service_attributes)
for x in range(5):
service_node.update_booting_status()
@@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
assert service_attributes.software_state == SoftwareState.GOOD
assert service_node.hardware_state == expected_operating_state
@pytest.mark.parametrize(
"operating_state, expected_operating_state",
[
(HardwareState.SHUTTING_DOWN, HardwareState.OFF)
],
[(HardwareState.SHUTTING_DOWN, HardwareState.OFF)],
)
def test_node_shutdown_correctly(operating_state, expected_operating_state):
"""
Tests that a node shutdown correctly.
"""
"""Tests that a node shutdown correctly."""
active_node = ActiveNode(
node_id = 0,
name = "node",
node_type = "COMPUTER",
priority = "1",
hardware_state = operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.GOOD,
file_system_state = "GOOD",
config_values = 1,
node_id=0,
name="node",
node_type="COMPUTER",
priority="1",
hardware_state=operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.GOOD,
file_system_state="GOOD",
config_values=1,
)
for x in range(5):
active_node.update_shutdown_status()
assert active_node.hardware_state == expected_operating_state

View File

@@ -48,7 +48,8 @@ def test_single_action_space_is_valid():
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
run_generic_set_actions(env)
@@ -77,8 +78,10 @@ def test_single_action_space_is_valid():
def test_agent_is_executing_actions_from_both_spaces():
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env)