diff --git a/.gitignore b/.gitignore index 260a980b..4bb700b2 100644 --- a/.gitignore +++ b/.gitignore @@ -142,4 +142,4 @@ cython_debug/ .idea/ # outputs -src/primaite/outputs/ \ No newline at end of file +src/primaite/outputs/ diff --git a/README.md b/README.md index 78f36fba..7782e8a9 100644 --- a/README.md +++ b/README.md @@ -1 +1,63 @@ # PrimAITE + +## Getting Started with PrimAITE + +### Pre-Requisites + +In order to get **PrimAITE** installed, you will need to have the following installed: + +- `python3.8+` +- `python3-pip` +- `virtualenv` + +**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS. + +### Installation from source +#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv) + +```unix +python3 -m venv +``` + +#### 2. Activate the venv + +##### Unix +```bash +source /bin/activate +``` + +##### Windows +```powershell +.\\Scripts\activate +``` + +#### 3. Install `primaite` into the venv along with all of it's dependencies + +```bash +python3 -m pip install -e . +``` + +### Development Installation +To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example: + +```bash +python3 -m pip install -e .[dev] + +## Building documentation +The PrimAITE documentation can be built with the following commands: + +##### Unix +```bash +cd docs +make html +``` + +##### Windows +```powershell +cd docs +.\make.bat html +``` + +This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build +options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation +for your specific output requirements. diff --git a/docs/source/config.rst b/docs/source/config.rst index 74898ec1..5410a877 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -293,6 +293,14 @@ Rewards are calculated based on the difference between the current state and ref The number of steps to take when scanning the file system +* **deterministic** [bool] + + Set to true if the agent should use deterministic actions. Default is ``False`` + +* **seed** [int] + + Seed used in the randomisation in training / evaluation. Default is ``None`` + The Lay Down Config ******************* diff --git a/docs/source/primaite-dependencies.rst b/docs/source/primaite-dependencies.rst index a7a0ec26..67971d2b 100644 --- a/docs/source/primaite-dependencies.rst +++ b/docs/source/primaite-dependencies.rst @@ -320,4 +320,4 @@ | ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ | zipp | 3.15.0 | MIT License | https://github.com/jaraco/zipp | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ \ No newline at end of file ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------+ diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 1ea110c9..420420f4 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -31,7 +31,7 @@ def _get_primaite_config(): "INFO": logging.INFO, "WARN": logging.WARN, "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL + "CRITICAL": logging.CRITICAL, } primaite_config["log_level"] = log_level_map[primaite_config["log_level"]] return primaite_config diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 19746d01..319d643c 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -3,8 +3,8 @@ import logging import os import shutil -from pathlib import Path from enum import Enum +from pathlib import Path from typing import Optional import pkg_resources @@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): :param last_n: The number of lines to print. Default value is 10. """ import re + from primaite import LOG_PATH if os.path.isfile(LOG_PATH): @@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): print(re.sub(r"\n*", "", line)) -_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa +_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa @app.command() diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 87a473e0..50970d56 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Union, Optional +from typing import Any, Dict, Final, Optional, Union import yaml @@ -173,8 +173,7 @@ def main_training_config_path() -> Path: return path -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -219,9 +218,7 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -233,10 +230,7 @@ def convert_legacy_training_config_dict( don't have action_type values. :return: The converted training config dict. """ - config_dict = { - "num_steps": num_steps, - "action_type": action_type - } + config_dict = {"num_steps": num_steps, "action_type": action_type} for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index a4b69ba6..1955bfb0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,17 +35,14 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction -from primaite.transactions.transactions_to_file import write_transaction_to_file _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) @@ -178,7 +175,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -200,7 +196,6 @@ class Primaite(Env): try: plt.tight_layout() nx.draw_networkx(self.network, with_labels=True) - now = datetime.now() # current date and time file_path = session_path / f"network_{timestamp_str}.png" plt.savefig(file_path, format="PNG") @@ -222,7 +217,9 @@ class Primaite(Env): # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() - self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed) + self.action_space = spaces.Discrete( + len(self.action_dict), seed=self.training_config.seed + ) elif self.training_config.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): @@ -233,13 +230,19 @@ class Primaite(Env): # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) self.action_dict = self.create_acl_action_dict() - self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed) + self.action_space = spaces.Discrete( + len(self.action_dict), seed=self.training_config.seed + ) elif self.training_config.action_type == ActionType.ANY: _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() - self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed) + self.action_space = spaces.Discrete( + len(self.action_dict), seed=self.training_config.seed + ) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -380,7 +383,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + # print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -407,12 +410,11 @@ class Primaite(Env): return self.env_obs, reward, done, self.step_info def close(self): + """Calls the __close__ method.""" self.__close__() def __close__(self): - """ - Override close function - """ + """Override close function.""" self.csv_file.close() def init_acl(self): @@ -1039,7 +1041,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. diff --git a/src/primaite/main.py b/src/primaite/main.py index fc549590..8483f383 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import SESSIONS_DIR, getLogger from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file +from primaite.transactions.transactions_to_file import write_transaction_to_file _LOGGER = getLogger(__name__) @@ -87,7 +86,13 @@ def run_stable_baselines3_ppo( _LOGGER.error("Could not load agent") _LOGGER.error("Exception occured", exc_info=True) else: - agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed) + agent = PPO( + PPOMlp, + env, + verbose=0, + n_steps=config_values.num_steps, + seed=env.training_config.seed, + ) if config_values.session_type == "TRAINING": # We're in a training session @@ -106,8 +111,7 @@ def run_stable_baselines3_ppo( for step in range(0, config_values.num_steps): action, _states = agent.predict( - obs, - deterministic=env.training_config.deterministic + obs, deterministic=env.training_config.deterministic ) # convert to int if action is a numpy array if isinstance(action, np.ndarray): @@ -146,7 +150,13 @@ def run_stable_baselines3_a2c( _LOGGER.error("Could not load agent") _LOGGER.error("Exception occured", exc_info=True) else: - agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed) + agent = A2C( + "MlpPolicy", + env, + verbose=0, + n_steps=config_values.num_steps, + seed=env.training_config.seed, + ) if config_values.session_type == "TRAINING": # We're in a training session @@ -164,8 +174,7 @@ def run_stable_baselines3_a2c( for step in range(0, config_values.num_steps): action, _states = agent.predict( - obs, - deterministic=env.training_config.deterministic + obs, deterministic=env.training_config.deterministic ) # convert to int if action is a numpy array if isinstance(action, np.ndarray): @@ -368,5 +377,3 @@ if __name__ == "__main__": "Please provide a lay down config file using the --ldc " "argument" ) run(training_config_path=args.tc, lay_down_config_path=args.ldc) - - diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index 00cd01c2..bac1792d 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -46,6 +46,7 @@ class Node: """Sets the node state to ON.""" self.hardware_state = HardwareState.BOOTING self.booting_count = self.config_values.node_booting_duration + def turn_off(self): """Sets the node state to OFF.""" self.hardware_state = HardwareState.OFF @@ -64,14 +65,14 @@ class Node: self.hardware_state = HardwareState.ON def update_booting_status(self): - """Updates the booting count""" + """Updates the booting count.""" self.booting_count -= 1 if self.booting_count <= 0: self.booting_count = 0 self.hardware_state = HardwareState.ON def update_shutdown_status(self): - """Updates the shutdown count""" + """Updates the shutdown count.""" self.shutting_down_count -= 1 if self.shutting_down_count <= 0: self.shutting_down_count = 0 diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 84a7c587..d4a5c8c8 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -190,13 +190,15 @@ class ServiceNode(ActiveNode): service_value.reduce_patching_count() def update_resetting_status(self): + """Updates the resetting counter for any service that are resetting.""" super().update_resetting_status() if self.resetting_count <= 0: for service in self.services.values(): service.software_state = SoftwareState.GOOD def update_booting_status(self): + """Updates the booting counter for any service that are booting up.""" super().update_booting_status() if self.booting_count <= 0: for service in self.services.values(): - service.software_state =SoftwareState.GOOD + service.software_state = SoftwareState.GOOD diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 6d822961..71ed343e 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -17,7 +17,6 @@ def start_jupyter_session(): .. todo:: Figure out how to get this working for Linux and MacOS too. """ - if importlib.util.find_spec("jupyter") is not None: jupyter_cmd = "python3 -m jupyter lab" if sys.platform == "win32": diff --git a/tests/conftest.py b/tests/conftest.py index 7d552d7c..20ad8b23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import tempfile import time from datetime import datetime from pathlib import Path -from typing import Union, Final +from typing import Final, Union import pandas as pd @@ -30,7 +30,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path: def _get_primaite_env_from_config( - training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] + training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] ): """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() @@ -84,7 +84,7 @@ def run_generic(env, config_values): def compare_file_content(output_a_file_path: str, output_b_file_path: str): - """Function used to check if output of both given files are the same""" + """Function used to check if output of both given files are the same.""" with open(output_a_file_path) as f1: with open(output_b_file_path) as f2: f1_content = f1.read() @@ -95,13 +95,15 @@ def compare_file_content(output_a_file_path: str, output_b_file_path: str): # both files have the same content return True # both files have different content - print(f"{output_a_file_path} and {output_b_file_path} has different contents") + print( + f"{output_a_file_path} and {output_b_file_path} has different contents" + ) return False def compare_transaction_file(output_a_file_path: str, output_b_file_path: str): - """Function used to check if contents of transaction files are the same""" + """Function used to check if contents of transaction files are the same.""" # load output a file data_a = pd.read_csv(output_a_file_path) @@ -109,19 +111,17 @@ def compare_transaction_file(output_a_file_path: str, output_b_file_path: str): data_b = pd.read_csv(output_b_file_path) # remove the time stamp column - data_a.drop('Timestamp', inplace=True, axis=1) - data_b.drop('Timestamp', inplace=True, axis=1) + data_a.drop("Timestamp", inplace=True, axis=1) + data_b.drop("Timestamp", inplace=True, axis=1) # if the comparison is empty, both files are the same i.e. True return data_a.compare(data_b).empty class TestSession: - def __init__( - self, - training_config_path, - laydown_config_path - ): + """Class that contains session values.""" + + def __init__(self, training_config_path, laydown_config_path): self.session_timestamp: Final[datetime] = datetime.now() self.session_dir = _get_session_path(self.session_timestamp) self.timestamp_str = self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") @@ -140,5 +140,8 @@ class TestSession: print("Writing Session Metadata file...") _write_session_metadata_file( - session_dir=self.session_dir, uuid="test", session_timestamp=self.session_timestamp, env=self.env + session_dir=self.session_dir, + uuid="test", + session_timestamp=self.session_timestamp, + env=self.env, ) diff --git a/tests/e2e_integration_tests/test_session_repeatability.py b/tests/e2e_integration_tests/test_session_repeatability.py index 587b8501..a1a8f16a 100644 --- a/tests/e2e_integration_tests/test_session_repeatability.py +++ b/tests/e2e_integration_tests/test_session_repeatability.py @@ -2,8 +2,13 @@ import time from primaite import getLogger from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.main import run_stable_baselines3_a2c, \ - run_stable_baselines3_ppo, run_generic, _update_session_metadata_file, _get_session_path +from primaite.main import ( + _get_session_path, + _update_session_metadata_file, + run_generic, + run_stable_baselines3_a2c, + run_stable_baselines3_ppo, +) from primaite.transactions.transactions_to_file import write_transaction_to_file from tests import TEST_CONFIG_ROOT from tests.conftest import TestSession, compare_file_content, compare_transaction_file @@ -12,18 +17,17 @@ _LOGGER = getLogger(__name__) def test_generic_same_results(): - """Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same""" + """Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same.""" print("") print("=======================") print("Generic test run") print("=======================") print("") - # run session 1 session1 = TestSession( TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) config_values = session1.env.training_config @@ -31,17 +35,14 @@ def test_generic_same_results(): # Get the number of steps (which is stored in the child config file) config_values.num_steps = session1.env.episode_steps - run_generic( - env=session1.env, - config_values=session1.env.training_config - ) + run_generic(env=session1.env, config_values=session1.env.training_config) _update_session_metadata_file(session_dir=session1.session_dir, env=session1.env) # run session 2 session2 = TestSession( TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) config_values = session2.env.training_config @@ -49,10 +50,7 @@ def test_generic_same_results(): # Get the number of steps (which is stored in the child config file) config_values.num_steps = session2.env.episode_steps - run_generic( - env=session2.env, - config_values=session2.env.training_config - ) + run_generic(env=session2.env, config_values=session2.env.training_config) _update_session_metadata_file(session_dir=session2.session_dir, env=session2.env) @@ -61,36 +59,41 @@ def test_generic_same_results(): time.sleep(1) # check if both outputs are the same - assert compare_file_content( - session1.env.csv_file.name, - session2.env.csv_file.name, - ) is True + assert ( + compare_file_content( + session1.env.csv_file.name, + session2.env.csv_file.name, + ) + is True + ) # deterministic run deterministic = TestSession( TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) deterministic.env.training_config.deterministic = True - run_generic( - env=deterministic.env, - config_values=deterministic.env.training_config + run_generic(env=deterministic.env, config_values=deterministic.env.training_config) + + _update_session_metadata_file( + session_dir=deterministic.session_dir, env=deterministic.env ) - _update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env) - # check if both outputs are the same - assert compare_file_content( - deterministic.env.csv_file.name, - TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_generic.csv", - ) is True + assert ( + compare_file_content( + deterministic.env.csv_file.name, + TEST_CONFIG_ROOT + / "e2e/deterministic_test_outputs/deterministic_generic.csv", + ) + is True + ) def test_ppo_same_results(): - """Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same""" - + """Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same.""" print("") print("=======================") print("PPO test run") @@ -99,7 +102,7 @@ def test_ppo_same_results(): training_session = TestSession( TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) # Train agent @@ -123,12 +126,14 @@ def test_ppo_same_results(): timestamp_str=training_session.timestamp_str, ) - _update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env) + _update_session_metadata_file( + session_dir=training_session.session_dir, env=training_session.env + ) # Evaluate Agent again eval_session1 = TestSession( TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) # Get the number of steps (which is stored in the child config file) @@ -137,7 +142,10 @@ def test_ppo_same_results(): # load the agent that was trained previously eval_session1.env.training_config.load_agent = True - eval_session1.env.training_config.agent_load_file = _get_session_path(training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip" + eval_session1.env.training_config.agent_load_file = ( + _get_session_path(training_session.session_timestamp) + / f"agent_saved_{training_session.timestamp_str}.zip" + ) config_values = eval_session1.env.training_config @@ -154,11 +162,13 @@ def test_ppo_same_results(): timestamp_str=eval_session1.timestamp_str, ) - _update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env) + _update_session_metadata_file( + session_dir=eval_session1.session_dir, env=eval_session1.env + ) eval_session2 = TestSession( TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) # Get the number of steps (which is stored in the child config file) @@ -167,8 +177,10 @@ def test_ppo_same_results(): # load the agent that was trained previously eval_session2.env.training_config.load_agent = True - eval_session2.env.training_config.agent_load_file = _get_session_path( - training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip" + eval_session2.env.training_config.agent_load_file = ( + _get_session_path(training_session.session_timestamp) + / f"agent_saved_{training_session.timestamp_str}.zip" + ) config_values = eval_session2.env.training_config @@ -185,18 +197,25 @@ def test_ppo_same_results(): timestamp_str=eval_session2.timestamp_str, ) - _update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env) + _update_session_metadata_file( + session_dir=eval_session2.session_dir, env=eval_session2.env + ) # check if both eval outputs are the same - assert compare_transaction_file( - eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv", - eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv", - ) is True + assert ( + compare_transaction_file( + eval_session1.session_dir + / f"all_transactions_{eval_session1.timestamp_str}.csv", + eval_session2.session_dir + / f"all_transactions_{eval_session2.timestamp_str}.csv", + ) + is True + ) # deterministic run deterministic = TestSession( TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) deterministic.env.training_config.deterministic = True @@ -214,18 +233,23 @@ def test_ppo_same_results(): timestamp_str=deterministic.timestamp_str, ) - _update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env) + _update_session_metadata_file( + session_dir=deterministic.session_dir, env=deterministic.env + ) # check if both outputs are the same - assert compare_transaction_file( - deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv", - TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv", - ) is True + assert ( + compare_transaction_file( + deterministic.session_dir + / f"all_transactions_{deterministic.timestamp_str}.csv", + TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv", + ) + is True + ) def test_a2c_same_results(): - """Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same""" - + """Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same.""" print("") print("=======================") print("A2C test run") @@ -234,7 +258,7 @@ def test_a2c_same_results(): training_session = TestSession( TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) # Train agent @@ -258,12 +282,14 @@ def test_a2c_same_results(): timestamp_str=training_session.timestamp_str, ) - _update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env) + _update_session_metadata_file( + session_dir=training_session.session_dir, env=training_session.env + ) # Evaluate Agent again eval_session1 = TestSession( TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) # Get the number of steps (which is stored in the child config file) @@ -272,8 +298,10 @@ def test_a2c_same_results(): # load the agent that was trained previously eval_session1.env.training_config.load_agent = True - eval_session1.env.training_config.agent_load_file = _get_session_path( - training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip" + eval_session1.env.training_config.agent_load_file = ( + _get_session_path(training_session.session_timestamp) + / f"agent_saved_{training_session.timestamp_str}.zip" + ) config_values = eval_session1.env.training_config @@ -290,11 +318,13 @@ def test_a2c_same_results(): timestamp_str=eval_session1.timestamp_str, ) - _update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env) + _update_session_metadata_file( + session_dir=eval_session1.session_dir, env=eval_session1.env + ) eval_session2 = TestSession( TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) # Get the number of steps (which is stored in the child config file) @@ -303,8 +333,10 @@ def test_a2c_same_results(): # load the agent that was trained previously eval_session2.env.training_config.load_agent = True - eval_session2.env.training_config.agent_load_file = _get_session_path( - training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip" + eval_session2.env.training_config.agent_load_file = ( + _get_session_path(training_session.session_timestamp) + / f"agent_saved_{training_session.timestamp_str}.zip" + ) config_values = eval_session2.env.training_config @@ -321,18 +353,25 @@ def test_a2c_same_results(): timestamp_str=eval_session2.timestamp_str, ) - _update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env) + _update_session_metadata_file( + session_dir=eval_session2.session_dir, env=eval_session2.env + ) # check if both eval outputs are the same - assert compare_transaction_file( - eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv", - eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv", - ) is True + assert ( + compare_transaction_file( + eval_session1.session_dir + / f"all_transactions_{eval_session1.timestamp_str}.csv", + eval_session2.session_dir + / f"all_transactions_{eval_session2.timestamp_str}.csv", + ) + is True + ) # deterministic run deterministic = TestSession( TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml", - data_manipulation_config_path() + data_manipulation_config_path(), ) deterministic.env.training_config.deterministic = True @@ -350,10 +389,16 @@ def test_a2c_same_results(): timestamp_str=deterministic.timestamp_str, ) - _update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env) + _update_session_metadata_file( + session_dir=deterministic.session_dir, env=deterministic.env + ) # check if both outputs are the same - assert compare_transaction_file( - deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv", - TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv", - ) is True + assert ( + compare_transaction_file( + deterministic.session_dir + / f"all_transactions_{deterministic.timestamp_str}.csv", + TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv", + ) + is True + ) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index dbcdf2d6..efca7b0b 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -27,7 +27,8 @@ def env(request): @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_without_obs.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) @@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite): @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_without_obs.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) @@ -140,7 +142,8 @@ class TestNodeLinkTable: @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_NODE_STATUSES.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index b2843f7f..abe8115c 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,7 +1,13 @@ """Used to test Active Node functions.""" import pytest -from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority +from primaite.common.enums import ( + FileSystemState, + HardwareState, + NodeType, + Priority, + SoftwareState, +) from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode @pytest.mark.parametrize( "starting_operating_state, expected_operating_state", - [ - (HardwareState.RESETTING, HardwareState.ON) - ], + [(HardwareState.RESETTING, HardwareState.ON)], ) def test_node_resets_correctly(starting_operating_state, expected_operating_state): - """ - Tests that a node resets correctly. - """ + """Tests that a node resets correctly.""" active_node = ActiveNode( - node_id = "0", - name = "node", - node_type = NodeType.COMPUTER, - priority = Priority.P1, - hardware_state = starting_operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.COMPROMISED, - file_system_state = FileSystemState.CORRUPT, - config_values=TrainingConfig() + node_id="0", + name="node", + node_type=NodeType.COMPUTER, + priority=Priority.P1, + hardware_state=starting_operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.COMPROMISED, + file_system_state=FileSystemState.CORRUPT, + config_values=TrainingConfig(), ) for x in range(5): @@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat assert active_node.file_system_state_actual == FileSystemState.GOOD assert active_node.hardware_state == expected_operating_state + @pytest.mark.parametrize( "operating_state, expected_operating_state", - [ - (HardwareState.BOOTING, HardwareState.ON) - ], + [(HardwareState.BOOTING, HardwareState.ON)], ) def test_node_boots_correctly(operating_state, expected_operating_state): - """ - Tests that a node boots correctly. - """ + """Tests that a node boots correctly.""" service_node = ServiceNode( - node_id = 0, - name = "node", - node_type = "COMPUTER", - priority = "1", - hardware_state = operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.GOOD, - file_system_state = "GOOD", - config_values = 1, + node_id=0, + name="node", + node_type="COMPUTER", + priority="1", + hardware_state=operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.GOOD, + file_system_state="GOOD", + config_values=1, ) service_attributes = Service( - name = "node", - port = "80", - software_state = SoftwareState.COMPROMISED - ) - service_node.add_service( - service_attributes + name="node", port="80", software_state=SoftwareState.COMPROMISED ) + service_node.add_service(service_attributes) for x in range(5): service_node.update_booting_status() @@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state): assert service_attributes.software_state == SoftwareState.GOOD assert service_node.hardware_state == expected_operating_state + @pytest.mark.parametrize( "operating_state, expected_operating_state", - [ - (HardwareState.SHUTTING_DOWN, HardwareState.OFF) - ], + [(HardwareState.SHUTTING_DOWN, HardwareState.OFF)], ) def test_node_shutdown_correctly(operating_state, expected_operating_state): - """ - Tests that a node shutdown correctly. - """ + """Tests that a node shutdown correctly.""" active_node = ActiveNode( - node_id = 0, - name = "node", - node_type = "COMPUTER", - priority = "1", - hardware_state = operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.GOOD, - file_system_state = "GOOD", - config_values = 1, + node_id=0, + name="node", + node_type="COMPUTER", + priority="1", + hardware_state=operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.GOOD, + file_system_state="GOOD", + config_values=1, ) for x in range(5): active_node.update_shutdown_status() assert active_node.hardware_state == expected_operating_state - - diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 16b9d03e..8ff43fe6 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -48,7 +48,8 @@ def test_single_action_space_is_valid(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" env = _get_primaite_env_from_config( training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", ) run_generic_set_actions(env) @@ -77,8 +78,10 @@ def test_single_action_space_is_valid(): def test_agent_is_executing_actions_from_both_spaces(): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", ) # Run environment with specified fixed blue agent actions only run_generic_set_actions(env)