diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..3fe668e2 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -6,6 +6,11 @@ # "STABLE_BASELINES3_A2C" # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "NONE" + # Sets How the Action Space is defined: # "NODE" # "ACL" diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml new file mode 100644 index 00000000..9382a2b5 --- /dev/null +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -0,0 +1,99 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "RANDOM" + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 9ac3d8e6..e592e21f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,6 +5,7 @@ import csv import logging from datetime import datetime from pathlib import Path +from random import randint, choice, uniform, sample from typing import Dict, Tuple, Union import networkx as nx @@ -276,8 +277,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - # if self.training_config.red_agent_identifier == "RANDOM": - # self.create_random_red_agent() + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -1249,13 +1250,13 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) + num_nodes_to_compromise = randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise - nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + nodes_to_be_compromised = sample(computers, num_nodes_to_compromise) # choose a random compromise node to be source of attacks - source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + source_node = choice(nodes_to_be_compromised) # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( @@ -1270,14 +1271,14 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = np.random.randint( + _start_step = randint( 2, max_step_compromised + 1 ) # step compromised - pol_service_name = np.random.choice( + pol_service_name = choice( list(node.services.keys()) ) - source_node_service = np.random.choice( + source_node_service = choice( list(source_node.services.values()) ) @@ -1301,13 +1302,13 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = np.random.randint( + ier_start_step = randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + ier_load = uniform(0.4, 0.8) * choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node @@ -1328,7 +1329,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), + node.ip_address, server.ip_address, ier_service, ier_port, @@ -1337,7 +1338,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: # If still none found choose from all servers possible_ier_destinations = [server.node_id for server in servers] - ier_dest = np.random.choice(possible_ier_destinations) + ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, ier_start_step, @@ -1354,22 +1355,10 @@ class Primaite(Env): overwhelm_pol.id = str(uuid.uuid4()) overwhelm_pol.end_step = self.episode_steps - # 3: Make sure the targetted node can be set to overwhelmed - with node pol # # TODO remove duplicate red pol for same targetted service - must take into account start step - # + o_pol_id = str(uuid.uuid4()) - # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - # o_pol_end_step = ( - # self.episode_steps - # ) # Can become compromised at any timestep after start - # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - # o_pol_service_name = ( - # ier_protocol # Same protocol/service as the IER uses to attack - # ) - # o_pol_new_state = SoftwareState["OVERWHELMED"] - # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( _id=o_pol_id, _start_step=ier_start_step,