From e0f3d61f6511181e861cfe4bc27d46cb18c0fba7 Mon Sep 17 00:00:00 2001 From: Brian Kanyora Date: Thu, 22 Jun 2023 15:34:13 +0100 Subject: [PATCH] feature\1522: Create random red agent behaviour. --- src/primaite/config/training_config.py | 17 +- src/primaite/environment/primaite_env.py | 173 ++++++++++++++++-- .../nodes/node_state_instruction_red.py | 17 ++ tests/config/random_agent_main_config.yaml | 96 ++++++++++ tests/test_red_random_agent_behaviour.py | 74 ++++++++ 5 files changed, 356 insertions(+), 21 deletions(-) create mode 100644 tests/config/random_agent_main_config.yaml create mode 100644 tests/test_red_random_agent_behaviour.py diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..6e88e7cb 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Union, Optional +from typing import Any, Dict, Final, Optional, Union import yaml @@ -21,6 +21,9 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." + red_agent_identifier: str = "RANDOM" + "Creates Random Red Agent Attacks" + action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -167,8 +170,7 @@ def main_training_config_path() -> Path: return path -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -213,9 +215,7 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -227,10 +227,7 @@ def convert_legacy_training_config_dict( don't have action_type values. :return: The converted training config dict. """ - config_dict = { - "num_steps": num_steps, - "action_type": action_type - } + config_dict = {"num_steps": num_steps, "action_type": action_type} for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..9161fa43 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -177,7 +175,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -238,7 +235,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -275,6 +274,10 @@ class Primaite(Env): # Does this for both live and reference nodes self.reset_environment() + # Create a random red agent to use for this episode + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() + # Reset counters and totals self.total_reward = 0 self.step_count = 0 @@ -379,7 +382,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1033,7 +1036,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. @@ -1216,3 +1218,152 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict + + def create_random_red_agent(self): + """Decide on random red agent for the episode to be called in env.reset().""" + + # Reset the current red iers and red node pol + self.red_iers = {} + self.red_node_pol = {} + + # Decide how many nodes become compromised + node_list = list(self.nodes.values()) + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len( + computers + ) # only computers can become compromised + # random select between 1 and max_num_nodes_compromised + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + + # Decide which of the nodes to compromise + nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + + # For each of the nodes to be compromised decide which step they become compromised + max_step_compromised = ( + self.episode_steps // 2 + ) # always compromise in first half of episode + + # Bandwidth for all links + bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + servers = [node for node in node_list if node.node_type == NodeType.SERVER] + + for n, node in enumerate(nodes_to_be_compromised): + # 1: Use Node PoL to set node to compromised + + _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _start_step = np.random.randint( + 2, max_step_compromised + 1 + ) # step compromised + _end_step = _start_step # Become compromised on 1 step + _target_node_id = node.node_id + _pol_initiator = "DIRECT" + _pol_type = NodePOLType["SERVICE"] # All computers are service nodes + pol_service_name = np.random.choice( + list(node.get_services().keys()) + ) # Random service may wish to change this, currently always TCP) + pol_protocol = pol_protocol + _pol_state = SoftwareState.COMPROMISED + is_entry_node = True # Assumes all computers in network are entry nodes + _pol_source_node_id = _pol_source_node_id + _pol_source_node_service = _pol_source_node_service + _pol_source_node_service_state = _pol_source_node_service_state + red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + + self.red_node_pol[_id] = red_pol + + # 2: Launch the attack from compromised node - set the IER + + ier_id = str(2000 + n) + # Launch the attack after node is compromised, and not right at the end of the episode + ier_start_step = np.random.randint( + _start_step + 2, int(self.episode_steps * 0.8) + ) + ier_end_step = self.episode_steps + ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith + ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + bandwidths + ) + ier_protocol = pol_service_name # Same protocol as compromised node + ier_service = node.get_services()[ + pol_service_name + ] # same service as defined in the pol + ier_port = ier_service.get_port() + ier_mission_criticality = ( + 0 # Red IER will never be important to green agent success + ) + # We choose a node to attack based on the first that applies: + # a. Green IERs, select dest node of the red ier based on dest node of green IER + # b. Attack a random server that doesn't have a DENY acl rule in default config + # c. Attack a random server + possible_ier_destinations = [ + ier.get_dest_node_id() + for ier in list(self.green_iers.values()) + if ier.get_source_node_id() == node.get_id() + ] + if len(possible_ier_destinations) < 1: + for server in servers: + if not self.acl.is_blocked( + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, + ): + possible_ier_destinations.append(server.node_id) + if len(possible_ier_destinations) < 1: + # If still none found choose from all servers + possible_ier_destinations = [server.node_id for server in servers] + ier_dest = np.random.choice(possible_ier_destinations) + self.red_iers[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + ier_source_node_id, + ier_dest, + ier_mission_criticality, + ) + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # TODO remove duplicate red pol for same targetted service - must take into account start step + + o_pol_id = str(3000 + n) + o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + o_pol_end_step = ( + self.episode_steps + ) # Can become compromised at any timestep after start + o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + o_pol_service_name = ( + ier_protocol # Same protocol/service as the IER uses to attack + ) + o_pol_new_state = SoftwareState["OVERWHELMED"] + o_pol_entry_node = False # Assumes servers are not entry nodes + o_red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..9ae917e9 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -137,3 +137,20 @@ class NodeStateInstructionRed(object): The source node service state """ return self.source_node_service_state + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"start_step={self.start_step}, " + f"end_step={self.end_step}, " + f"target_node_id={self.target_node_id}, " + f"initiator={self.initiator}, " + f"pol_type={self.pol_type}, " + f"service_name={self.service_name}, " + f"state={self.state}, " + f"source_node_id={self.source_node_id}, " + f"source_node_service={self.source_node_service}, " + f"source_node_service_state={self.source_node_service_state}" + f")" + ) \ No newline at end of file diff --git a/tests/config/random_agent_main_config.yaml b/tests/config/random_agent_main_config.yaml new file mode 100644 index 00000000..d2d18bbc --- /dev/null +++ b/tests/config/random_agent_main_config.yaml @@ -0,0 +1,96 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: GENERIC +# +red_agent_identifier: RANDOM +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: ANY +# Number of episodes to run per session +num_episodes: 1 +# Number of time_steps per episode +num_steps: 5 +# Time delay between steps (for generic agents) +time_delay: 1 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1_000_000_000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py new file mode 100644 index 00000000..a86e32c1 --- /dev/null +++ b/tests/test_red_random_agent_behaviour.py @@ -0,0 +1,74 @@ +from datetime import time, datetime + +from primaite.environment.primaite_env import Primaite +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_temp_session_path + + +def run_generic(env, config_values): + """Run against a generic agent.""" + # Reset the environment at the start of the episode + env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + # action = env.action_space.sample() + action = 0 + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + env.reset() + + env.close() + + +def test_random_red_agent_behaviour(): + """ + Test that hardware state is penalised at each step. + + When the initial state is OFF compared to reference state which is ON. + """ + list_of_node_instructions = [] + for i in range(2): + + """Takes a config path and returns the created instance of Primaite.""" + session_timestamp: datetime = datetime.now() + session_path = _get_temp_session_path(session_timestamp) + + timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + env = Primaite( + training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + transaction_list=[], + session_path=session_path, + timestamp_str=timestamp_str, + ) + training_config = env.training_config + training_config.num_steps = env.episode_steps + + # TOOD: This needs t be refactored to happen outside. Should be part of + # a main Session class. + if training_config.agent_identifier == "GENERIC": + run_generic(env, training_config) + all_red_actions = env.red_node_pol + list_of_node_instructions.append(all_red_actions) + + # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) + print(list_of_node_instructions[0]["1"].get_start_step()) + print(list_of_node_instructions[0]["1"].get_end_step()) + print(list_of_node_instructions[0]["1"].get_target_node_id()) + print(list_of_node_instructions[1]["1"].get_start_step()) + print(list_of_node_instructions[1]["1"].get_end_step()) + print(list_of_node_instructions[1]["1"].get_target_node_id()) + assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])