diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml index 3e0a3e2f..96243daf 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -7,8 +7,10 @@ # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C -# RED AGENT IDENTIFIER -# RANDOM or NONE +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False random_red_agent: True # Sets How the Action Space is defined: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index c80c36ec..36632155 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -427,7 +427,12 @@ class Primaite(Env): for link_key, link_value in self.links.items(): _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - _LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) + print( + " Protocol: " + + protocol.get_name().name + + ", Load: " + + str(protocol.get_load()) + ) def interpret_action_and_apply(self, _action): """ @@ -437,16 +442,21 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes + if self.training_config.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) - elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6 + elif ( + len(self.action_dict[_action]) == 6 + ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) - elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 + elif ( + len(self.action_dict[_action]) == 4 + ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: - _LOGGER.error("Invalid action type found") + logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): """ @@ -510,8 +520,7 @@ class Primaite(Env): elif property_action == 1: # Patch (valid action if it's good or compromised) node.set_service_state( - self.services_list[service_index], - SoftwareState.PATCHING, + self.services_list[service_index], SoftwareState.PATCHING ) else: # Node is not of Service Type @@ -709,7 +718,8 @@ class Primaite(Env): _LOGGER.error(f"Invalid item_type: {item_type}") pass - _LOGGER.debug("Environment configuration loaded") + _LOGGER.info("Environment configuration loaded") + print("Environment configuration loaded") def create_node(self, item): """ @@ -1166,12 +1176,7 @@ class Primaite(Env): # Use MAX to ensure we get them all for node_action in range(4): for service_state in range(self.num_services): - action = [ - node, - node_property, - node_action, - service_state, - ] + action = [node, node_property, node_action, service_state] # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): actions[action_key] = action @@ -1221,7 +1226,11 @@ class Primaite(Env): # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} + new_node_action_dict = { + k + len(acl_action_dict) - 1: v + for k, v in node_action_dict.items() + if k != 0 + } # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} @@ -1235,7 +1244,8 @@ class Primaite(Env): # Decide how many nodes become compromised node_list = list(self.nodes.values()) - computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + computers = [node for node in node_list if + node.node_type == NodeType.COMPUTER] max_num_nodes_compromised = len( computers ) # only computers can become compromised @@ -1250,7 +1260,7 @@ class Primaite(Env): # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1261,13 +1271,15 @@ class Primaite(Env): _LOGGER.error(msg) raise Exception(msg) - servers = [node for node in node_list if node.node_type == NodeType.SERVER] + servers = [node for node in node_list if + node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint(2, max_step_compromised + 1) # step compromised + _start_step = randint(2, + max_step_compromised + 1) # step compromised pol_service_name = choice(list(node.services.keys())) source_node_service = choice(list(source_node.services.values())) @@ -1292,7 +1304,8 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) + ier_start_step = randint(_start_step + 2, + int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith @@ -1315,15 +1328,16 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: # If still none found choose from all servers - possible_ier_destinations = [server.node_id for server in servers] + possible_ier_destinations = [server.node_id for server in + servers] ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml new file mode 100644 index 00000000..800fe808 --- /dev/null +++ b/tests/config/test_random_red_main_config.yaml @@ -0,0 +1,112 @@ +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: DUMMY + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: True + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 2 +# Number of time_steps per episode +num_steps: 15 +# Time delay between steps (for generic agents) +time_delay: 1 + +# Type of session to be run (TRAINING or EVALUATION) +session_type: EVAL +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 6b06dbb1..8cf60236 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,68 +1,29 @@ -from datetime import datetime +import pytest from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.environment.primaite_env import Primaite from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_temp_session_path - -def run_generic(env, config_values): - """Run against a generic agent.""" - # Reset the environment at the start of the episode - env.reset() - for episode in range(0, config_values.num_episodes): - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - # action = env.action_space.sample() - action = 0 - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Reset the environment at the end of the episode - env.reset() - - env.close() - - -def test_random_red_agent_behaviour(): - """ - Test that hardware state is penalised at each step. - - When the initial state is OFF compared to reference state which is ON. - """ +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "test_random_red_main_config.yaml", + data_manipulation_config_path(), + ] + ], + indirect=True, +) +def test_random_red_agent_behaviour(temp_primaite_session): + """Test that red agent POL is randomised each episode.""" list_of_node_instructions = [] - # RUN TWICE so we can make sure that red agent is randomised - for i in range(2): - """Takes a config path and returns the created instance of Primaite.""" - session_timestamp: datetime = datetime.now() - session_path = _get_temp_session_path(session_timestamp) + with temp_primaite_session as session: + session.evaluate() + list_of_node_instructions.append(session.env.red_node_pol) - timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - env = Primaite( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=data_manipulation_config_path(), - transaction_list=[], - session_path=session_path, - timestamp_str=timestamp_str, - ) - # set red_agent_ - env.training_config.random_red_agent = True - training_config = env.training_config - training_config.num_steps = env.episode_steps - - run_generic(env, training_config) - # add red pol instructions to list - list_of_node_instructions.append(env.red_node_pol) + session.evaluate() + list_of_node_instructions.append(session.env.red_node_pol) # compare instructions to make sure that red instructions are truly random for index, instruction in enumerate(list_of_node_instructions): @@ -73,5 +34,4 @@ def test_random_red_agent_behaviour(): print(f"{key} end step: {instruction.get_end_step()}") print(f"{key} target node id: {instruction.get_target_node_id()}") print("") - assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])