From e0f3d61f6511181e861cfe4bc27d46cb18c0fba7 Mon Sep 17 00:00:00 2001 From: Brian Kanyora Date: Thu, 22 Jun 2023 15:34:13 +0100 Subject: [PATCH 1/7] feature\1522: Create random red agent behaviour. --- src/primaite/config/training_config.py | 17 +- src/primaite/environment/primaite_env.py | 173 ++++++++++++++++-- .../nodes/node_state_instruction_red.py | 17 ++ tests/config/random_agent_main_config.yaml | 96 ++++++++++ tests/test_red_random_agent_behaviour.py | 74 ++++++++ 5 files changed, 356 insertions(+), 21 deletions(-) create mode 100644 tests/config/random_agent_main_config.yaml create mode 100644 tests/test_red_random_agent_behaviour.py diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..6e88e7cb 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Union, Optional +from typing import Any, Dict, Final, Optional, Union import yaml @@ -21,6 +21,9 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." + red_agent_identifier: str = "RANDOM" + "Creates Random Red Agent Attacks" + action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -167,8 +170,7 @@ def main_training_config_path() -> Path: return path -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -213,9 +215,7 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -227,10 +227,7 @@ def convert_legacy_training_config_dict( don't have action_type values. :return: The converted training config dict. """ - config_dict = { - "num_steps": num_steps, - "action_type": action_type - } + config_dict = {"num_steps": num_steps, "action_type": action_type} for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..9161fa43 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -177,7 +175,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -238,7 +235,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -275,6 +274,10 @@ class Primaite(Env): # Does this for both live and reference nodes self.reset_environment() + # Create a random red agent to use for this episode + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() + # Reset counters and totals self.total_reward = 0 self.step_count = 0 @@ -379,7 +382,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1033,7 +1036,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. @@ -1216,3 +1218,152 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict + + def create_random_red_agent(self): + """Decide on random red agent for the episode to be called in env.reset().""" + + # Reset the current red iers and red node pol + self.red_iers = {} + self.red_node_pol = {} + + # Decide how many nodes become compromised + node_list = list(self.nodes.values()) + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len( + computers + ) # only computers can become compromised + # random select between 1 and max_num_nodes_compromised + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + + # Decide which of the nodes to compromise + nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + + # For each of the nodes to be compromised decide which step they become compromised + max_step_compromised = ( + self.episode_steps // 2 + ) # always compromise in first half of episode + + # Bandwidth for all links + bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + servers = [node for node in node_list if node.node_type == NodeType.SERVER] + + for n, node in enumerate(nodes_to_be_compromised): + # 1: Use Node PoL to set node to compromised + + _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _start_step = np.random.randint( + 2, max_step_compromised + 1 + ) # step compromised + _end_step = _start_step # Become compromised on 1 step + _target_node_id = node.node_id + _pol_initiator = "DIRECT" + _pol_type = NodePOLType["SERVICE"] # All computers are service nodes + pol_service_name = np.random.choice( + list(node.get_services().keys()) + ) # Random service may wish to change this, currently always TCP) + pol_protocol = pol_protocol + _pol_state = SoftwareState.COMPROMISED + is_entry_node = True # Assumes all computers in network are entry nodes + _pol_source_node_id = _pol_source_node_id + _pol_source_node_service = _pol_source_node_service + _pol_source_node_service_state = _pol_source_node_service_state + red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + + self.red_node_pol[_id] = red_pol + + # 2: Launch the attack from compromised node - set the IER + + ier_id = str(2000 + n) + # Launch the attack after node is compromised, and not right at the end of the episode + ier_start_step = np.random.randint( + _start_step + 2, int(self.episode_steps * 0.8) + ) + ier_end_step = self.episode_steps + ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith + ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + bandwidths + ) + ier_protocol = pol_service_name # Same protocol as compromised node + ier_service = node.get_services()[ + pol_service_name + ] # same service as defined in the pol + ier_port = ier_service.get_port() + ier_mission_criticality = ( + 0 # Red IER will never be important to green agent success + ) + # We choose a node to attack based on the first that applies: + # a. Green IERs, select dest node of the red ier based on dest node of green IER + # b. Attack a random server that doesn't have a DENY acl rule in default config + # c. Attack a random server + possible_ier_destinations = [ + ier.get_dest_node_id() + for ier in list(self.green_iers.values()) + if ier.get_source_node_id() == node.get_id() + ] + if len(possible_ier_destinations) < 1: + for server in servers: + if not self.acl.is_blocked( + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, + ): + possible_ier_destinations.append(server.node_id) + if len(possible_ier_destinations) < 1: + # If still none found choose from all servers + possible_ier_destinations = [server.node_id for server in servers] + ier_dest = np.random.choice(possible_ier_destinations) + self.red_iers[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + ier_source_node_id, + ier_dest, + ier_mission_criticality, + ) + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # TODO remove duplicate red pol for same targetted service - must take into account start step + + o_pol_id = str(3000 + n) + o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + o_pol_end_step = ( + self.episode_steps + ) # Can become compromised at any timestep after start + o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + o_pol_service_name = ( + ier_protocol # Same protocol/service as the IER uses to attack + ) + o_pol_new_state = SoftwareState["OVERWHELMED"] + o_pol_entry_node = False # Assumes servers are not entry nodes + o_red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..9ae917e9 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -137,3 +137,20 @@ class NodeStateInstructionRed(object): The source node service state """ return self.source_node_service_state + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"start_step={self.start_step}, " + f"end_step={self.end_step}, " + f"target_node_id={self.target_node_id}, " + f"initiator={self.initiator}, " + f"pol_type={self.pol_type}, " + f"service_name={self.service_name}, " + f"state={self.state}, " + f"source_node_id={self.source_node_id}, " + f"source_node_service={self.source_node_service}, " + f"source_node_service_state={self.source_node_service_state}" + f")" + ) \ No newline at end of file diff --git a/tests/config/random_agent_main_config.yaml b/tests/config/random_agent_main_config.yaml new file mode 100644 index 00000000..d2d18bbc --- /dev/null +++ b/tests/config/random_agent_main_config.yaml @@ -0,0 +1,96 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: GENERIC +# +red_agent_identifier: RANDOM +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: ANY +# Number of episodes to run per session +num_episodes: 1 +# Number of time_steps per episode +num_steps: 5 +# Time delay between steps (for generic agents) +time_delay: 1 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1_000_000_000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py new file mode 100644 index 00000000..a86e32c1 --- /dev/null +++ b/tests/test_red_random_agent_behaviour.py @@ -0,0 +1,74 @@ +from datetime import time, datetime + +from primaite.environment.primaite_env import Primaite +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_temp_session_path + + +def run_generic(env, config_values): + """Run against a generic agent.""" + # Reset the environment at the start of the episode + env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + # action = env.action_space.sample() + action = 0 + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + env.reset() + + env.close() + + +def test_random_red_agent_behaviour(): + """ + Test that hardware state is penalised at each step. + + When the initial state is OFF compared to reference state which is ON. + """ + list_of_node_instructions = [] + for i in range(2): + + """Takes a config path and returns the created instance of Primaite.""" + session_timestamp: datetime = datetime.now() + session_path = _get_temp_session_path(session_timestamp) + + timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + env = Primaite( + training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + transaction_list=[], + session_path=session_path, + timestamp_str=timestamp_str, + ) + training_config = env.training_config + training_config.num_steps = env.episode_steps + + # TOOD: This needs t be refactored to happen outside. Should be part of + # a main Session class. + if training_config.agent_identifier == "GENERIC": + run_generic(env, training_config) + all_red_actions = env.red_node_pol + list_of_node_instructions.append(all_red_actions) + + # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) + print(list_of_node_instructions[0]["1"].get_start_step()) + print(list_of_node_instructions[0]["1"].get_end_step()) + print(list_of_node_instructions[0]["1"].get_target_node_id()) + print(list_of_node_instructions[1]["1"].get_start_step()) + print(list_of_node_instructions[1]["1"].get_end_step()) + print(list_of_node_instructions[1]["1"].get_target_node_id()) + assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) From f61d50a96f88c6cd18ab47c9bab1fb52982d4b35 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 29 Jun 2023 15:03:11 +0100 Subject: [PATCH 2/7] #1522: fixing create random red agent function --- .gitignore | 2 + src/primaite/environment/primaite_env.py | 149 ++++++++++++----------- 2 files changed, 78 insertions(+), 73 deletions(-) diff --git a/.gitignore b/.gitignore index eed6c903..5adbdc57 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,5 @@ dmypy.json # Cython debug symbols cython_debug/ + +.idea/ \ No newline at end of file diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index c3d408d2..9ac3d8e6 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -9,6 +9,7 @@ from typing import Dict, Tuple, Union import networkx as nx import numpy as np +import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -58,12 +59,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -275,8 +276,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - if self.training_config.red_agent_identifier == "RANDOM": - self.create_random_red_agent() + # if self.training_config.red_agent_identifier == "RANDOM": + # self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -380,6 +381,7 @@ class Primaite(Env): self.nodes_post_pol, self.nodes_post_red, self.nodes_reference, + self.green_iers, self.green_iers_reference, self.red_iers, self.step_count, @@ -445,11 +447,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1247,14 +1249,17 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + # choose a random compromise node to be source of attacks + source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1264,57 +1269,50 @@ class Primaite(Env): for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised - _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _id = str(uuid.uuid4()) _start_step = np.random.randint( 2, max_step_compromised + 1 ) # step compromised - _end_step = _start_step # Become compromised on 1 step - _target_node_id = node.node_id - _pol_initiator = "DIRECT" - _pol_type = NodePOLType["SERVICE"] # All computers are service nodes pol_service_name = np.random.choice( - list(node.get_services().keys()) - ) # Random service may wish to change this, currently always TCP) - pol_protocol = pol_protocol - _pol_state = SoftwareState.COMPROMISED - is_entry_node = True # Assumes all computers in network are entry nodes - _pol_source_node_id = _pol_source_node_id - _pol_source_node_service = _pol_source_node_service - _pol_source_node_service_state = _pol_source_node_service_state + list(node.services.keys()) + ) + + source_node_service = np.random.choice( + list(source_node.services.values()) + ) + red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=_id, + _start_step=_start_step, + _end_step=_start_step, # only run for 1 step + _target_node_id=node.node_id, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=pol_service_name, + _pol_state=SoftwareState.COMPROMISED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[_id] = red_pol # 2: Launch the attack from compromised node - set the IER - ier_id = str(2000 + n) + ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode ier_start_step = np.random.randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps - ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node - ier_service = node.get_services()[ - pol_service_name - ] # same service as defined in the pol - ier_port = ier_service.get_port() + ier_service = node.services[pol_service_name] + ier_port = ier_service.port ier_mission_criticality = ( 0 # Red IER will never be important to green agent success ) @@ -1325,15 +1323,15 @@ class Primaite(Env): possible_ier_destinations = [ ier.get_dest_node_id() for ier in list(self.green_iers.values()) - if ier.get_source_node_id() == node.get_id() + if ier.get_source_node_id() == node.node_id ] if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), - server.ip_address, - ier_service, - ier_port, + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1347,37 +1345,42 @@ class Primaite(Env): ier_load, ier_protocol, ier_port, - ier_source_node_id, + node.node_id, ier_dest, ier_mission_criticality, ) - # 3: Make sure the targetted node can be set to overwhelmed - with node pol - # TODO remove duplicate red pol for same targetted service - must take into account start step + overwhelm_pol = red_pol + overwhelm_pol.id = str(uuid.uuid4()) + overwhelm_pol.end_step = self.episode_steps - o_pol_id = str(3000 + n) - o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - o_pol_end_step = ( - self.episode_steps - ) # Can become compromised at any timestep after start - o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - o_pol_service_name = ( - ier_protocol # Same protocol/service as the IER uses to attack - ) - o_pol_new_state = SoftwareState["OVERWHELMED"] - o_pol_entry_node = False # Assumes servers are not entry nodes + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # # TODO remove duplicate red pol for same targetted service - must take into account start step + # + o_pol_id = str(uuid.uuid4()) + # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + # o_pol_end_step = ( + # self.episode_steps + # ) # Can become compromised at any timestep after start + # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + # o_pol_service_name = ( + # ier_protocol # Same protocol/service as the IER uses to attack + # ) + # o_pol_new_state = SoftwareState["OVERWHELMED"] + # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=o_pol_id, + _start_step=ier_start_step, + _end_step=self.episode_steps, + _target_node_id=ier_dest, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=ier_protocol, + _pol_state=SoftwareState.OVERWHELMED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[o_pol_id] = o_red_pol From 3e691b4f4611309e17049d7e7f7f41b72fc312e6 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 30 Jun 2023 10:37:23 +0100 Subject: [PATCH 3/7] #1522: remove numpy randomisation + added random red agent config --- .../training/training_config_main.yaml | 5 + .../training_config_random_red_agent.yaml | 99 +++++++++++++++++++ src/primaite/environment/primaite_env.py | 39 +++----- 3 files changed, 118 insertions(+), 25 deletions(-) create mode 100644 src/primaite/config/_package_data/training/training_config_random_red_agent.yaml diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..3fe668e2 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -6,6 +6,11 @@ # "STABLE_BASELINES3_A2C" # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "NONE" + # Sets How the Action Space is defined: # "NODE" # "ACL" diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml new file mode 100644 index 00000000..9382a2b5 --- /dev/null +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -0,0 +1,99 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "RANDOM" + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 9ac3d8e6..e592e21f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,6 +5,7 @@ import csv import logging from datetime import datetime from pathlib import Path +from random import randint, choice, uniform, sample from typing import Dict, Tuple, Union import networkx as nx @@ -276,8 +277,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - # if self.training_config.red_agent_identifier == "RANDOM": - # self.create_random_red_agent() + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -1249,13 +1250,13 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) + num_nodes_to_compromise = randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise - nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + nodes_to_be_compromised = sample(computers, num_nodes_to_compromise) # choose a random compromise node to be source of attacks - source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + source_node = choice(nodes_to_be_compromised) # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( @@ -1270,14 +1271,14 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = np.random.randint( + _start_step = randint( 2, max_step_compromised + 1 ) # step compromised - pol_service_name = np.random.choice( + pol_service_name = choice( list(node.services.keys()) ) - source_node_service = np.random.choice( + source_node_service = choice( list(source_node.services.values()) ) @@ -1301,13 +1302,13 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = np.random.randint( + ier_start_step = randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + ier_load = uniform(0.4, 0.8) * choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node @@ -1328,7 +1329,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), + node.ip_address, server.ip_address, ier_service, ier_port, @@ -1337,7 +1338,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: # If still none found choose from all servers possible_ier_destinations = [server.node_id for server in servers] - ier_dest = np.random.choice(possible_ier_destinations) + ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, ier_start_step, @@ -1354,22 +1355,10 @@ class Primaite(Env): overwhelm_pol.id = str(uuid.uuid4()) overwhelm_pol.end_step = self.episode_steps - # 3: Make sure the targetted node can be set to overwhelmed - with node pol # # TODO remove duplicate red pol for same targetted service - must take into account start step - # + o_pol_id = str(uuid.uuid4()) - # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - # o_pol_end_step = ( - # self.episode_steps - # ) # Can become compromised at any timestep after start - # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - # o_pol_service_name = ( - # ier_protocol # Same protocol/service as the IER uses to attack - # ) - # o_pol_new_state = SoftwareState["OVERWHELMED"] - # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( _id=o_pol_id, _start_step=ier_start_step, From 4299170ce42e68e392f5be2e1ef646c62546c971 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 09:46:52 +0100 Subject: [PATCH 4/7] #1522: added a check for existing links in laydown + test that checks if red agent instructions are random --- src/primaite/environment/primaite_env.py | 6 ++++ tests/test_red_random_agent_behaviour.py | 36 ++++++++++++------------ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e592e21f..58932c4c 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1265,6 +1265,12 @@ class Primaite(Env): # Bandwidth for all links bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + + if len(bandwidths) < 1: + msg = "Random red agent cannot be used on a network without any links" + _LOGGER.error(msg) + raise Exception(msg) + servers = [node for node in node_list if node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index a86e32c1..c9189c26 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,5 +1,6 @@ -from datetime import time, datetime +from datetime import datetime +from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -24,9 +25,6 @@ def run_generic(env, config_values): if done: break - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - # Reset the environment at the end of the episode env.reset() @@ -40,6 +38,8 @@ def test_random_red_agent_behaviour(): When the initial state is OFF compared to reference state which is ON. """ list_of_node_instructions = [] + + # RUN TWICE so we can make sure that red agent is randomised for i in range(2): """Takes a config path and returns the created instance of Primaite.""" @@ -49,7 +49,7 @@ def test_random_red_agent_behaviour(): timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, timestamp_str=timestamp_str, @@ -57,18 +57,18 @@ def test_random_red_agent_behaviour(): training_config = env.training_config training_config.num_steps = env.episode_steps - # TOOD: This needs t be refactored to happen outside. Should be part of - # a main Session class. - if training_config.agent_identifier == "GENERIC": - run_generic(env, training_config) - all_red_actions = env.red_node_pol - list_of_node_instructions.append(all_red_actions) + run_generic(env, training_config) + # add red pol instructions to list + list_of_node_instructions.append(env.red_node_pol) + + # compare instructions to make sure that red instructions are truly random + for index, instruction in enumerate(list_of_node_instructions): + for key in list_of_node_instructions[index].keys(): + instruction: NodeInstructionRed = list_of_node_instructions[index][key] + print(f"run {index}") + print(f"{key} start step: {instruction.get_start_step()}") + print(f"{key} end step: {instruction.get_end_step()}") + print(f"{key} target node id: {instruction.get_target_node_id()}") + print("") - # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) - print(list_of_node_instructions[0]["1"].get_start_step()) - print(list_of_node_instructions[0]["1"].get_end_step()) - print(list_of_node_instructions[0]["1"].get_target_node_id()) - print(list_of_node_instructions[1]["1"].get_start_step()) - print(list_of_node_instructions[1]["1"].get_end_step()) - print(list_of_node_instructions[1]["1"].get_target_node_id()) assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) From 6c4a538b41988869b5fa9d1f35515d969299e031 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 10:08:25 +0100 Subject: [PATCH 5/7] #1522: run pre-commit --- .gitignore | 2 +- src/primaite/environment/primaite_env.py | 55 ++++++++----------- .../nodes/node_state_instruction_red.py | 2 +- tests/test_red_random_agent_behaviour.py | 7 ++- 4 files changed, 28 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 5adbdc57..b65d1fd8 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,4 @@ dmypy.json # Cython debug symbols cython_debug/ -.idea/ \ No newline at end of file +.idea/ diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 58932c4c..eb0bc5de 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,14 +3,14 @@ import copy import csv import logging +import uuid as uuid from datetime import datetime from pathlib import Path -from random import randint, choice, uniform, sample +from random import choice, randint, sample, uniform from typing import Dict, Tuple, Union import networkx as nx import numpy as np -import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -60,12 +60,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -448,11 +448,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1238,7 +1238,6 @@ class Primaite(Env): def create_random_red_agent(self): """Decide on random red agent for the episode to be called in env.reset().""" - # Reset the current red iers and red node pol self.red_iers = {} self.red_node_pol = {} @@ -1260,7 +1259,7 @@ class Primaite(Env): # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1277,16 +1276,10 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint( - 2, max_step_compromised + 1 - ) # step compromised - pol_service_name = choice( - list(node.services.keys()) - ) + _start_step = randint(2, max_step_compromised + 1) # step compromised + pol_service_name = choice(list(node.services.keys())) - source_node_service = choice( - list(source_node.services.values()) - ) + source_node_service = choice(list(source_node.services.values())) red_pol = NodeStateInstructionRed( _id=_id, @@ -1299,7 +1292,7 @@ class Primaite(Env): _pol_state=SoftwareState.COMPROMISED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[_id] = red_pol @@ -1308,15 +1301,11 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint( - _start_step + 2, int(self.episode_steps * 0.8) - ) + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = uniform(0.4, 0.8) * choice( - bandwidths - ) + ier_load = uniform(0.4, 0.8) * choice(bandwidths) ier_protocol = pol_service_name # Same protocol as compromised node ier_service = node.services[pol_service_name] ier_port = ier_service.port @@ -1335,10 +1324,10 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1376,6 +1365,6 @@ class Primaite(Env): _pol_state=SoftwareState.OVERWHELMED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 9ae917e9..2f7d0622 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -153,4 +153,4 @@ class NodeStateInstructionRed(object): f"source_node_service={self.source_node_service}, " f"source_node_service_state={self.source_node_service_state}" f")" - ) \ No newline at end of file + ) diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index c9189c26..476a08f1 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -2,6 +2,7 @@ from datetime import datetime from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite +from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -41,14 +42,14 @@ def test_random_red_agent_behaviour(): # RUN TWICE so we can make sure that red agent is randomised for i in range(2): - """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() session_path = _get_temp_session_path(session_timestamp) timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( - training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, @@ -64,7 +65,7 @@ def test_random_red_agent_behaviour(): # compare instructions to make sure that red instructions are truly random for index, instruction in enumerate(list_of_node_instructions): for key in list_of_node_instructions[index].keys(): - instruction: NodeInstructionRed = list_of_node_instructions[index][key] + instruction: NodeStateInstructionRed = list_of_node_instructions[index][key] print(f"run {index}") print(f"{key} start step: {instruction.get_start_step()}") print(f"{key} end step: {instruction.get_end_step()}") From 0943e9511b55074f6f2721312231c840e2f243cb Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 12:18:58 +0100 Subject: [PATCH 6/7] #1522: refactor red_agent_identifier -> random_red_agent so that it is a boolean + documentation --- docs/source/config.rst | 4 + .../training/training_config_main.yaml | 2 +- .../training_config_random_red_agent.yaml | 2 +- src/primaite/config/training_config.py | 2 +- src/primaite/environment/primaite_env.py | 2 +- tests/config/random_agent_main_config.yaml | 96 ------------------- tests/test_red_random_agent_behaviour.py | 2 + 7 files changed, 10 insertions(+), 100 deletions(-) delete mode 100644 tests/config/random_agent_main_config.yaml diff --git a/docs/source/config.rst b/docs/source/config.rst index 74898ec1..fa58e6cf 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -28,6 +28,10 @@ The environment config file consists of the following attributes: * STABLE_BASELINES3_PPO - Use a SB3 PPO agent * STABLE_BASELINES3_A2C - use a SB3 A2C agent +* **random_red_agent** [bool] + + Determines if the session should be run with a random red agent + * **action_type** [enum] Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 3fe668e2..8f035d41 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -9,7 +9,7 @@ agent_identifier: STABLE_BASELINES3_A2C # RED AGENT IDENTIFIER # RANDOM or NONE -red_agent_identifier: "NONE" +random_red_agent: False # Sets How the Action Space is defined: # "NODE" diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml index 9382a2b5..3e0a3e2f 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -9,7 +9,7 @@ agent_identifier: STABLE_BASELINES3_A2C # RED AGENT IDENTIFIER # RANDOM or NONE -red_agent_identifier: "RANDOM" +random_red_agent: True # Sets How the Action Space is defined: # "NODE" diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 6e88e7cb..7995dfe8 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -21,7 +21,7 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." - red_agent_identifier: str = "RANDOM" + random_red_agent: bool = False "Creates Random Red Agent Attacks" action_type: ActionType = ActionType.ANY diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index eb0bc5de..5cb85afd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -277,7 +277,7 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - if self.training_config.red_agent_identifier == "RANDOM": + if self.training_config.random_red_agent: self.create_random_red_agent() # Reset counters and totals diff --git a/tests/config/random_agent_main_config.yaml b/tests/config/random_agent_main_config.yaml deleted file mode 100644 index d2d18bbc..00000000 --- a/tests/config/random_agent_main_config.yaml +++ /dev/null @@ -1,96 +0,0 @@ -# Main Config File - -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC -# -red_agent_identifier: RANDOM -# Sets How the Action Space is defined: -# "NODE" -# "ACL" -# "ANY" node and acl actions -action_type: ANY -# Number of episodes to run per session -num_episodes: 1 -# Number of time_steps per episode -num_steps: 5 -# Time delay between steps (for generic agents) -time_delay: 1 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] - -# Environment config values -# The high value for the observation space -observation_space_high_value: 1_000_000_000 - -# Reward values -# Generic -all_ok: 0 -# Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 -# Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 -# Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 -# IER status -red_ier_running: -5 -green_ier_blocked: -10 - -# Patching / Reset durations -os_patching_duration: 5 # The time taken to patch the OS -node_reset_duration: 5 # The time taken to reset a node (hardware) -service_patching_duration: 5 # The time taken to patch a service -file_system_repairing_limit: 5 # The time take to repair the file system -file_system_restoring_limit: 5 # The time take to restore the file system -file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 476a08f1..6b06dbb1 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -55,6 +55,8 @@ def test_random_red_agent_behaviour(): session_path=session_path, timestamp_str=timestamp_str, ) + # set red_agent_ + env.training_config.random_red_agent = True training_config = env.training_config training_config.num_steps = env.episode_steps From cb9d40579f4c5352350e52bfd66b615b35559ae5 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 13:36:14 +0100 Subject: [PATCH 7/7] #1522: create_random_red_agent -> _create_random_red_agent + converting NodeStateInstructionRed into a dataclass --- src/primaite/environment/primaite_env.py | 4 ++-- .../nodes/node_state_instruction_red.py | 20 +++---------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5cb85afd..823c11fe 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -278,7 +278,7 @@ class Primaite(Env): # Create a random red agent to use for this episode if self.training_config.random_red_agent: - self.create_random_red_agent() + self._create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -1236,7 +1236,7 @@ class Primaite(Env): combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict - def create_random_red_agent(self): + def _create_random_red_agent(self): """Decide on random red agent for the episode to be called in env.reset().""" # Reset the current red iers and red node pol self.red_iers = {} diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 2f7d0622..4272ce24 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -1,8 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" +from dataclasses import dataclass + from primaite.common.enums import NodePOLType +@dataclass() class NodeStateInstructionRed(object): """The Node State Instruction class.""" @@ -137,20 +140,3 @@ class NodeStateInstructionRed(object): The source node service state """ return self.source_node_service_state - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"id={self.id}, " - f"start_step={self.start_step}, " - f"end_step={self.end_step}, " - f"target_node_id={self.target_node_id}, " - f"initiator={self.initiator}, " - f"pol_type={self.pol_type}, " - f"service_name={self.service_name}, " - f"state={self.state}, " - f"source_node_id={self.source_node_id}, " - f"source_node_service={self.source_node_service}, " - f"source_node_service_state={self.source_node_service_state}" - f")" - )