diff --git a/.gitignore b/.gitignore index eed6c903..b65d1fd8 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,5 @@ dmypy.json # Cython debug symbols cython_debug/ + +.idea/ diff --git a/docs/source/config.rst b/docs/source/config.rst index 22fd0c01..71ade6c5 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -39,6 +39,10 @@ The environment config file consists of the following attributes: * RANDOM - A Stochastic random agent +* **random_red_agent** [bool] + + Determines if the session should be run with a random red agent + * **action_type** [enum] Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index a414bed9..91d9af11 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -25,6 +25,12 @@ deep_learning_framework: TF2 # "DUMMY" (primaite.agents.simple.DummyAgent) agent_identifier: PPO +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: False + # Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. # Options are: # "BASIC" (The current observation space only) diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml new file mode 100644 index 00000000..3e0a3e2f --- /dev/null +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -0,0 +1,99 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +random_red_agent: True + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 2ffc2a8c..bd73f65b 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -54,6 +54,9 @@ class TrainingConfig: hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL "The view the deterministic hard-coded agent has of the environment" + random_red_agent: bool = False + "Creates Random Red Agent Attacks" + action_type: ActionType = ActionType.ANY "The ActionType to use" diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index df51e21e..c80c36ec 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,8 +1,14 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy +import csv +import logging +import uuid as uuid +from datetime import datetime from pathlib import Path from typing import Dict, Final, Tuple, Union +from random import choice, randint, sample, uniform +from typing import Dict, Tuple, Union import networkx as nx import numpy as np @@ -272,6 +278,10 @@ class Primaite(Env): # Does this for both live and reference nodes self.reset_environment() + # Create a random red agent to use for this episode + if self.training_config.random_red_agent: + self._create_random_red_agent() + # Reset counters and totals self.total_reward = 0 self.step_count = 0 @@ -1216,3 +1226,136 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict + + def _create_random_red_agent(self): + """Decide on random red agent for the episode to be called in env.reset().""" + # Reset the current red iers and red node pol + self.red_iers = {} + self.red_node_pol = {} + + # Decide how many nodes become compromised + node_list = list(self.nodes.values()) + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len( + computers + ) # only computers can become compromised + # random select between 1 and max_num_nodes_compromised + num_nodes_to_compromise = randint(1, max_num_nodes_compromised) + + # Decide which of the nodes to compromise + nodes_to_be_compromised = sample(computers, num_nodes_to_compromise) + + # choose a random compromise node to be source of attacks + source_node = choice(nodes_to_be_compromised) + + # For each of the nodes to be compromised decide which step they become compromised + max_step_compromised = ( + self.episode_steps // 2 + ) # always compromise in first half of episode + + # Bandwidth for all links + bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + + if len(bandwidths) < 1: + msg = "Random red agent cannot be used on a network without any links" + _LOGGER.error(msg) + raise Exception(msg) + + servers = [node for node in node_list if node.node_type == NodeType.SERVER] + + for n, node in enumerate(nodes_to_be_compromised): + # 1: Use Node PoL to set node to compromised + + _id = str(uuid.uuid4()) + _start_step = randint(2, max_step_compromised + 1) # step compromised + pol_service_name = choice(list(node.services.keys())) + + source_node_service = choice(list(source_node.services.values())) + + red_pol = NodeStateInstructionRed( + _id=_id, + _start_step=_start_step, + _end_step=_start_step, # only run for 1 step + _target_node_id=node.node_id, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=pol_service_name, + _pol_state=SoftwareState.COMPROMISED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state, + ) + + self.red_node_pol[_id] = red_pol + + # 2: Launch the attack from compromised node - set the IER + + ier_id = str(uuid.uuid4()) + # Launch the attack after node is compromised, and not right at the end of the episode + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) + ier_end_step = self.episode_steps + + # Randomise the load, as a percentage of a random link bandwith + ier_load = uniform(0.4, 0.8) * choice(bandwidths) + ier_protocol = pol_service_name # Same protocol as compromised node + ier_service = node.services[pol_service_name] + ier_port = ier_service.port + ier_mission_criticality = ( + 0 # Red IER will never be important to green agent success + ) + # We choose a node to attack based on the first that applies: + # a. Green IERs, select dest node of the red ier based on dest node of green IER + # b. Attack a random server that doesn't have a DENY acl rule in default config + # c. Attack a random server + possible_ier_destinations = [ + ier.get_dest_node_id() + for ier in list(self.green_iers.values()) + if ier.get_source_node_id() == node.node_id + ] + if len(possible_ier_destinations) < 1: + for server in servers: + if not self.acl.is_blocked( + node.ip_address, + server.ip_address, + ier_service, + ier_port, + ): + possible_ier_destinations.append(server.node_id) + if len(possible_ier_destinations) < 1: + # If still none found choose from all servers + possible_ier_destinations = [server.node_id for server in servers] + ier_dest = choice(possible_ier_destinations) + self.red_iers[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + node.node_id, + ier_dest, + ier_mission_criticality, + ) + + overwhelm_pol = red_pol + overwhelm_pol.id = str(uuid.uuid4()) + overwhelm_pol.end_step = self.episode_steps + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # # TODO remove duplicate red pol for same targetted service - must take into account start step + + o_pol_id = str(uuid.uuid4()) + o_red_pol = NodeStateInstructionRed( + _id=o_pol_id, + _start_step=ier_start_step, + _end_step=self.episode_steps, + _target_node_id=ier_dest, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=ier_protocol, + _pol_state=SoftwareState.OVERWHELMED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state, + ) + self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..4272ce24 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -1,8 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" +from dataclasses import dataclass + from primaite.common.enums import NodePOLType +@dataclass() class NodeStateInstructionRed(object): """The Node State Instruction class.""" diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py new file mode 100644 index 00000000..6b06dbb1 --- /dev/null +++ b/tests/test_red_random_agent_behaviour.py @@ -0,0 +1,77 @@ +from datetime import datetime + +from primaite.config.lay_down_config import data_manipulation_config_path +from primaite.environment.primaite_env import Primaite +from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_temp_session_path + + +def run_generic(env, config_values): + """Run against a generic agent.""" + # Reset the environment at the start of the episode + env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + # action = env.action_space.sample() + action = 0 + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Reset the environment at the end of the episode + env.reset() + + env.close() + + +def test_random_red_agent_behaviour(): + """ + Test that hardware state is penalised at each step. + + When the initial state is OFF compared to reference state which is ON. + """ + list_of_node_instructions = [] + + # RUN TWICE so we can make sure that red agent is randomised + for i in range(2): + """Takes a config path and returns the created instance of Primaite.""" + session_timestamp: datetime = datetime.now() + session_path = _get_temp_session_path(session_timestamp) + + timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + env = Primaite( + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=data_manipulation_config_path(), + transaction_list=[], + session_path=session_path, + timestamp_str=timestamp_str, + ) + # set red_agent_ + env.training_config.random_red_agent = True + training_config = env.training_config + training_config.num_steps = env.episode_steps + + run_generic(env, training_config) + # add red pol instructions to list + list_of_node_instructions.append(env.red_node_pol) + + # compare instructions to make sure that red instructions are truly random + for index, instruction in enumerate(list_of_node_instructions): + for key in list_of_node_instructions[index].keys(): + instruction: NodeStateInstructionRed = list_of_node_instructions[index][key] + print(f"run {index}") + print(f"{key} start step: {instruction.get_start_step()}") + print(f"{key} end step: {instruction.get_end_step()}") + print(f"{key} target node id: {instruction.get_target_node_id()}") + print("") + + assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])