diff --git a/.gitignore b/.gitignore index eed6c903..5adbdc57 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,5 @@ dmypy.json # Cython debug symbols cython_debug/ + +.idea/ \ No newline at end of file diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index c3d408d2..9ac3d8e6 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -9,6 +9,7 @@ from typing import Dict, Tuple, Union import networkx as nx import numpy as np +import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -58,12 +59,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -275,8 +276,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - if self.training_config.red_agent_identifier == "RANDOM": - self.create_random_red_agent() + # if self.training_config.red_agent_identifier == "RANDOM": + # self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -380,6 +381,7 @@ class Primaite(Env): self.nodes_post_pol, self.nodes_post_red, self.nodes_reference, + self.green_iers, self.green_iers_reference, self.red_iers, self.step_count, @@ -445,11 +447,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1247,14 +1249,17 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + # choose a random compromise node to be source of attacks + source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1264,57 +1269,50 @@ class Primaite(Env): for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised - _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _id = str(uuid.uuid4()) _start_step = np.random.randint( 2, max_step_compromised + 1 ) # step compromised - _end_step = _start_step # Become compromised on 1 step - _target_node_id = node.node_id - _pol_initiator = "DIRECT" - _pol_type = NodePOLType["SERVICE"] # All computers are service nodes pol_service_name = np.random.choice( - list(node.get_services().keys()) - ) # Random service may wish to change this, currently always TCP) - pol_protocol = pol_protocol - _pol_state = SoftwareState.COMPROMISED - is_entry_node = True # Assumes all computers in network are entry nodes - _pol_source_node_id = _pol_source_node_id - _pol_source_node_service = _pol_source_node_service - _pol_source_node_service_state = _pol_source_node_service_state + list(node.services.keys()) + ) + + source_node_service = np.random.choice( + list(source_node.services.values()) + ) + red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=_id, + _start_step=_start_step, + _end_step=_start_step, # only run for 1 step + _target_node_id=node.node_id, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=pol_service_name, + _pol_state=SoftwareState.COMPROMISED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[_id] = red_pol # 2: Launch the attack from compromised node - set the IER - ier_id = str(2000 + n) + ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode ier_start_step = np.random.randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps - ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node - ier_service = node.get_services()[ - pol_service_name - ] # same service as defined in the pol - ier_port = ier_service.get_port() + ier_service = node.services[pol_service_name] + ier_port = ier_service.port ier_mission_criticality = ( 0 # Red IER will never be important to green agent success ) @@ -1325,15 +1323,15 @@ class Primaite(Env): possible_ier_destinations = [ ier.get_dest_node_id() for ier in list(self.green_iers.values()) - if ier.get_source_node_id() == node.get_id() + if ier.get_source_node_id() == node.node_id ] if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), - server.ip_address, - ier_service, - ier_port, + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1347,37 +1345,42 @@ class Primaite(Env): ier_load, ier_protocol, ier_port, - ier_source_node_id, + node.node_id, ier_dest, ier_mission_criticality, ) - # 3: Make sure the targetted node can be set to overwhelmed - with node pol - # TODO remove duplicate red pol for same targetted service - must take into account start step + overwhelm_pol = red_pol + overwhelm_pol.id = str(uuid.uuid4()) + overwhelm_pol.end_step = self.episode_steps - o_pol_id = str(3000 + n) - o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - o_pol_end_step = ( - self.episode_steps - ) # Can become compromised at any timestep after start - o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - o_pol_service_name = ( - ier_protocol # Same protocol/service as the IER uses to attack - ) - o_pol_new_state = SoftwareState["OVERWHELMED"] - o_pol_entry_node = False # Assumes servers are not entry nodes + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # # TODO remove duplicate red pol for same targetted service - must take into account start step + # + o_pol_id = str(uuid.uuid4()) + # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + # o_pol_end_step = ( + # self.episode_steps + # ) # Can become compromised at any timestep after start + # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + # o_pol_service_name = ( + # ier_protocol # Same protocol/service as the IER uses to attack + # ) + # o_pol_new_state = SoftwareState["OVERWHELMED"] + # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=o_pol_id, + _start_step=ier_start_step, + _end_step=self.episode_steps, + _target_node_id=ier_dest, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=ier_protocol, + _pol_state=SoftwareState.OVERWHELMED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[o_pol_id] = o_red_pol