From cd991a7d61d7c1ddd4bb991c18a58bac2942c732 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 11:10:21 +0100 Subject: [PATCH 1/9] Fix reference IERs --- src/primaite/environment/primaite_env.py | 18 +++++++++++++++++- src/primaite/environment/reward.py | 10 +++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..1307a930 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -109,6 +109,7 @@ class Primaite(Env): # Create a dictionary to hold all the green IERs (this will come from an external source) self.green_iers: Dict[str, IER] = {} + self.green_iers_reference: Dict[str, IER] = {} # Create a dictionary to hold all the node PoLs (this will come from an external source) self.node_pol = {} @@ -310,6 +311,9 @@ class Primaite(Env): # Need to clear traffic on all links first for link_key, link_value in self.links.items(): link_value.clear_traffic() + + for link in self.links_reference.values(): + link.clear_traffic() # Create a Transaction (metric) object for this step transaction = Transaction( @@ -348,7 +352,7 @@ class Primaite(Env): self.network_reference, self.nodes_reference, self.links_reference, - self.green_iers, + self.green_iers_reference, self.acl, self.step_count, ) # Network PoL @@ -375,6 +379,7 @@ class Primaite(Env): self.nodes_post_red, self.nodes_reference, self.green_iers, + self.green_iers_reference, self.red_iers, self.step_count, self.training_config, @@ -866,6 +871,17 @@ class Primaite(Env): ier_destination, ier_mission_criticality, ) + self.green_iers_reference[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + ier_source, + ier_destination, + ier_mission_criticality, + ) def create_red_ier(self, item): """ diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index a620f9b3..777dcf74 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -6,6 +6,9 @@ from primaite.common.enums import FileSystemState, HardwareState, SoftwareState from primaite.common.service import Service from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode +from primaite import getLogger + +_LOGGER = getLogger(__name__) def calculate_reward_function( @@ -13,6 +16,7 @@ def calculate_reward_function( final_nodes, reference_nodes, green_iers, + green_iers_reference, red_iers, step_count, config_values, @@ -68,11 +72,15 @@ def calculate_reward_function( reward_value += config_values.red_ier_running # Go through each green IER - penalise if it's not running (weighted) + # but only if it's supposed to be running (it's running in reference) for ier_key, ier_value in green_iers.items(): + ref_ier = green_iers_reference[ier_key] start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - if not ier_value.get_is_running(): + if not ier_value.get_is_running() and ref_ier.get_is_running(): + # what should happen if reference IER is blocked but live IER is running? + _LOGGER.debug(f"Applying penalty of {config_values.green_ier_blocked * ier_value.get_mission_criticality()} due to IER {ier_key} being blocked") reward_value += ( config_values.green_ier_blocked * ier_value.get_mission_criticality() From 3774fb8319dde057fad82743610781aa1c3a0c07 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 11:20:18 +0100 Subject: [PATCH 2/9] apply pre-commits --- src/primaite/environment/primaite_env.py | 24 +++++++++++------------- src/primaite/environment/reward.py | 10 ++++++++-- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 1307a930..bdfe00dd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -178,7 +176,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -200,7 +197,7 @@ class Primaite(Env): try: plt.tight_layout() nx.draw_networkx(self.network, with_labels=True) - now = datetime.now() # current date and time + # now = datetime.now() # current date and time file_path = session_path / f"network_{timestamp_str}.png" plt.savefig(file_path, format="PNG") @@ -239,7 +236,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -311,7 +310,7 @@ class Primaite(Env): # Need to clear traffic on all links first for link_key, link_value in self.links.items(): link_value.clear_traffic() - + for link in self.links_reference.values(): link.clear_traffic() @@ -384,7 +383,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + # print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1049,7 +1048,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 777dcf74..f48db259 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -2,11 +2,11 @@ """Implements reward function.""" from typing import Dict +from primaite import getLogger from primaite.common.enums import FileSystemState, HardwareState, SoftwareState from primaite.common.service import Service from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode -from primaite import getLogger _LOGGER = getLogger(__name__) @@ -80,7 +80,13 @@ def calculate_reward_function( if step_count >= start_step and step_count <= stop_step: if not ier_value.get_is_running() and ref_ier.get_is_running(): # what should happen if reference IER is blocked but live IER is running? - _LOGGER.debug(f"Applying penalty of {config_values.green_ier_blocked * ier_value.get_mission_criticality()} due to IER {ier_key} being blocked") + _LOGGER.debug( + ( + f"Applying penalty of " + f"{config_values.green_ier_blocked * ier_value.get_mission_criticality()} " + f"due to IER {ier_key} being blocked" + ) + ) reward_value += ( config_values.green_ier_blocked * ier_value.get_mission_criticality() From dc43e5dc15dd345bc75c2d68754f8e695cb2b62a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 10:45:45 +0000 Subject: [PATCH 3/9] rename to prevent confusion --- src/primaite/environment/reward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index f48db259..00ae3528 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -74,11 +74,11 @@ def calculate_reward_function( # Go through each green IER - penalise if it's not running (weighted) # but only if it's supposed to be running (it's running in reference) for ier_key, ier_value in green_iers.items(): - ref_ier = green_iers_reference[ier_key] + reference_ier = green_iers_reference[ier_key] start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - if not ier_value.get_is_running() and ref_ier.get_is_running(): + if not ier_value.get_is_running() and reference_ier.get_is_running(): # what should happen if reference IER is blocked but live IER is running? _LOGGER.debug( ( From cdeb6abf607c81d4f94fa44d6560be0373ddf098 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 12:44:42 +0100 Subject: [PATCH 4/9] More descriptive debug msg --- src/primaite/environment/reward.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 00ae3528..0befd547 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -78,19 +78,23 @@ def calculate_reward_function( start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - if not ier_value.get_is_running() and reference_ier.get_is_running(): - # what should happen if reference IER is blocked but live IER is running? - _LOGGER.debug( - ( - f"Applying penalty of " - f"{config_values.green_ier_blocked * ier_value.get_mission_criticality()} " - f"due to IER {ier_key} being blocked" + if not ier_value.get_is_running(): + if reference_ier.get_is_running(): + ier_reward = ( + config_values.green_ier_blocked + * ier_value.get_mission_criticality() + ) + _LOGGER.debug( + f"Applying reward of {ier_reward} because IER {ier_key} is blocked" + ) + reward_value += ier_reward + else: + _LOGGER.debug( + ( + f"IER {ier_key} is blocked in the reference and live environments. " + f"Therefore, no penalty was applied." + ) ) - ) - reward_value += ( - config_values.green_ier_blocked - * ier_value.get_mission_criticality() - ) return reward_value From de91a50581b52bb679d2b032fe03e0fa1e51e576 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 12:56:15 +0100 Subject: [PATCH 5/9] Improve readability --- src/primaite/environment/reward.py | 38 ++++++++++++++++++------------ 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 0befd547..aa9e4503 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -78,23 +78,31 @@ def calculate_reward_function( start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - if not ier_value.get_is_running(): - if reference_ier.get_is_running(): - ier_reward = ( - config_values.green_ier_blocked - * ier_value.get_mission_criticality() + reference_blocked = reference_ier.get_is_running() + live_blocked = ier_value.get_is_running() + ier_reward = ( + config_values.green_ier_blocked * ier_value.get_mission_criticality() + ) + + if live_blocked and not reference_blocked: + _LOGGER.debug( + f"Applying reward of {ier_reward} because IER {ier_key} is blocked" + ) + reward_value += ier_reward + elif live_blocked and reference_blocked: + _LOGGER.debug( + ( + f"IER {ier_key} is blocked in the reference and live environments. " + f"Penalty of {ier_reward} was NOT applied." ) - _LOGGER.debug( - f"Applying reward of {ier_reward} because IER {ier_key} is blocked" - ) - reward_value += ier_reward - else: - _LOGGER.debug( - ( - f"IER {ier_key} is blocked in the reference and live environments. " - f"Therefore, no penalty was applied." - ) + ) + elif not live_blocked and reference_blocked: + _LOGGER.debug( + ( + f"IER {ier_key} is blocked in the reference env but not in the live one. " + f"Penalty of {ier_reward} was NOT applied." ) + ) return reward_value From beae1e5c4f503016d9b802d9fce50ede1c2bd064 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 13:06:10 +0100 Subject: [PATCH 6/9] Cosmetic changes to satisfy pre-commit --- src/primaite/__init__.py | 2 +- src/primaite/cli.py | 5 +- src/primaite/config/training_config.py | 14 +--- src/primaite/environment/primaite_env.py | 21 +++-- src/primaite/main.py | 5 +- src/primaite/nodes/node.py | 5 +- src/primaite/nodes/service_node.py | 4 +- src/primaite/notebooks/__init__.py | 1 - tests/test_observation_space.py | 9 ++- tests/test_resetting_node.py | 98 +++++++++++------------- tests/test_single_action_space.py | 9 ++- 11 files changed, 80 insertions(+), 93 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 1ea110c9..420420f4 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -31,7 +31,7 @@ def _get_primaite_config(): "INFO": logging.INFO, "WARN": logging.WARN, "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL + "CRITICAL": logging.CRITICAL, } primaite_config["log_level"] = log_level_map[primaite_config["log_level"]] return primaite_config diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 19746d01..319d643c 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -3,8 +3,8 @@ import logging import os import shutil -from pathlib import Path from enum import Enum +from pathlib import Path from typing import Optional import pkg_resources @@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): :param last_n: The number of lines to print. Default value is 10. """ import re + from primaite import LOG_PATH if os.path.isfile(LOG_PATH): @@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): print(re.sub(r"\n*", "", line)) -_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa +_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa @app.command() diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..7bafc910 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Union, Optional +from typing import Any, Dict, Final, Optional, Union import yaml @@ -167,8 +167,7 @@ def main_training_config_path() -> Path: return path -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -213,9 +212,7 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -227,10 +224,7 @@ def convert_legacy_training_config_dict( don't have action_type values. :return: The converted training config dict. """ - config_dict = { - "num_steps": num_steps, - "action_type": action_type - } + config_dict = {"num_steps": num_steps, "action_type": action_type} for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..7608e7db 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -177,7 +175,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -199,7 +196,6 @@ class Primaite(Env): try: plt.tight_layout() nx.draw_networkx(self.network, with_labels=True) - now = datetime.now() # current date and time file_path = session_path / f"network_{timestamp_str}.png" plt.savefig(file_path, format="PNG") @@ -238,7 +234,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -379,7 +377,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + # print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1033,7 +1031,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. diff --git a/src/primaite/main.py b/src/primaite/main.py index ac32a018..f5e94509 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import SESSIONS_DIR, getLogger from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file +from primaite.transactions.transactions_to_file import write_transaction_to_file _LOGGER = getLogger(__name__) @@ -349,5 +348,3 @@ if __name__ == "__main__": "Please provide a lay down config file using the --ldc " "argument" ) run(training_config_path=args.tc, lay_down_config_path=args.ldc) - - diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index 00cd01c2..bac1792d 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -46,6 +46,7 @@ class Node: """Sets the node state to ON.""" self.hardware_state = HardwareState.BOOTING self.booting_count = self.config_values.node_booting_duration + def turn_off(self): """Sets the node state to OFF.""" self.hardware_state = HardwareState.OFF @@ -64,14 +65,14 @@ class Node: self.hardware_state = HardwareState.ON def update_booting_status(self): - """Updates the booting count""" + """Updates the booting count.""" self.booting_count -= 1 if self.booting_count <= 0: self.booting_count = 0 self.hardware_state = HardwareState.ON def update_shutdown_status(self): - """Updates the shutdown count""" + """Updates the shutdown count.""" self.shutting_down_count -= 1 if self.shutting_down_count <= 0: self.shutting_down_count = 0 diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 84a7c587..324592c3 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -190,13 +190,15 @@ class ServiceNode(ActiveNode): service_value.reduce_patching_count() def update_resetting_status(self): + """Update resetting counter and set software state if it reached 0.""" super().update_resetting_status() if self.resetting_count <= 0: for service in self.services.values(): service.software_state = SoftwareState.GOOD def update_booting_status(self): + """Update booting counter and set software to good if it reached 0.""" super().update_booting_status() if self.booting_count <= 0: for service in self.services.values(): - service.software_state =SoftwareState.GOOD + service.software_state = SoftwareState.GOOD diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 6d822961..71ed343e 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -17,7 +17,6 @@ def start_jupyter_session(): .. todo:: Figure out how to get this working for Linux and MacOS too. """ - if importlib.util.find_spec("jupyter") is not None: jupyter_cmd = "python3 -m jupyter lab" if sys.platform == "win32": diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index dbcdf2d6..efca7b0b 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -27,7 +27,8 @@ def env(request): @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_without_obs.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) @@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite): @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_without_obs.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) @@ -140,7 +142,8 @@ class TestNodeLinkTable: @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_NODE_STATUSES.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index b2843f7f..abe8115c 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,7 +1,13 @@ """Used to test Active Node functions.""" import pytest -from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority +from primaite.common.enums import ( + FileSystemState, + HardwareState, + NodeType, + Priority, + SoftwareState, +) from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode @pytest.mark.parametrize( "starting_operating_state, expected_operating_state", - [ - (HardwareState.RESETTING, HardwareState.ON) - ], + [(HardwareState.RESETTING, HardwareState.ON)], ) def test_node_resets_correctly(starting_operating_state, expected_operating_state): - """ - Tests that a node resets correctly. - """ + """Tests that a node resets correctly.""" active_node = ActiveNode( - node_id = "0", - name = "node", - node_type = NodeType.COMPUTER, - priority = Priority.P1, - hardware_state = starting_operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.COMPROMISED, - file_system_state = FileSystemState.CORRUPT, - config_values=TrainingConfig() + node_id="0", + name="node", + node_type=NodeType.COMPUTER, + priority=Priority.P1, + hardware_state=starting_operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.COMPROMISED, + file_system_state=FileSystemState.CORRUPT, + config_values=TrainingConfig(), ) for x in range(5): @@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat assert active_node.file_system_state_actual == FileSystemState.GOOD assert active_node.hardware_state == expected_operating_state + @pytest.mark.parametrize( "operating_state, expected_operating_state", - [ - (HardwareState.BOOTING, HardwareState.ON) - ], + [(HardwareState.BOOTING, HardwareState.ON)], ) def test_node_boots_correctly(operating_state, expected_operating_state): - """ - Tests that a node boots correctly. - """ + """Tests that a node boots correctly.""" service_node = ServiceNode( - node_id = 0, - name = "node", - node_type = "COMPUTER", - priority = "1", - hardware_state = operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.GOOD, - file_system_state = "GOOD", - config_values = 1, + node_id=0, + name="node", + node_type="COMPUTER", + priority="1", + hardware_state=operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.GOOD, + file_system_state="GOOD", + config_values=1, ) service_attributes = Service( - name = "node", - port = "80", - software_state = SoftwareState.COMPROMISED - ) - service_node.add_service( - service_attributes + name="node", port="80", software_state=SoftwareState.COMPROMISED ) + service_node.add_service(service_attributes) for x in range(5): service_node.update_booting_status() @@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state): assert service_attributes.software_state == SoftwareState.GOOD assert service_node.hardware_state == expected_operating_state + @pytest.mark.parametrize( "operating_state, expected_operating_state", - [ - (HardwareState.SHUTTING_DOWN, HardwareState.OFF) - ], + [(HardwareState.SHUTTING_DOWN, HardwareState.OFF)], ) def test_node_shutdown_correctly(operating_state, expected_operating_state): - """ - Tests that a node shutdown correctly. - """ + """Tests that a node shutdown correctly.""" active_node = ActiveNode( - node_id = 0, - name = "node", - node_type = "COMPUTER", - priority = "1", - hardware_state = operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.GOOD, - file_system_state = "GOOD", - config_values = 1, + node_id=0, + name="node", + node_type="COMPUTER", + priority="1", + hardware_state=operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.GOOD, + file_system_state="GOOD", + config_values=1, ) for x in range(5): active_node.update_shutdown_status() assert active_node.hardware_state == expected_operating_state - - diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 16b9d03e..8ff43fe6 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -48,7 +48,8 @@ def test_single_action_space_is_valid(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" env = _get_primaite_env_from_config( training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", ) run_generic_set_actions(env) @@ -77,8 +78,10 @@ def test_single_action_space_is_valid(): def test_agent_is_executing_actions_from_both_spaces(): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", ) # Run environment with specified fixed blue agent actions only run_generic_set_actions(env) From a3e50293b7d74c63fdfc455494db0c27ada81148 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 12:07:33 +0000 Subject: [PATCH 7/9] Add pre-commits to build pipeline --- .azure/azure-ci-build-pipeline.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index dd45907d..ceda11c6 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -25,6 +25,11 @@ steps: versionSpec: '$(python.version)' displayName: 'Use Python $(python.version)' +- script: | + pre-commit install + pre-commit run --all-files + displayName: 'Run pre-commits' + - script: | python -m pip install --upgrade pip==23.0.1 pip install wheel==0.38.4 --upgrade From 33f7e9f5066a057a572ee93d174b95131f7abe62 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 13:07:54 +0000 Subject: [PATCH 8/9] Add pre-commit --- .azure/azure-ci-build-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index ceda11c6..8bfdca02 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -26,6 +26,7 @@ steps: displayName: 'Use Python $(python.version)' - script: | + python -m pip install pre-commit pre-commit install pre-commit run --all-files displayName: 'Run pre-commits' From 349a18a4eb0ef067442bc7ec8cfb8f1167ee7a53 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 15:27:56 +0100 Subject: [PATCH 9/9] Fix ier reward calculation --- src/primaite/environment/reward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index aa9e4503..1a1a0770 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -78,8 +78,8 @@ def calculate_reward_function( start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - reference_blocked = reference_ier.get_is_running() - live_blocked = ier_value.get_is_running() + reference_blocked = not reference_ier.get_is_running() + live_blocked = not ier_value.get_is_running() ier_reward = ( config_values.green_ier_blocked * ier_value.get_mission_criticality() )