diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index dd45907d..8bfdca02 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -25,6 +25,12 @@ steps: versionSpec: '$(python.version)' displayName: 'Use Python $(python.version)' +- script: | + python -m pip install pre-commit + pre-commit install + pre-commit run --all-files + displayName: 'Run pre-commits' + - script: | python -m pip install --upgrade pip==23.0.1 pip install wheel==0.38.4 --upgrade diff --git a/docs/source/config.rst b/docs/source/config.rst index 5410a877..01f1e325 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -28,6 +28,10 @@ The environment config file consists of the following attributes: * STABLE_BASELINES3_PPO - Use a SB3 PPO agent * STABLE_BASELINES3_A2C - use a SB3 A2C agent +* **random_red_agent** [bool] + + Determines if the session should be run with a random red agent + * **action_type** [enum] Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 4f639f11..a59b2361 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session: * Timestamp * Episode number * Step number - * Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y - * Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y + * Initial observation space (what the blue agent observed when it decided its action) * Reward value - * Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X + * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X **Diagrams** diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..b4bfa75e 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -6,11 +6,23 @@ # "STABLE_BASELINES3_A2C" # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +random_red_agent: False + # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session num_episodes: 10 # Number of time_steps per episode diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml new file mode 100644 index 00000000..3e0a3e2f --- /dev/null +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -0,0 +1,99 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +random_red_agent: True + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 3f9d3eb1..5ec4d942 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -21,6 +21,9 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." + random_red_agent: bool = False + "Creates Random Red Agent Attacks" + action_type: ActionType = ActionType.ANY "The ActionType to use." diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..81ddaaf5 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? + self.structure: List[str] return NotImplemented @abstractmethod @@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC): """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented + @abstractmethod + def generate_structure(self) -> List[str]: + """Return a list of labels for the components of the flattened observation space.""" + return NotImplemented + class NodeLinkTable(AbstractObservationComponent): """Table with nodes and links as rows and hardware/software status as cols. @@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent): # 3. Initialise Observation with zeroes self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent): protocol_index += 1 item_index += 1 + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + nodes = self.env.nodes.values() + links = self.env.links.values() + + structure = [] + + for i, node in enumerate(nodes): + node_id = node.node_id + node_labels = [ + f"node_{node_id}_id", + f"node_{node_id}_hardware_status", + f"node_{node_id}_os_status", + f"node_{node_id}_fs_status", + ] + for j, serv in enumerate(self.env.services_list): + node_labels.append(f"node_{node_id}_service_{serv}_status") + + structure.extend(node_labels) + + for i, link in enumerate(links): + link_id = link.id + link_labels = [ + f"link_{link_id}_id", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + ] + for j, serv in enumerate(self.env.services_list): + link_labels.append(f"link_{link_id}_service_{serv}_load") + + structure.extend(link_labels) + return structure + class NodeStatuses(AbstractObservationComponent): """Flat list of nodes' hardware, OS, file system, and service states. @@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() def update(self): """Update the observation based on current environment state. @@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent): ) self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + services = self.env.services_list + + structure = [] + for _, node in self.env.nodes.items(): + node_id = node.node_id + structure.append(f"node_{node_id}_hardware_state_NONE") + for state in HardwareState: + structure.append(f"node_{node_id}_hardware_state_{state.name}") + structure.append(f"node_{node_id}_software_state_NONE") + for state in SoftwareState: + structure.append(f"node_{node_id}_software_state_{state.name}") + structure.append(f"node_{node_id}_file_system_state_NONE") + for state in FileSystemState: + structure.append(f"node_{node_id}_file_system_state_{state.name}") + for service in services: + structure.append(f"node_{node_id}_service_{service}_state_NONE") + for state in SoftwareState: + structure.append( + f"node_{node_id}_service_{service}_state_{state.name}" + ) + return structure + class LinkTrafficLevels(AbstractObservationComponent): """Flat list of traffic levels encoded into banded categories. @@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + structure = [] + for _, link in self.env.links.items(): + link_id = link.id + if self._combine_service_traffic: + protocols = ["overall"] + else: + protocols = [protocol.name for protocol in link.protocol_list] + + for p in protocols: + for i in range(self._quantisation_levels): + structure.append(f"link_{link_id}_{p}_traffic_level_{i}") + return structure + class ObservationsHandler: """Component-based observation space handler. @@ -311,8 +395,17 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] - self.space: spaces.Space - self.current_observation: Union[Tuple[np.ndarray], np.ndarray] + + # internal the observation space (unflattened version of space if flatten=True) + self._space: spaces.Space + # flattened version of the observation space + self._flat_space: spaces.Space + + self._observation: Union[Tuple[np.ndarray], np.ndarray] + # used for transactions and when flatten=true + self._flat_observation: np.ndarray + + self.flatten: bool = False def update_obs(self): """Fetch fresh information about the environment.""" @@ -321,12 +414,11 @@ class ObservationsHandler: obs.update() current_obs.append(obs.current_observation) - # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: - self.current_observation = current_obs[0] + self._observation = current_obs[0] else: - self.current_observation = tuple(current_obs) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self._observation = tuple(current_obs) + self._flat_observation = spaces.flatten(self._space, self._observation) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -353,12 +445,31 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - # If there is only one component, don't use a tuple space, just pass through that component's space. + # if there are multiple components, build a composite tuple space if len(component_spaces) == 1: - self.space = component_spaces[0] + self._space = component_spaces[0] else: - self.space = spaces.Tuple(component_spaces) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self._space = spaces.Tuple(component_spaces) + if len(component_spaces) > 0: + self._flat_space = spaces.flatten_space(self._space) + else: + self._flat_space = spaces.Box(0, 1, (0,)) + + @property + def space(self): + """Observation space, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_space + else: + return self._space + + @property + def current_observation(self): + """Current observation, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_observation + else: + return self._observation @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -388,6 +499,9 @@ class ObservationsHandler: # Instantiate the handler handler = cls() + if obs_space_config.get("flatten"): + handler.flatten = True + for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] @@ -401,3 +515,17 @@ class ObservationsHandler: handler.update_obs() return handler + + def describe_structure(self): + """Create a list of names for the features of the obs space. + + The order of labels follows the flattened version of the space. + """ + # as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have + # to fake it. each component has to just hard-code the expected label order after flattening... + + labels = [] + for obs_comp in self.registered_obs_components: + labels.extend(obs_comp.structure) + + return labels diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 23890613..ce092cbd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,8 +3,10 @@ import copy import csv import logging +import uuid as uuid from datetime import datetime from pathlib import Path +from random import choice, randint, sample, uniform from typing import Dict, Tuple, Union import networkx as nx @@ -197,7 +199,6 @@ class Primaite(Env): try: plt.tight_layout() nx.draw_networkx(self.network, with_labels=True) - # now = datetime.now() # current date and time file_path = session_path / f"network_{timestamp_str}.png" plt.savefig(file_path, format="PNG") @@ -281,6 +282,10 @@ class Primaite(Env): # Does this for both live and reference nodes self.reset_environment() + # Create a random red agent to use for this episode + if self.training_config.random_red_agent: + self._create_random_red_agent() + # Reset counters and totals self.total_reward = 0 self.step_count = 0 @@ -325,7 +330,8 @@ class Primaite(Env): datetime.now(), self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) + transaction.set_obs_space(self.obs_handler._flat_observation) + # Load the action space into the transaction transaction.set_action_space(copy.deepcopy(action)) @@ -406,8 +412,6 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() - # Load the new observation space into the transaction - transaction.set_obs_space_post(copy.deepcopy(self.env_obs)) # 8. Add the transaction to the list of transactions self.transaction_list.append(copy.deepcopy(transaction)) @@ -1240,3 +1244,136 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict + + def _create_random_red_agent(self): + """Decide on random red agent for the episode to be called in env.reset().""" + # Reset the current red iers and red node pol + self.red_iers = {} + self.red_node_pol = {} + + # Decide how many nodes become compromised + node_list = list(self.nodes.values()) + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len( + computers + ) # only computers can become compromised + # random select between 1 and max_num_nodes_compromised + num_nodes_to_compromise = randint(1, max_num_nodes_compromised) + + # Decide which of the nodes to compromise + nodes_to_be_compromised = sample(computers, num_nodes_to_compromise) + + # choose a random compromise node to be source of attacks + source_node = choice(nodes_to_be_compromised) + + # For each of the nodes to be compromised decide which step they become compromised + max_step_compromised = ( + self.episode_steps // 2 + ) # always compromise in first half of episode + + # Bandwidth for all links + bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + + if len(bandwidths) < 1: + msg = "Random red agent cannot be used on a network without any links" + _LOGGER.error(msg) + raise Exception(msg) + + servers = [node for node in node_list if node.node_type == NodeType.SERVER] + + for n, node in enumerate(nodes_to_be_compromised): + # 1: Use Node PoL to set node to compromised + + _id = str(uuid.uuid4()) + _start_step = randint(2, max_step_compromised + 1) # step compromised + pol_service_name = choice(list(node.services.keys())) + + source_node_service = choice(list(source_node.services.values())) + + red_pol = NodeStateInstructionRed( + _id=_id, + _start_step=_start_step, + _end_step=_start_step, # only run for 1 step + _target_node_id=node.node_id, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=pol_service_name, + _pol_state=SoftwareState.COMPROMISED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state, + ) + + self.red_node_pol[_id] = red_pol + + # 2: Launch the attack from compromised node - set the IER + + ier_id = str(uuid.uuid4()) + # Launch the attack after node is compromised, and not right at the end of the episode + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) + ier_end_step = self.episode_steps + + # Randomise the load, as a percentage of a random link bandwith + ier_load = uniform(0.4, 0.8) * choice(bandwidths) + ier_protocol = pol_service_name # Same protocol as compromised node + ier_service = node.services[pol_service_name] + ier_port = ier_service.port + ier_mission_criticality = ( + 0 # Red IER will never be important to green agent success + ) + # We choose a node to attack based on the first that applies: + # a. Green IERs, select dest node of the red ier based on dest node of green IER + # b. Attack a random server that doesn't have a DENY acl rule in default config + # c. Attack a random server + possible_ier_destinations = [ + ier.get_dest_node_id() + for ier in list(self.green_iers.values()) + if ier.get_source_node_id() == node.node_id + ] + if len(possible_ier_destinations) < 1: + for server in servers: + if not self.acl.is_blocked( + node.ip_address, + server.ip_address, + ier_service, + ier_port, + ): + possible_ier_destinations.append(server.node_id) + if len(possible_ier_destinations) < 1: + # If still none found choose from all servers + possible_ier_destinations = [server.node_id for server in servers] + ier_dest = choice(possible_ier_destinations) + self.red_iers[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + node.node_id, + ier_dest, + ier_mission_criticality, + ) + + overwhelm_pol = red_pol + overwhelm_pol.id = str(uuid.uuid4()) + overwhelm_pol.end_step = self.episode_steps + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # # TODO remove duplicate red pol for same targetted service - must take into account start step + + o_pol_id = str(uuid.uuid4()) + o_red_pol = NodeStateInstructionRed( + _id=o_pol_id, + _start_step=ier_start_step, + _end_step=self.episode_steps, + _target_node_id=ier_dest, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=ier_protocol, + _pol_state=SoftwareState.OVERWHELMED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state, + ) + self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index aa9e4503..1a1a0770 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -78,8 +78,8 @@ def calculate_reward_function( start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - reference_blocked = reference_ier.get_is_running() - live_blocked = ier_value.get_is_running() + reference_blocked = not reference_ier.get_is_running() + live_blocked = not ier_value.get_is_running() ier_reward = ( config_values.green_ier_blocked * ier_value.get_mission_criticality() ) diff --git a/src/primaite/main.py b/src/primaite/main.py index 8483f383..f315cd34 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -354,6 +354,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, transaction_list=transaction_list, session_path=session_dir, timestamp_str=timestamp_str, + obs_space_description=env.obs_handler.describe_structure(), ) print("Updating Session Metadata file...") diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..4272ce24 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -1,8 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" +from dataclasses import dataclass + from primaite.common.enums import NodePOLType +@dataclass() class NodeStateInstructionRed(object): """The Node State Instruction class.""" diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index d4a5c8c8..324592c3 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -190,14 +190,14 @@ class ServiceNode(ActiveNode): service_value.reduce_patching_count() def update_resetting_status(self): - """Updates the resetting counter for any service that are resetting.""" + """Update resetting counter and set software state if it reached 0.""" super().update_resetting_status() if self.resetting_count <= 0: for service in self.services.values(): service.software_state = SoftwareState.GOOD def update_booting_status(self): - """Updates the booting counter for any service that are booting up.""" + """Update booting counter and set software to good if it reached 0.""" super().update_booting_status() if self.booting_count <= 0: for service in self.services.values(): diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index a4ce48e3..39236217 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -20,23 +20,14 @@ class Transaction(object): self.episode_number = _episode_number self.step_number = _step_number - def set_obs_space_pre(self, _obs_space_pre): + def set_obs_space(self, _obs_space): """ Sets the observation space (pre). Args: _obs_space_pre: The observation space before any actions are taken """ - self.obs_space_pre = _obs_space_pre - - def set_obs_space_post(self, _obs_space_post): - """ - Sets the observation space (post). - - Args: - _obs_space_post: The observation space after any actions are taken - """ - self.obs_space_post = _obs_space_post + self.obs_space = _obs_space def set_reward(self, _reward): """ diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 11e68af8..4e364f0b 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space): return [str(_action_space)] -def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): - """ - Turns observation space into a string array so it can be saved to csv. - - Args: - _obs_space: The observation space - _obs_assets: The number of assets (i.e. nodes or links) in the observation space - _obs_features: The number of features associated with the asset - """ - return_array = [] - for x in range(_obs_assets): - for y in range(_obs_features): - return_array.append(str(_obs_space[x][y])) - - return return_array - - -def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str): +def write_transaction_to_file( + transaction_list, + session_path: Path, + timestamp_str: str, + obs_space_description: list, +): """ Writes transaction logs to file to support training evaluation. @@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st # This will be tied into the PrimAITE Use Case so that they make sense template_transation = transaction_list[0] action_length = template_transation.action_space.size - obs_shape = template_transation.obs_space_post.shape - obs_assets = template_transation.obs_space_post.shape[0] - if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - obs_features = 1 - else: - obs_features = template_transation.obs_space_post.shape[1] + # obs_shape = template_transation.obs_space_post.shape + # obs_assets = template_transation.obs_space_post.shape[0] + # if len(obs_shape) == 1: + # bit of a workaround but I think the way transactions are written will change soon + # obs_features = 1 + # else: + # obs_features = template_transation.obs_space_post.shape[1] # Create the action space headers array action_header = [] @@ -70,16 +58,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st action_header.append("AS_" + str(x)) # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) + # obs_header_initial = [f"pre_{o}" for o in obs_space_description] + # obs_header_new = [f"post_{o}" for o in obs_space_description] # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new + header = header + action_header + obs_space_description try: filename = session_path / f"all_transactions_{timestamp_str}.csv" @@ -98,12 +82,7 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st csv_data = ( csv_data + turn_action_space_to_array(transaction.action_space) - + turn_obs_space_to_array( - transaction.obs_space_pre, obs_assets, obs_features - ) - + turn_obs_space_to_array( - transaction.obs_space_post, obs_assets, obs_features - ) + + transaction.obs_space.tolist() ) csv_writer.writerow(csv_data) diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py new file mode 100644 index 00000000..6b06dbb1 --- /dev/null +++ b/tests/test_red_random_agent_behaviour.py @@ -0,0 +1,77 @@ +from datetime import datetime + +from primaite.config.lay_down_config import data_manipulation_config_path +from primaite.environment.primaite_env import Primaite +from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_temp_session_path + + +def run_generic(env, config_values): + """Run against a generic agent.""" + # Reset the environment at the start of the episode + env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + # action = env.action_space.sample() + action = 0 + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Reset the environment at the end of the episode + env.reset() + + env.close() + + +def test_random_red_agent_behaviour(): + """ + Test that hardware state is penalised at each step. + + When the initial state is OFF compared to reference state which is ON. + """ + list_of_node_instructions = [] + + # RUN TWICE so we can make sure that red agent is randomised + for i in range(2): + """Takes a config path and returns the created instance of Primaite.""" + session_timestamp: datetime = datetime.now() + session_path = _get_temp_session_path(session_timestamp) + + timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + env = Primaite( + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=data_manipulation_config_path(), + transaction_list=[], + session_path=session_path, + timestamp_str=timestamp_str, + ) + # set red_agent_ + env.training_config.random_red_agent = True + training_config = env.training_config + training_config.num_steps = env.episode_steps + + run_generic(env, training_config) + # add red pol instructions to list + list_of_node_instructions.append(env.red_node_pol) + + # compare instructions to make sure that red instructions are truly random + for index, instruction in enumerate(list_of_node_instructions): + for key in list_of_node_instructions[index].keys(): + instruction: NodeStateInstructionRed = list_of_node_instructions[index][key] + print(f"run {index}") + print(f"{key} start step: {instruction.get_start_step()}") + print(f"{key} end step: {instruction.get_end_step()}") + print(f"{key} target node id: {instruction.get_target_node_id()}") + print("") + + assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) diff --git a/tests/test_reward.py b/tests/test_reward.py index c3fcdfc4..b8c92274 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -16,17 +16,26 @@ def test_rewards_are_being_penalised_at_each_step_function(): ) """ - On different steps (of the 13 in total) these are the following rewards for config_6 which are activated: - File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3) - Hardware State: onShouldBeOff = -2 (between Steps 4 & 6) - Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9) - Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12) + The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps: + On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute. + For example, turning the nodes' file system state to CORRUPT from its original state GOOD. + As a result these are the following rewards are activated: + File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 & 2) + Hardware State: off_should_be_on = -10 * 2 (on Steps 4 & 5) + Service State: compromised_should_be_good = -20 * 2 (on Steps 7 & 8) + Software State: compromised_should_be_good = -20 * 2 (on Steps 10 & 11) - Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26 - Step Count: 13 + The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice. + + Note: This test run inherits from conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded + to do NOTHING on every step. + We use Pattern of Lifes (PoLs) to change the nodes states and display that the agent is being penalised on all steps + where the live network node differs from the network reference node. + + Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120 + Step Count: 15 For the 4 steps where this occurs the average reward is: - Average Reward: 2 (26 / 13) + Average Reward: -8 (-120 / 15) """ - print("average reward", env.average_reward) assert env.average_reward == -8.0