From 57315a6789668ae15bb2e7a6af0e07d35bede4eb Mon Sep 17 00:00:00 2001 From: Brian Kanyora Date: Thu, 22 Jun 2023 15:34:13 +0100 Subject: [PATCH 01/21] feature\1522: Create random red agent behaviour. --- src/primaite/config/training_config.py | 17 +- src/primaite/environment/primaite_env.py | 173 ++++++++++++++++-- .../nodes/node_state_instruction_red.py | 17 ++ tests/config/random_agent_main_config.yaml | 96 ++++++++++ tests/test_red_random_agent_behaviour.py | 74 ++++++++ 5 files changed, 356 insertions(+), 21 deletions(-) create mode 100644 tests/config/random_agent_main_config.yaml create mode 100644 tests/test_red_random_agent_behaviour.py diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..6e88e7cb 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Union, Optional +from typing import Any, Dict, Final, Optional, Union import yaml @@ -21,6 +21,9 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." + red_agent_identifier: str = "RANDOM" + "Creates Random Red Agent Attacks" + action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -167,8 +170,7 @@ def main_training_config_path() -> Path: return path -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -213,9 +215,7 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -227,10 +227,7 @@ def convert_legacy_training_config_dict( don't have action_type values. :return: The converted training config dict. """ - config_dict = { - "num_steps": num_steps, - "action_type": action_type - } + config_dict = {"num_steps": num_steps, "action_type": action_type} for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..9161fa43 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -177,7 +175,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -238,7 +235,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -275,6 +274,10 @@ class Primaite(Env): # Does this for both live and reference nodes self.reset_environment() + # Create a random red agent to use for this episode + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() + # Reset counters and totals self.total_reward = 0 self.step_count = 0 @@ -379,7 +382,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1033,7 +1036,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. @@ -1216,3 +1218,152 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict + + def create_random_red_agent(self): + """Decide on random red agent for the episode to be called in env.reset().""" + + # Reset the current red iers and red node pol + self.red_iers = {} + self.red_node_pol = {} + + # Decide how many nodes become compromised + node_list = list(self.nodes.values()) + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len( + computers + ) # only computers can become compromised + # random select between 1 and max_num_nodes_compromised + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + + # Decide which of the nodes to compromise + nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + + # For each of the nodes to be compromised decide which step they become compromised + max_step_compromised = ( + self.episode_steps // 2 + ) # always compromise in first half of episode + + # Bandwidth for all links + bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + servers = [node for node in node_list if node.node_type == NodeType.SERVER] + + for n, node in enumerate(nodes_to_be_compromised): + # 1: Use Node PoL to set node to compromised + + _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _start_step = np.random.randint( + 2, max_step_compromised + 1 + ) # step compromised + _end_step = _start_step # Become compromised on 1 step + _target_node_id = node.node_id + _pol_initiator = "DIRECT" + _pol_type = NodePOLType["SERVICE"] # All computers are service nodes + pol_service_name = np.random.choice( + list(node.get_services().keys()) + ) # Random service may wish to change this, currently always TCP) + pol_protocol = pol_protocol + _pol_state = SoftwareState.COMPROMISED + is_entry_node = True # Assumes all computers in network are entry nodes + _pol_source_node_id = _pol_source_node_id + _pol_source_node_service = _pol_source_node_service + _pol_source_node_service_state = _pol_source_node_service_state + red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + + self.red_node_pol[_id] = red_pol + + # 2: Launch the attack from compromised node - set the IER + + ier_id = str(2000 + n) + # Launch the attack after node is compromised, and not right at the end of the episode + ier_start_step = np.random.randint( + _start_step + 2, int(self.episode_steps * 0.8) + ) + ier_end_step = self.episode_steps + ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith + ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + bandwidths + ) + ier_protocol = pol_service_name # Same protocol as compromised node + ier_service = node.get_services()[ + pol_service_name + ] # same service as defined in the pol + ier_port = ier_service.get_port() + ier_mission_criticality = ( + 0 # Red IER will never be important to green agent success + ) + # We choose a node to attack based on the first that applies: + # a. Green IERs, select dest node of the red ier based on dest node of green IER + # b. Attack a random server that doesn't have a DENY acl rule in default config + # c. Attack a random server + possible_ier_destinations = [ + ier.get_dest_node_id() + for ier in list(self.green_iers.values()) + if ier.get_source_node_id() == node.get_id() + ] + if len(possible_ier_destinations) < 1: + for server in servers: + if not self.acl.is_blocked( + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, + ): + possible_ier_destinations.append(server.node_id) + if len(possible_ier_destinations) < 1: + # If still none found choose from all servers + possible_ier_destinations = [server.node_id for server in servers] + ier_dest = np.random.choice(possible_ier_destinations) + self.red_iers[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + ier_source_node_id, + ier_dest, + ier_mission_criticality, + ) + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # TODO remove duplicate red pol for same targetted service - must take into account start step + + o_pol_id = str(3000 + n) + o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + o_pol_end_step = ( + self.episode_steps + ) # Can become compromised at any timestep after start + o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + o_pol_service_name = ( + ier_protocol # Same protocol/service as the IER uses to attack + ) + o_pol_new_state = SoftwareState["OVERWHELMED"] + o_pol_entry_node = False # Assumes servers are not entry nodes + o_red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..9ae917e9 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -137,3 +137,20 @@ class NodeStateInstructionRed(object): The source node service state """ return self.source_node_service_state + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"start_step={self.start_step}, " + f"end_step={self.end_step}, " + f"target_node_id={self.target_node_id}, " + f"initiator={self.initiator}, " + f"pol_type={self.pol_type}, " + f"service_name={self.service_name}, " + f"state={self.state}, " + f"source_node_id={self.source_node_id}, " + f"source_node_service={self.source_node_service}, " + f"source_node_service_state={self.source_node_service_state}" + f")" + ) \ No newline at end of file diff --git a/tests/config/random_agent_main_config.yaml b/tests/config/random_agent_main_config.yaml new file mode 100644 index 00000000..d2d18bbc --- /dev/null +++ b/tests/config/random_agent_main_config.yaml @@ -0,0 +1,96 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: GENERIC +# +red_agent_identifier: RANDOM +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: ANY +# Number of episodes to run per session +num_episodes: 1 +# Number of time_steps per episode +num_steps: 5 +# Time delay between steps (for generic agents) +time_delay: 1 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1_000_000_000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py new file mode 100644 index 00000000..a86e32c1 --- /dev/null +++ b/tests/test_red_random_agent_behaviour.py @@ -0,0 +1,74 @@ +from datetime import time, datetime + +from primaite.environment.primaite_env import Primaite +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_temp_session_path + + +def run_generic(env, config_values): + """Run against a generic agent.""" + # Reset the environment at the start of the episode + env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + # action = env.action_space.sample() + action = 0 + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + env.reset() + + env.close() + + +def test_random_red_agent_behaviour(): + """ + Test that hardware state is penalised at each step. + + When the initial state is OFF compared to reference state which is ON. + """ + list_of_node_instructions = [] + for i in range(2): + + """Takes a config path and returns the created instance of Primaite.""" + session_timestamp: datetime = datetime.now() + session_path = _get_temp_session_path(session_timestamp) + + timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + env = Primaite( + training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + transaction_list=[], + session_path=session_path, + timestamp_str=timestamp_str, + ) + training_config = env.training_config + training_config.num_steps = env.episode_steps + + # TOOD: This needs t be refactored to happen outside. Should be part of + # a main Session class. + if training_config.agent_identifier == "GENERIC": + run_generic(env, training_config) + all_red_actions = env.red_node_pol + list_of_node_instructions.append(all_red_actions) + + # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) + print(list_of_node_instructions[0]["1"].get_start_step()) + print(list_of_node_instructions[0]["1"].get_end_step()) + print(list_of_node_instructions[0]["1"].get_target_node_id()) + print(list_of_node_instructions[1]["1"].get_start_step()) + print(list_of_node_instructions[1]["1"].get_end_step()) + print(list_of_node_instructions[1]["1"].get_target_node_id()) + assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) From 36f1dff9b858d38a7e943a222db9f41857a87ae0 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 27 Jun 2023 12:27:57 +0100 Subject: [PATCH 02/21] 1555 - updated doc-string to make test understanding easier --- tests/test_reward.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/test_reward.py b/tests/test_reward.py index c3fcdfc4..56e31ed5 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -16,17 +16,25 @@ def test_rewards_are_being_penalised_at_each_step_function(): ) """ - On different steps (of the 13 in total) these are the following rewards for config_6 which are activated: - File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3) - Hardware State: onShouldBeOff = -2 (between Steps 4 & 6) - Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9) - Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12) + The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps: + On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute. + For example, turning the nodes' file system state to CORRUPT from its original state GOOD. + As a result these are the following rewards are activated: + File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 = 3) + Hardware State: off_should_be_on = -10 * 2 (on Steps 4 - 6) + Service State: compromised_should_be_good = -20 * 2 (on Steps 7 - 9) + Software State: compromised_should_be_good = -20 * 2 (on Steps 10 - 12) - Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26 - Step Count: 13 + The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice. + + Note: This test run inherits conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded + to do NOTHING on every step so we use Pattern of Lifes (PoLs) to change the nodes states and display that the agent + is being penalised on every step where the live network node differs from the network reference node. + + Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120 + Step Count: 15 For the 4 steps where this occurs the average reward is: - Average Reward: 2 (26 / 13) + Average Reward: -8 (-120 / 15) """ - print("average reward", env.average_reward) assert env.average_reward == -8.0 From beae1e5c4f503016d9b802d9fce50ede1c2bd064 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 13:06:10 +0100 Subject: [PATCH 03/21] Cosmetic changes to satisfy pre-commit --- src/primaite/__init__.py | 2 +- src/primaite/cli.py | 5 +- src/primaite/config/training_config.py | 14 +--- src/primaite/environment/primaite_env.py | 21 +++-- src/primaite/main.py | 5 +- src/primaite/nodes/node.py | 5 +- src/primaite/nodes/service_node.py | 4 +- src/primaite/notebooks/__init__.py | 1 - tests/test_observation_space.py | 9 ++- tests/test_resetting_node.py | 98 +++++++++++------------- tests/test_single_action_space.py | 9 ++- 11 files changed, 80 insertions(+), 93 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 1ea110c9..420420f4 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -31,7 +31,7 @@ def _get_primaite_config(): "INFO": logging.INFO, "WARN": logging.WARN, "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL + "CRITICAL": logging.CRITICAL, } primaite_config["log_level"] = log_level_map[primaite_config["log_level"]] return primaite_config diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 19746d01..319d643c 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -3,8 +3,8 @@ import logging import os import shutil -from pathlib import Path from enum import Enum +from pathlib import Path from typing import Optional import pkg_resources @@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): :param last_n: The number of lines to print. Default value is 10. """ import re + from primaite import LOG_PATH if os.path.isfile(LOG_PATH): @@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): print(re.sub(r"\n*", "", line)) -_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa +_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa @app.command() diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..7bafc910 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Union, Optional +from typing import Any, Dict, Final, Optional, Union import yaml @@ -167,8 +167,7 @@ def main_training_config_path() -> Path: return path -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -213,9 +212,7 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -227,10 +224,7 @@ def convert_legacy_training_config_dict( don't have action_type values. :return: The converted training config dict. """ - config_dict = { - "num_steps": num_steps, - "action_type": action_type - } + config_dict = {"num_steps": num_steps, "action_type": action_type} for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..7608e7db 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -177,7 +175,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -199,7 +196,6 @@ class Primaite(Env): try: plt.tight_layout() nx.draw_networkx(self.network, with_labels=True) - now = datetime.now() # current date and time file_path = session_path / f"network_{timestamp_str}.png" plt.savefig(file_path, format="PNG") @@ -238,7 +234,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -379,7 +377,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + # print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1033,7 +1031,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. diff --git a/src/primaite/main.py b/src/primaite/main.py index ac32a018..f5e94509 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import SESSIONS_DIR, getLogger from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file +from primaite.transactions.transactions_to_file import write_transaction_to_file _LOGGER = getLogger(__name__) @@ -349,5 +348,3 @@ if __name__ == "__main__": "Please provide a lay down config file using the --ldc " "argument" ) run(training_config_path=args.tc, lay_down_config_path=args.ldc) - - diff --git a/src/primaite/nodes/node.py b/src/primaite/nodes/node.py index 00cd01c2..bac1792d 100644 --- a/src/primaite/nodes/node.py +++ b/src/primaite/nodes/node.py @@ -46,6 +46,7 @@ class Node: """Sets the node state to ON.""" self.hardware_state = HardwareState.BOOTING self.booting_count = self.config_values.node_booting_duration + def turn_off(self): """Sets the node state to OFF.""" self.hardware_state = HardwareState.OFF @@ -64,14 +65,14 @@ class Node: self.hardware_state = HardwareState.ON def update_booting_status(self): - """Updates the booting count""" + """Updates the booting count.""" self.booting_count -= 1 if self.booting_count <= 0: self.booting_count = 0 self.hardware_state = HardwareState.ON def update_shutdown_status(self): - """Updates the shutdown count""" + """Updates the shutdown count.""" self.shutting_down_count -= 1 if self.shutting_down_count <= 0: self.shutting_down_count = 0 diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 84a7c587..324592c3 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -190,13 +190,15 @@ class ServiceNode(ActiveNode): service_value.reduce_patching_count() def update_resetting_status(self): + """Update resetting counter and set software state if it reached 0.""" super().update_resetting_status() if self.resetting_count <= 0: for service in self.services.values(): service.software_state = SoftwareState.GOOD def update_booting_status(self): + """Update booting counter and set software to good if it reached 0.""" super().update_booting_status() if self.booting_count <= 0: for service in self.services.values(): - service.software_state =SoftwareState.GOOD + service.software_state = SoftwareState.GOOD diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 6d822961..71ed343e 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -17,7 +17,6 @@ def start_jupyter_session(): .. todo:: Figure out how to get this working for Linux and MacOS too. """ - if importlib.util.find_spec("jupyter") is not None: jupyter_cmd = "python3 -m jupyter lab" if sys.platform == "win32": diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index dbcdf2d6..efca7b0b 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -27,7 +27,8 @@ def env(request): @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_without_obs.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) @@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite): @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_without_obs.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) @@ -140,7 +142,8 @@ class TestNodeLinkTable: @pytest.mark.env_config_paths( dict( - training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_NODE_STATUSES.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index b2843f7f..abe8115c 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,7 +1,13 @@ """Used to test Active Node functions.""" import pytest -from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority +from primaite.common.enums import ( + FileSystemState, + HardwareState, + NodeType, + Priority, + SoftwareState, +) from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode @pytest.mark.parametrize( "starting_operating_state, expected_operating_state", - [ - (HardwareState.RESETTING, HardwareState.ON) - ], + [(HardwareState.RESETTING, HardwareState.ON)], ) def test_node_resets_correctly(starting_operating_state, expected_operating_state): - """ - Tests that a node resets correctly. - """ + """Tests that a node resets correctly.""" active_node = ActiveNode( - node_id = "0", - name = "node", - node_type = NodeType.COMPUTER, - priority = Priority.P1, - hardware_state = starting_operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.COMPROMISED, - file_system_state = FileSystemState.CORRUPT, - config_values=TrainingConfig() + node_id="0", + name="node", + node_type=NodeType.COMPUTER, + priority=Priority.P1, + hardware_state=starting_operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.COMPROMISED, + file_system_state=FileSystemState.CORRUPT, + config_values=TrainingConfig(), ) for x in range(5): @@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat assert active_node.file_system_state_actual == FileSystemState.GOOD assert active_node.hardware_state == expected_operating_state + @pytest.mark.parametrize( "operating_state, expected_operating_state", - [ - (HardwareState.BOOTING, HardwareState.ON) - ], + [(HardwareState.BOOTING, HardwareState.ON)], ) def test_node_boots_correctly(operating_state, expected_operating_state): - """ - Tests that a node boots correctly. - """ + """Tests that a node boots correctly.""" service_node = ServiceNode( - node_id = 0, - name = "node", - node_type = "COMPUTER", - priority = "1", - hardware_state = operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.GOOD, - file_system_state = "GOOD", - config_values = 1, + node_id=0, + name="node", + node_type="COMPUTER", + priority="1", + hardware_state=operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.GOOD, + file_system_state="GOOD", + config_values=1, ) service_attributes = Service( - name = "node", - port = "80", - software_state = SoftwareState.COMPROMISED - ) - service_node.add_service( - service_attributes + name="node", port="80", software_state=SoftwareState.COMPROMISED ) + service_node.add_service(service_attributes) for x in range(5): service_node.update_booting_status() @@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state): assert service_attributes.software_state == SoftwareState.GOOD assert service_node.hardware_state == expected_operating_state + @pytest.mark.parametrize( "operating_state, expected_operating_state", - [ - (HardwareState.SHUTTING_DOWN, HardwareState.OFF) - ], + [(HardwareState.SHUTTING_DOWN, HardwareState.OFF)], ) def test_node_shutdown_correctly(operating_state, expected_operating_state): - """ - Tests that a node shutdown correctly. - """ + """Tests that a node shutdown correctly.""" active_node = ActiveNode( - node_id = 0, - name = "node", - node_type = "COMPUTER", - priority = "1", - hardware_state = operating_state, - ip_address = "192.168.0.1", - software_state = SoftwareState.GOOD, - file_system_state = "GOOD", - config_values = 1, + node_id=0, + name="node", + node_type="COMPUTER", + priority="1", + hardware_state=operating_state, + ip_address="192.168.0.1", + software_state=SoftwareState.GOOD, + file_system_state="GOOD", + config_values=1, ) for x in range(5): active_node.update_shutdown_status() assert active_node.hardware_state == expected_operating_state - - diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 16b9d03e..8ff43fe6 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -48,7 +48,8 @@ def test_single_action_space_is_valid(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" env = _get_primaite_env_from_config( training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", ) run_generic_set_actions(env) @@ -77,8 +78,10 @@ def test_single_action_space_is_valid(): def test_agent_is_executing_actions_from_both_spaces(): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", ) # Run environment with specified fixed blue agent actions only run_generic_set_actions(env) From a3e50293b7d74c63fdfc455494db0c27ada81148 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 12:07:33 +0000 Subject: [PATCH 04/21] Add pre-commits to build pipeline --- .azure/azure-ci-build-pipeline.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index dd45907d..ceda11c6 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -25,6 +25,11 @@ steps: versionSpec: '$(python.version)' displayName: 'Use Python $(python.version)' +- script: | + pre-commit install + pre-commit run --all-files + displayName: 'Run pre-commits' + - script: | python -m pip install --upgrade pip==23.0.1 pip install wheel==0.38.4 --upgrade From 33f7e9f5066a057a572ee93d174b95131f7abe62 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 13:07:54 +0000 Subject: [PATCH 05/21] Add pre-commit --- .azure/azure-ci-build-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index ceda11c6..8bfdca02 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -26,6 +26,7 @@ steps: displayName: 'Use Python $(python.version)' - script: | + python -m pip install pre-commit pre-commit install pre-commit run --all-files displayName: 'Run pre-commits' From 349a18a4eb0ef067442bc7ec8cfb8f1167ee7a53 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Jun 2023 15:27:56 +0100 Subject: [PATCH 06/21] Fix ier reward calculation --- src/primaite/environment/reward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index aa9e4503..1a1a0770 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -78,8 +78,8 @@ def calculate_reward_function( start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - reference_blocked = reference_ier.get_is_running() - live_blocked = ier_value.get_is_running() + reference_blocked = not reference_ier.get_is_running() + live_blocked = not ier_value.get_is_running() ier_reward = ( config_values.green_ier_blocked * ier_value.get_mission_criticality() ) From 9623b1450a527b89162c4319998b9fa70681f9a3 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 27 Jun 2023 16:59:43 +0100 Subject: [PATCH 07/21] 1555 - added specific steps to doc string --- tests/test_reward.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_reward.py b/tests/test_reward.py index 56e31ed5..b8c92274 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -20,16 +20,17 @@ def test_rewards_are_being_penalised_at_each_step_function(): On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute. For example, turning the nodes' file system state to CORRUPT from its original state GOOD. As a result these are the following rewards are activated: - File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 = 3) - Hardware State: off_should_be_on = -10 * 2 (on Steps 4 - 6) - Service State: compromised_should_be_good = -20 * 2 (on Steps 7 - 9) - Software State: compromised_should_be_good = -20 * 2 (on Steps 10 - 12) + File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 & 2) + Hardware State: off_should_be_on = -10 * 2 (on Steps 4 & 5) + Service State: compromised_should_be_good = -20 * 2 (on Steps 7 & 8) + Software State: compromised_should_be_good = -20 * 2 (on Steps 10 & 11) The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice. - Note: This test run inherits conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded - to do NOTHING on every step so we use Pattern of Lifes (PoLs) to change the nodes states and display that the agent - is being penalised on every step where the live network node differs from the network reference node. + Note: This test run inherits from conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded + to do NOTHING on every step. + We use Pattern of Lifes (PoLs) to change the nodes states and display that the agent is being penalised on all steps + where the live network node differs from the network reference node. Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120 Step Count: 15 From e086d419adc4e4721a1010b88548f36b7e2138ab Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 28 Jun 2023 11:07:45 +0100 Subject: [PATCH 08/21] Attempt to add flat spaces --- scratch.py | 6 +++++ .../training/training_config_main.yaml | 9 +++++-- src/primaite/environment/observations.py | 24 +++++++++++++++---- 3 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 scratch.py diff --git a/scratch.py b/scratch.py new file mode 100644 index 00000000..6bab60c1 --- /dev/null +++ b/scratch.py @@ -0,0 +1,6 @@ +from primaite.main import run + +run( + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml", + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml", +) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..a679400c 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -11,12 +11,17 @@ agent_identifier: STABLE_BASELINES3_A2C # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE # Number of episodes to run per session -num_episodes: 10 +num_episodes: 1000 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) -time_delay: 10 +time_delay: 0 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING # Determine whether to load an agent from file diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..e6eb533c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -311,8 +311,13 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] + + # need to keep track of the flattened and unflattened version of the space (if there is one) self.space: spaces.Space + self.unflattened_space: spaces.Space + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] + self.flatten: bool = False def update_obs(self): """Fetch fresh information about the environment.""" @@ -324,9 +329,14 @@ class ObservationsHandler: # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: self.current_observation = current_obs[0] + # If there are many compoenents, the space may need to be flattened else: - self.current_observation = tuple(current_obs) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + if self.flatten: + self.current_observation = spaces.flatten( + self.unflattened_space, tuple(current_obs) + ) + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -357,8 +367,11 @@ class ObservationsHandler: if len(component_spaces) == 1: self.space = component_spaces[0] else: - self.space = spaces.Tuple(component_spaces) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self.unflattened_space = spaces.Tuple(component_spaces) + if self.flatten: + self.space = spaces.flatten_space(spaces.Tuple(component_spaces)) + else: + self.space = self.unflattened_space @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -388,6 +401,9 @@ class ObservationsHandler: # Instantiate the handler handler = cls() + if obs_space_config.get("flatten"): + handler.flatten = True + for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] From 10e432eb01e5491c68c6b71a41d104178e7bbac4 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 29 Jun 2023 15:03:11 +0100 Subject: [PATCH 09/21] #1522: fixing create random red agent function --- .gitignore | 2 + src/primaite/environment/primaite_env.py | 149 ++++++++++++----------- 2 files changed, 78 insertions(+), 73 deletions(-) diff --git a/.gitignore b/.gitignore index eed6c903..5adbdc57 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,5 @@ dmypy.json # Cython debug symbols cython_debug/ + +.idea/ \ No newline at end of file diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index c3d408d2..9ac3d8e6 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -9,6 +9,7 @@ from typing import Dict, Tuple, Union import networkx as nx import numpy as np +import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -58,12 +59,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -275,8 +276,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - if self.training_config.red_agent_identifier == "RANDOM": - self.create_random_red_agent() + # if self.training_config.red_agent_identifier == "RANDOM": + # self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -380,6 +381,7 @@ class Primaite(Env): self.nodes_post_pol, self.nodes_post_red, self.nodes_reference, + self.green_iers, self.green_iers_reference, self.red_iers, self.step_count, @@ -445,11 +447,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1247,14 +1249,17 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + # choose a random compromise node to be source of attacks + source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1264,57 +1269,50 @@ class Primaite(Env): for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised - _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _id = str(uuid.uuid4()) _start_step = np.random.randint( 2, max_step_compromised + 1 ) # step compromised - _end_step = _start_step # Become compromised on 1 step - _target_node_id = node.node_id - _pol_initiator = "DIRECT" - _pol_type = NodePOLType["SERVICE"] # All computers are service nodes pol_service_name = np.random.choice( - list(node.get_services().keys()) - ) # Random service may wish to change this, currently always TCP) - pol_protocol = pol_protocol - _pol_state = SoftwareState.COMPROMISED - is_entry_node = True # Assumes all computers in network are entry nodes - _pol_source_node_id = _pol_source_node_id - _pol_source_node_service = _pol_source_node_service - _pol_source_node_service_state = _pol_source_node_service_state + list(node.services.keys()) + ) + + source_node_service = np.random.choice( + list(source_node.services.values()) + ) + red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=_id, + _start_step=_start_step, + _end_step=_start_step, # only run for 1 step + _target_node_id=node.node_id, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=pol_service_name, + _pol_state=SoftwareState.COMPROMISED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[_id] = red_pol # 2: Launch the attack from compromised node - set the IER - ier_id = str(2000 + n) + ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode ier_start_step = np.random.randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps - ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node - ier_service = node.get_services()[ - pol_service_name - ] # same service as defined in the pol - ier_port = ier_service.get_port() + ier_service = node.services[pol_service_name] + ier_port = ier_service.port ier_mission_criticality = ( 0 # Red IER will never be important to green agent success ) @@ -1325,15 +1323,15 @@ class Primaite(Env): possible_ier_destinations = [ ier.get_dest_node_id() for ier in list(self.green_iers.values()) - if ier.get_source_node_id() == node.get_id() + if ier.get_source_node_id() == node.node_id ] if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), - server.ip_address, - ier_service, - ier_port, + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1347,37 +1345,42 @@ class Primaite(Env): ier_load, ier_protocol, ier_port, - ier_source_node_id, + node.node_id, ier_dest, ier_mission_criticality, ) - # 3: Make sure the targetted node can be set to overwhelmed - with node pol - # TODO remove duplicate red pol for same targetted service - must take into account start step + overwhelm_pol = red_pol + overwhelm_pol.id = str(uuid.uuid4()) + overwhelm_pol.end_step = self.episode_steps - o_pol_id = str(3000 + n) - o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - o_pol_end_step = ( - self.episode_steps - ) # Can become compromised at any timestep after start - o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - o_pol_service_name = ( - ier_protocol # Same protocol/service as the IER uses to attack - ) - o_pol_new_state = SoftwareState["OVERWHELMED"] - o_pol_entry_node = False # Assumes servers are not entry nodes + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # # TODO remove duplicate red pol for same targetted service - must take into account start step + # + o_pol_id = str(uuid.uuid4()) + # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + # o_pol_end_step = ( + # self.episode_steps + # ) # Can become compromised at any timestep after start + # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + # o_pol_service_name = ( + # ier_protocol # Same protocol/service as the IER uses to attack + # ) + # o_pol_new_state = SoftwareState["OVERWHELMED"] + # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=o_pol_id, + _start_step=ier_start_step, + _end_step=self.episode_steps, + _target_node_id=ier_dest, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=ier_protocol, + _pol_state=SoftwareState.OVERWHELMED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[o_pol_id] = o_red_pol From c9f58fdb2a1031b2aee16fc929bfe76e4765fb74 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Jun 2023 15:26:07 +0100 Subject: [PATCH 10/21] Fix observation representation in transactions --- src/primaite/environment/observations.py | 149 +++++++++++++++--- src/primaite/environment/primaite_env.py | 5 +- src/primaite/main.py | 1 + .../transactions/transactions_to_file.py | 54 ++----- 4 files changed, 150 insertions(+), 59 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index e6eb533c..023c5f30 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? + self.structure: list[str] return NotImplemented @abstractmethod @@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC): """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented + @abstractmethod + def generate_structure(self) -> List[str]: + """Return a list of labels for the components of the flattened observation space.""" + return NotImplemented + class NodeLinkTable(AbstractObservationComponent): """Table with nodes and links as rows and hardware/software status as cols. @@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent): # 3. Initialise Observation with zeroes self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent): protocol_index += 1 item_index += 1 + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + nodes = self.env.nodes.values() + links = self.env.links.values() + + structure = [] + + for i, node in enumerate(nodes): + node_id = node.node_id + node_labels = [ + f"node_{node_id}_id", + f"node_{node_id}_hardware_status", + f"node_{node_id}_os_status", + f"node_{node_id}_fs_status", + ] + for j, serv in enumerate(self.env.services_list): + node_labels.append(f"node_{node_id}_service_{serv}_status") + + structure.extend(node_labels) + + for i, link in enumerate(links): + link_id = link.id + link_labels = [ + f"link_{link_id}_id", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + ] + for j, serv in enumerate(self.env.services_list): + link_labels.append(f"node_{node_id}_service_{serv}_load") + + structure.extend(link_labels) + return structure + class NodeStatuses(AbstractObservationComponent): """Flat list of nodes' hardware, OS, file system, and service states. @@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() def update(self): """Update the observation based on current environment state. @@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent): ) self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + services = self.env.services_list + + structure = [] + for _, node in self.env.nodes.items(): + node_id = node.node_id + structure.append(f"node_{node_id}_hardware_state_NONE") + for state in HardwareState: + structure.append(f"node_{node_id}_hardware_state_{state.name}") + structure.append(f"node_{node_id}_software_state_NONE") + for state in SoftwareState: + structure.append(f"node_{node_id}_software_state_{state.name}") + structure.append(f"node_{node_id}_file_system_state_NONE") + for state in FileSystemState: + structure.append(f"node_{node_id}_file_system_state_{state.name}") + for service in services: + structure.append(f"node_{node_id}_service_{service}_state_NONE") + for state in SoftwareState: + structure.append( + f"node_{node_id}_service_{service}_state_{state.name}" + ) + return structure + class LinkTrafficLevels(AbstractObservationComponent): """Flat list of traffic levels encoded into banded categories. @@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + structure = [] + for _, link in self.env.links.items(): + link_id = link.id + if self._combine_service_traffic: + protocols = ["overall"] + else: + protocols = [protocol.name for protocol in link.protocol_list] + + for p in protocols: + for i in range(self._quantisation_levels): + structure.append(f"link_{link_id}_{p}_traffic_level_{i}") + return structure + class ObservationsHandler: """Component-based observation space handler. @@ -312,11 +396,15 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] - # need to keep track of the flattened and unflattened version of the space (if there is one) - self.space: spaces.Space - self.unflattened_space: spaces.Space + # internal the observation space (unflattened version of space if flatten=True) + self._space: spaces.Space + # flattened version of the observation space + self._flat_space: spaces.Space + + self._observation: Union[Tuple[np.ndarray], np.ndarray] + # used for transactions and when flatten=true + self._flat_observation: np.ndarray - self.current_observation: Union[Tuple[np.ndarray], np.ndarray] self.flatten: bool = False def update_obs(self): @@ -326,17 +414,11 @@ class ObservationsHandler: obs.update() current_obs.append(obs.current_observation) - # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: - self.current_observation = current_obs[0] - # If there are many compoenents, the space may need to be flattened + self._observation = current_obs[0] else: - if self.flatten: - self.current_observation = spaces.flatten( - self.unflattened_space, tuple(current_obs) - ) - else: - self.current_observation = tuple(current_obs) + self._observation = tuple(current_obs) + self._flat_observation = spaces.flatten(self._space, self._observation) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -363,15 +445,28 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - # If there is only one component, don't use a tuple space, just pass through that component's space. + # if there are multiple components, build a composite tuple space if len(component_spaces) == 1: - self.space = component_spaces[0] + self._space = component_spaces[0] else: - self.unflattened_space = spaces.Tuple(component_spaces) - if self.flatten: - self.space = spaces.flatten_space(spaces.Tuple(component_spaces)) - else: - self.space = self.unflattened_space + self._space = spaces.Tuple(component_spaces) + self._flat_space = spaces.flatten_space(self._space) + + @property + def space(self): + """Observation space, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_space + else: + return self._space + + @property + def current_observation(self): + """Current observation, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_observation + else: + return self._observation @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -417,3 +512,17 @@ class ObservationsHandler: handler.update_obs() return handler + + def describe_structure(self): + """Create a list of names for the features of the obs space. + + The order of labels follows the flattened version of the space. + """ + # as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have + # to fake it. each component has to just hard-code the expected label order after flattening... + + labels = [] + for obs_comp in self.registered_obs_components: + labels.extend(obs_comp.structure) + + return labels diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index be4cc434..e56abf9d 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -318,7 +318,8 @@ class Primaite(Env): datetime.now(), self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) + transaction.set_obs_space_pre(self.obs_handler._flat_observation) + # Load the action space into the transaction transaction.set_action_space(copy.deepcopy(action)) @@ -400,7 +401,7 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() # Load the new observation space into the transaction - transaction.set_obs_space_post(copy.deepcopy(self.env_obs)) + transaction.set_obs_space_post(self.obs_handler._flat_observation) # 8. Add the transaction to the list of transactions self.transaction_list.append(copy.deepcopy(transaction)) diff --git a/src/primaite/main.py b/src/primaite/main.py index f5e94509..4d83f604 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -325,6 +325,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, transaction_list=transaction_list, session_path=session_dir, timestamp_str=timestamp_str, + obs_space_description=env.obs_handler.describe_structure(), ) print("Updating Session Metadata file...") diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 11e68af8..b2a4d40d 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space): return [str(_action_space)] -def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): - """ - Turns observation space into a string array so it can be saved to csv. - - Args: - _obs_space: The observation space - _obs_assets: The number of assets (i.e. nodes or links) in the observation space - _obs_features: The number of features associated with the asset - """ - return_array = [] - for x in range(_obs_assets): - for y in range(_obs_features): - return_array.append(str(_obs_space[x][y])) - - return return_array - - -def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str): +def write_transaction_to_file( + transaction_list, + session_path: Path, + timestamp_str: str, + obs_space_description: list, +): """ Writes transaction logs to file to support training evaluation. @@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st # This will be tied into the PrimAITE Use Case so that they make sense template_transation = transaction_list[0] action_length = template_transation.action_space.size - obs_shape = template_transation.obs_space_post.shape - obs_assets = template_transation.obs_space_post.shape[0] - if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - obs_features = 1 - else: - obs_features = template_transation.obs_space_post.shape[1] + # obs_shape = template_transation.obs_space_post.shape + # obs_assets = template_transation.obs_space_post.shape[0] + # if len(obs_shape) == 1: + # bit of a workaround but I think the way transactions are written will change soon + # obs_features = 1 + # else: + # obs_features = template_transation.obs_space_post.shape[1] # Create the action space headers array action_header = [] @@ -70,12 +58,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st action_header.append("AS_" + str(x)) # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) + obs_header_initial = [f"pre_{o}" for o in obs_space_description] + obs_header_new = [f"post_{o}" for o in obs_space_description] # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] @@ -98,12 +82,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st csv_data = ( csv_data + turn_action_space_to_array(transaction.action_space) - + turn_obs_space_to_array( - transaction.obs_space_pre, obs_assets, obs_features - ) - + turn_obs_space_to_array( - transaction.obs_space_post, obs_assets, obs_features - ) + + transaction.obs_space_pre.tolist() + + transaction.obs_space_post.tolist() ) csv_writer.writerow(csv_data) From fb48f75adffca6560984cc3590e932f9d9f78b82 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 09:54:34 +0100 Subject: [PATCH 11/21] Remove temporary file --- scratch.py | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 scratch.py diff --git a/scratch.py b/scratch.py deleted file mode 100644 index 6bab60c1..00000000 --- a/scratch.py +++ /dev/null @@ -1,6 +0,0 @@ -from primaite.main import run - -run( - "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml", - "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml", -) From 4e1e0ef4b45cb5bf5ec430a6eb0c6dba6cad4eef Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 30 Jun 2023 10:37:23 +0100 Subject: [PATCH 12/21] #1522: remove numpy randomisation + added random red agent config --- .../training/training_config_main.yaml | 5 + .../training_config_random_red_agent.yaml | 99 +++++++++++++++++++ src/primaite/environment/primaite_env.py | 39 +++----- 3 files changed, 118 insertions(+), 25 deletions(-) create mode 100644 src/primaite/config/_package_data/training/training_config_random_red_agent.yaml diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..3fe668e2 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -6,6 +6,11 @@ # "STABLE_BASELINES3_A2C" # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "NONE" + # Sets How the Action Space is defined: # "NODE" # "ACL" diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml new file mode 100644 index 00000000..9382a2b5 --- /dev/null +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -0,0 +1,99 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "RANDOM" + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 9ac3d8e6..e592e21f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,6 +5,7 @@ import csv import logging from datetime import datetime from pathlib import Path +from random import randint, choice, uniform, sample from typing import Dict, Tuple, Union import networkx as nx @@ -276,8 +277,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - # if self.training_config.red_agent_identifier == "RANDOM": - # self.create_random_red_agent() + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -1249,13 +1250,13 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) + num_nodes_to_compromise = randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise - nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + nodes_to_be_compromised = sample(computers, num_nodes_to_compromise) # choose a random compromise node to be source of attacks - source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + source_node = choice(nodes_to_be_compromised) # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( @@ -1270,14 +1271,14 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = np.random.randint( + _start_step = randint( 2, max_step_compromised + 1 ) # step compromised - pol_service_name = np.random.choice( + pol_service_name = choice( list(node.services.keys()) ) - source_node_service = np.random.choice( + source_node_service = choice( list(source_node.services.values()) ) @@ -1301,13 +1302,13 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = np.random.randint( + ier_start_step = randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + ier_load = uniform(0.4, 0.8) * choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node @@ -1328,7 +1329,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), + node.ip_address, server.ip_address, ier_service, ier_port, @@ -1337,7 +1338,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: # If still none found choose from all servers possible_ier_destinations = [server.node_id for server in servers] - ier_dest = np.random.choice(possible_ier_destinations) + ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, ier_start_step, @@ -1354,22 +1355,10 @@ class Primaite(Env): overwhelm_pol.id = str(uuid.uuid4()) overwhelm_pol.end_step = self.episode_steps - # 3: Make sure the targetted node can be set to overwhelmed - with node pol # # TODO remove duplicate red pol for same targetted service - must take into account start step - # + o_pol_id = str(uuid.uuid4()) - # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - # o_pol_end_step = ( - # self.episode_steps - # ) # Can become compromised at any timestep after start - # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - # o_pol_service_name = ( - # ier_protocol # Same protocol/service as the IER uses to attack - # ) - # o_pol_new_state = SoftwareState["OVERWHELMED"] - # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( _id=o_pol_id, _start_step=ier_start_step, From 99ba05c6ee644223bfbd6e560ccdffa2aa90e5cc Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 10:41:56 +0100 Subject: [PATCH 13/21] Remove redundant cols from transactions --- src/primaite/environment/observations.py | 2 +- src/primaite/environment/primaite_env.py | 4 +--- src/primaite/transactions/transaction.py | 13 ++----------- src/primaite/transactions/transactions_to_file.py | 9 ++++----- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 023c5f30..fcd52559 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -168,7 +168,7 @@ class NodeLinkTable(AbstractObservationComponent): f"link_{link_id}_n/a", ] for j, serv in enumerate(self.env.services_list): - link_labels.append(f"node_{node_id}_service_{serv}_load") + link_labels.append(f"link_{link_id}_service_{serv}_load") structure.extend(link_labels) return structure diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e56abf9d..2418cac0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -318,7 +318,7 @@ class Primaite(Env): datetime.now(), self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(self.obs_handler._flat_observation) + transaction.set_obs_space(self.obs_handler._flat_observation) # Load the action space into the transaction transaction.set_action_space(copy.deepcopy(action)) @@ -400,8 +400,6 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() - # Load the new observation space into the transaction - transaction.set_obs_space_post(self.obs_handler._flat_observation) # 8. Add the transaction to the list of transactions self.transaction_list.append(copy.deepcopy(transaction)) diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index a4ce48e3..39236217 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -20,23 +20,14 @@ class Transaction(object): self.episode_number = _episode_number self.step_number = _step_number - def set_obs_space_pre(self, _obs_space_pre): + def set_obs_space(self, _obs_space): """ Sets the observation space (pre). Args: _obs_space_pre: The observation space before any actions are taken """ - self.obs_space_pre = _obs_space_pre - - def set_obs_space_post(self, _obs_space_post): - """ - Sets the observation space (post). - - Args: - _obs_space_post: The observation space after any actions are taken - """ - self.obs_space_post = _obs_space_post + self.obs_space = _obs_space def set_reward(self, _reward): """ diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index b2a4d40d..4e364f0b 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -58,12 +58,12 @@ def write_transaction_to_file( action_header.append("AS_" + str(x)) # Create the observation space headers array - obs_header_initial = [f"pre_{o}" for o in obs_space_description] - obs_header_new = [f"post_{o}" for o in obs_space_description] + # obs_header_initial = [f"pre_{o}" for o in obs_space_description] + # obs_header_new = [f"post_{o}" for o in obs_space_description] # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new + header = header + action_header + obs_space_description try: filename = session_path / f"all_transactions_{timestamp_str}.csv" @@ -82,8 +82,7 @@ def write_transaction_to_file( csv_data = ( csv_data + turn_action_space_to_array(transaction.action_space) - + transaction.obs_space_pre.tolist() - + transaction.obs_space_post.tolist() + + transaction.obs_space.tolist() ) csv_writer.writerow(csv_data) From 09883e13c2b2b5c4254a43cf8c4a89fa7d2b4431 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 10:44:04 +0100 Subject: [PATCH 14/21] Update docs --- docs/source/primaite_session.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 4f639f11..a59b2361 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session: * Timestamp * Episode number * Step number - * Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y - * Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y + * Initial observation space (what the blue agent observed when it decided its action) * Reward value - * Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X + * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X **Diagrams** From d86489a9c27272750dfb35b9a354e2e97a2edeea Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 13:16:30 +0100 Subject: [PATCH 15/21] revert unnecessary changes. --- .../_package_data/training/training_config_main.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index a679400c..ac63c667 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -5,7 +5,7 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +agent_identifier: STABLE_BASELINES3_PPO # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -16,12 +16,14 @@ observation_space: # flatten: true components: - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session -num_episodes: 1000 +num_episodes: 10 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) -time_delay: 0 +time_delay: 10 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING # Determine whether to load an agent from file From 7e6fe2759b02694b3d19d04cff7b94f5be86ef22 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 15:43:15 +0100 Subject: [PATCH 16/21] Fix flattening when there are no components. --- src/primaite/environment/observations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index fcd52559..b19bd29f 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -450,7 +450,10 @@ class ObservationsHandler: self._space = component_spaces[0] else: self._space = spaces.Tuple(component_spaces) - self._flat_space = spaces.flatten_space(self._space) + if len(component_spaces) > 0: + self._flat_space = spaces.flatten_space(self._space) + else: + self._flat_space = spaces.Box(0, 1, (0,)) @property def space(self): From 046937d8382a43679052e5a4f6a7b3ac3ddb7975 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 08:00:51 +0000 Subject: [PATCH 17/21] Apply suggestions from code review --- src/primaite/environment/observations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index b19bd29f..81ddaaf5 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,7 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? - self.structure: list[str] + self.structure: List[str] return NotImplemented @abstractmethod From 68457aa0b2f2caf76f1d2aa146857240f3d00368 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 09:46:52 +0100 Subject: [PATCH 18/21] #1522: added a check for existing links in laydown + test that checks if red agent instructions are random --- src/primaite/environment/primaite_env.py | 6 ++++ tests/test_red_random_agent_behaviour.py | 36 ++++++++++++------------ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e592e21f..58932c4c 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1265,6 +1265,12 @@ class Primaite(Env): # Bandwidth for all links bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + + if len(bandwidths) < 1: + msg = "Random red agent cannot be used on a network without any links" + _LOGGER.error(msg) + raise Exception(msg) + servers = [node for node in node_list if node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index a86e32c1..c9189c26 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,5 +1,6 @@ -from datetime import time, datetime +from datetime import datetime +from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -24,9 +25,6 @@ def run_generic(env, config_values): if done: break - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - # Reset the environment at the end of the episode env.reset() @@ -40,6 +38,8 @@ def test_random_red_agent_behaviour(): When the initial state is OFF compared to reference state which is ON. """ list_of_node_instructions = [] + + # RUN TWICE so we can make sure that red agent is randomised for i in range(2): """Takes a config path and returns the created instance of Primaite.""" @@ -49,7 +49,7 @@ def test_random_red_agent_behaviour(): timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, timestamp_str=timestamp_str, @@ -57,18 +57,18 @@ def test_random_red_agent_behaviour(): training_config = env.training_config training_config.num_steps = env.episode_steps - # TOOD: This needs t be refactored to happen outside. Should be part of - # a main Session class. - if training_config.agent_identifier == "GENERIC": - run_generic(env, training_config) - all_red_actions = env.red_node_pol - list_of_node_instructions.append(all_red_actions) + run_generic(env, training_config) + # add red pol instructions to list + list_of_node_instructions.append(env.red_node_pol) + + # compare instructions to make sure that red instructions are truly random + for index, instruction in enumerate(list_of_node_instructions): + for key in list_of_node_instructions[index].keys(): + instruction: NodeInstructionRed = list_of_node_instructions[index][key] + print(f"run {index}") + print(f"{key} start step: {instruction.get_start_step()}") + print(f"{key} end step: {instruction.get_end_step()}") + print(f"{key} target node id: {instruction.get_target_node_id()}") + print("") - # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) - print(list_of_node_instructions[0]["1"].get_start_step()) - print(list_of_node_instructions[0]["1"].get_end_step()) - print(list_of_node_instructions[0]["1"].get_target_node_id()) - print(list_of_node_instructions[1]["1"].get_start_step()) - print(list_of_node_instructions[1]["1"].get_end_step()) - print(list_of_node_instructions[1]["1"].get_target_node_id()) assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) From 6b4530bded0e8e568cac0fcf681640359d6b0071 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 10:08:25 +0100 Subject: [PATCH 19/21] #1522: run pre-commit --- .gitignore | 2 +- src/primaite/environment/primaite_env.py | 55 ++++++++----------- .../nodes/node_state_instruction_red.py | 2 +- tests/test_red_random_agent_behaviour.py | 7 ++- 4 files changed, 28 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 5adbdc57..b65d1fd8 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,4 @@ dmypy.json # Cython debug symbols cython_debug/ -.idea/ \ No newline at end of file +.idea/ diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 58932c4c..eb0bc5de 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,14 +3,14 @@ import copy import csv import logging +import uuid as uuid from datetime import datetime from pathlib import Path -from random import randint, choice, uniform, sample +from random import choice, randint, sample, uniform from typing import Dict, Tuple, Union import networkx as nx import numpy as np -import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -60,12 +60,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -448,11 +448,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1238,7 +1238,6 @@ class Primaite(Env): def create_random_red_agent(self): """Decide on random red agent for the episode to be called in env.reset().""" - # Reset the current red iers and red node pol self.red_iers = {} self.red_node_pol = {} @@ -1260,7 +1259,7 @@ class Primaite(Env): # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1277,16 +1276,10 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint( - 2, max_step_compromised + 1 - ) # step compromised - pol_service_name = choice( - list(node.services.keys()) - ) + _start_step = randint(2, max_step_compromised + 1) # step compromised + pol_service_name = choice(list(node.services.keys())) - source_node_service = choice( - list(source_node.services.values()) - ) + source_node_service = choice(list(source_node.services.values())) red_pol = NodeStateInstructionRed( _id=_id, @@ -1299,7 +1292,7 @@ class Primaite(Env): _pol_state=SoftwareState.COMPROMISED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[_id] = red_pol @@ -1308,15 +1301,11 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint( - _start_step + 2, int(self.episode_steps * 0.8) - ) + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = uniform(0.4, 0.8) * choice( - bandwidths - ) + ier_load = uniform(0.4, 0.8) * choice(bandwidths) ier_protocol = pol_service_name # Same protocol as compromised node ier_service = node.services[pol_service_name] ier_port = ier_service.port @@ -1335,10 +1324,10 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1376,6 +1365,6 @@ class Primaite(Env): _pol_state=SoftwareState.OVERWHELMED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 9ae917e9..2f7d0622 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -153,4 +153,4 @@ class NodeStateInstructionRed(object): f"source_node_service={self.source_node_service}, " f"source_node_service_state={self.source_node_service_state}" f")" - ) \ No newline at end of file + ) diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index c9189c26..476a08f1 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -2,6 +2,7 @@ from datetime import datetime from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite +from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -41,14 +42,14 @@ def test_random_red_agent_behaviour(): # RUN TWICE so we can make sure that red agent is randomised for i in range(2): - """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() session_path = _get_temp_session_path(session_timestamp) timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( - training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, @@ -64,7 +65,7 @@ def test_random_red_agent_behaviour(): # compare instructions to make sure that red instructions are truly random for index, instruction in enumerate(list_of_node_instructions): for key in list_of_node_instructions[index].keys(): - instruction: NodeInstructionRed = list_of_node_instructions[index][key] + instruction: NodeStateInstructionRed = list_of_node_instructions[index][key] print(f"run {index}") print(f"{key} start step: {instruction.get_start_step()}") print(f"{key} end step: {instruction.get_end_step()}") From befd183b2c899208dc34674025ba5b7ac973114e Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 12:18:58 +0100 Subject: [PATCH 20/21] #1522: refactor red_agent_identifier -> random_red_agent so that it is a boolean + documentation --- docs/source/config.rst | 4 + .../training/training_config_main.yaml | 2 +- .../training_config_random_red_agent.yaml | 2 +- src/primaite/config/training_config.py | 2 +- src/primaite/environment/primaite_env.py | 2 +- tests/config/random_agent_main_config.yaml | 96 ------------------- tests/test_red_random_agent_behaviour.py | 2 + 7 files changed, 10 insertions(+), 100 deletions(-) delete mode 100644 tests/config/random_agent_main_config.yaml diff --git a/docs/source/config.rst b/docs/source/config.rst index 74898ec1..fa58e6cf 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -28,6 +28,10 @@ The environment config file consists of the following attributes: * STABLE_BASELINES3_PPO - Use a SB3 PPO agent * STABLE_BASELINES3_A2C - use a SB3 A2C agent +* **random_red_agent** [bool] + + Determines if the session should be run with a random red agent + * **action_type** [enum] Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 3fe668e2..8f035d41 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -9,7 +9,7 @@ agent_identifier: STABLE_BASELINES3_A2C # RED AGENT IDENTIFIER # RANDOM or NONE -red_agent_identifier: "NONE" +random_red_agent: False # Sets How the Action Space is defined: # "NODE" diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml index 9382a2b5..3e0a3e2f 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -9,7 +9,7 @@ agent_identifier: STABLE_BASELINES3_A2C # RED AGENT IDENTIFIER # RANDOM or NONE -red_agent_identifier: "RANDOM" +random_red_agent: True # Sets How the Action Space is defined: # "NODE" diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 6e88e7cb..7995dfe8 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -21,7 +21,7 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." - red_agent_identifier: str = "RANDOM" + random_red_agent: bool = False "Creates Random Red Agent Attacks" action_type: ActionType = ActionType.ANY diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index eb0bc5de..5cb85afd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -277,7 +277,7 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - if self.training_config.red_agent_identifier == "RANDOM": + if self.training_config.random_red_agent: self.create_random_red_agent() # Reset counters and totals diff --git a/tests/config/random_agent_main_config.yaml b/tests/config/random_agent_main_config.yaml deleted file mode 100644 index d2d18bbc..00000000 --- a/tests/config/random_agent_main_config.yaml +++ /dev/null @@ -1,96 +0,0 @@ -# Main Config File - -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC -# -red_agent_identifier: RANDOM -# Sets How the Action Space is defined: -# "NODE" -# "ACL" -# "ANY" node and acl actions -action_type: ANY -# Number of episodes to run per session -num_episodes: 1 -# Number of time_steps per episode -num_steps: 5 -# Time delay between steps (for generic agents) -time_delay: 1 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] - -# Environment config values -# The high value for the observation space -observation_space_high_value: 1_000_000_000 - -# Reward values -# Generic -all_ok: 0 -# Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 -# Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 -# Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 -# IER status -red_ier_running: -5 -green_ier_blocked: -10 - -# Patching / Reset durations -os_patching_duration: 5 # The time taken to patch the OS -node_reset_duration: 5 # The time taken to reset a node (hardware) -service_patching_duration: 5 # The time taken to patch a service -file_system_repairing_limit: 5 # The time take to repair the file system -file_system_restoring_limit: 5 # The time take to restore the file system -file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 476a08f1..6b06dbb1 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -55,6 +55,8 @@ def test_random_red_agent_behaviour(): session_path=session_path, timestamp_str=timestamp_str, ) + # set red_agent_ + env.training_config.random_red_agent = True training_config = env.training_config training_config.num_steps = env.episode_steps From a7913487b8e15b32cfce56363b40f9ad5b46444d Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 13:36:14 +0100 Subject: [PATCH 21/21] #1522: create_random_red_agent -> _create_random_red_agent + converting NodeStateInstructionRed into a dataclass --- src/primaite/environment/primaite_env.py | 4 ++-- .../nodes/node_state_instruction_red.py | 20 +++---------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5cb85afd..823c11fe 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -278,7 +278,7 @@ class Primaite(Env): # Create a random red agent to use for this episode if self.training_config.random_red_agent: - self.create_random_red_agent() + self._create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -1236,7 +1236,7 @@ class Primaite(Env): combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict - def create_random_red_agent(self): + def _create_random_red_agent(self): """Decide on random red agent for the episode to be called in env.reset().""" # Reset the current red iers and red node pol self.red_iers = {} diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 2f7d0622..4272ce24 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -1,8 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" +from dataclasses import dataclass + from primaite.common.enums import NodePOLType +@dataclass() class NodeStateInstructionRed(object): """The Node State Instruction class.""" @@ -137,20 +140,3 @@ class NodeStateInstructionRed(object): The source node service state """ return self.source_node_service_state - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"id={self.id}, " - f"start_step={self.start_step}, " - f"end_step={self.end_step}, " - f"target_node_id={self.target_node_id}, " - f"initiator={self.initiator}, " - f"pol_type={self.pol_type}, " - f"service_name={self.service_name}, " - f"state={self.state}, " - f"source_node_id={self.source_node_id}, " - f"source_node_service={self.source_node_service}, " - f"source_node_service_state={self.source_node_service_state}" - f")" - )