From 3cd5864f25f966c9118b993e488ad74e08e0e764 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 26 May 2023 10:17:45 +0100 Subject: [PATCH 01/37] 1429 - created new branch from dev, added enums to enums.py, created agents package and utils.py file, added option to primaite_env.py for ANY action type and changed the action spaces are defined using ADSP branch --- src/primaite/agents/__init__.py | 0 src/primaite/agents/utils.py | 509 +++++++++++++++++++ src/primaite/common/enums.py | 26 + src/primaite/environment/primaite_env.py | 118 ++++- tests/config/single_action_space_config.yaml | 89 ++++ tests/test_single_action_space.py | 12 + 6 files changed, 735 insertions(+), 19 deletions(-) create mode 100644 src/primaite/agents/__init__.py create mode 100644 src/primaite/agents/utils.py create mode 100644 tests/config/single_action_space_config.yaml create mode 100644 tests/test_single_action_space.py diff --git a/src/primaite/agents/__init__.py b/src/primaite/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py new file mode 100644 index 00000000..d3924b24 --- /dev/null +++ b/src/primaite/agents/utils.py @@ -0,0 +1,509 @@ +import logging +import os.path +from datetime import datetime + +import numpy as np +import yaml + +from primaite.common.config_values_main import ConfigValuesMain +from primaite.common.enums import ( + ActionType, + HardwareState, + LinkStatus, + NodeHardwareAction, + NodePOLType, + NodeSoftwareAction, + SoftwareState, +) + + +def load_config_values(config_path): + """Loads the config values from the main config file into a config object.""" + config_file_main = open(config_path, "r") + config_data = yaml.safe_load(config_file_main) + # Create a config class + config_values = ConfigValuesMain() + + try: + # Generic + config_values.red_agent_identifier = config_data["redAgentIdentifier"] + config_values.action_type = ActionType[config_data["actionType"]] + config_values.config_filename_use_case = config_data["configFilename"] + # Reward values + # Generic + config_values.all_ok = float(config_data["allOk"]) + # Node Operating State + config_values.off_should_be_on = float(config_data["offShouldBeOn"]) + config_values.off_should_be_resetting = float( + config_data["offShouldBeResetting"] + ) + config_values.on_should_be_off = float(config_data["onShouldBeOff"]) + config_values.on_should_be_resetting = float(config_data["onShouldBeResetting"]) + config_values.resetting_should_be_on = float(config_data["resettingShouldBeOn"]) + config_values.resetting_should_be_off = float( + config_data["resettingShouldBeOff"] + ) + # Node O/S or Service State + config_values.good_should_be_patching = float( + config_data["goodShouldBePatching"] + ) + config_values.good_should_be_compromised = float( + config_data["goodShouldBeCompromised"] + ) + config_values.good_should_be_overwhelmed = float( + config_data["goodShouldBeOverwhelmed"] + ) + config_values.patching_should_be_good = float( + config_data["patchingShouldBeGood"] + ) + config_values.patching_should_be_compromised = float( + config_data["patchingShouldBeCompromised"] + ) + config_values.patching_should_be_overwhelmed = float( + config_data["patchingShouldBeOverwhelmed"] + ) + config_values.compromised_should_be_good = float( + config_data["compromisedShouldBeGood"] + ) + config_values.compromised_should_be_patching = float( + config_data["compromisedShouldBePatching"] + ) + config_values.compromised_should_be_overwhelmed = float( + config_data["compromisedShouldBeOverwhelmed"] + ) + config_values.compromised = float(config_data["compromised"]) + config_values.overwhelmed_should_be_good = float( + config_data["overwhelmedShouldBeGood"] + ) + config_values.overwhelmed_should_be_patching = float( + config_data["overwhelmedShouldBePatching"] + ) + config_values.overwhelmed_should_be_compromised = float( + config_data["overwhelmedShouldBeCompromised"] + ) + config_values.overwhelmed = float(config_data["overwhelmed"]) + # IER status + config_values.red_ier_running = float(config_data["redIerRunning"]) + config_values.green_ier_blocked = float(config_data["greenIerBlocked"]) + # Patching / Reset durations + config_values.os_patching_duration = int(config_data["osPatchingDuration"]) + config_values.node_reset_duration = int(config_data["nodeResetDuration"]) + config_values.service_patching_duration = int( + config_data["servicePatchingDuration"] + ) + + except Exception as e: + print(f"Could not save load config data: {e} ") + + return config_values + + +def configure_logging(log_name): + """Configures logging.""" + try: + now = datetime.now() # current date and time + time = now.strftime("%Y%m%d_%H%M%S") + filename = "/app/logs/" + log_name + "/" + time + ".log" + path = f"/app/logs/{log_name}" + is_dir = os.path.isdir(path) + if not is_dir: + os.makedirs(path) + logging.basicConfig( + filename=filename, + filemode="w", + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%d-%b-%y %H:%M:%S", + level=logging.INFO, + ) + except Exception as e: + print("ERROR: Could not start logging", e) + + +def transform_change_obs_readable(obs): + """Transform list of transactions to readable list of each observation property. + + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + """ + ids = [i for i in obs[:, 0]] + operating_states = [HardwareState(i).name for i in obs[:, 1]] + os_states = [SoftwareState(i).name for i in obs[:, 2]] + new_obs = [ids, operating_states, os_states] + + for service in range(3, obs.shape[1]): + # Links bit/s don't have a service state + service_states = [ + SoftwareState(i).name if i <= 4 else i for i in obs[:, service] + ] + new_obs.append(service_states) + + return new_obs + + +def transform_obs_readable(obs): + """ + Transform obs readable function. + + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]. + """ + changed_obs = transform_change_obs_readable(obs) + new_obs = list(zip(*changed_obs)) + # Convert list of tuples to list of lists + new_obs = [list(i) for i in new_obs] + + return new_obs + + +def convert_to_new_obs(obs, num_nodes=10): + """Convert original gym Box observation space to new multiDiscrete observation space.""" + # Remove ID columns, remove links and flatten to MultiDiscrete observation space + new_obs = obs[:num_nodes, 1:].flatten() + return new_obs + + +def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): + """ + Convert to old observation, links filled with 0's as no information is included in new observation space. + + example: + obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) + + new_obs = array([[ 1, 1, 1, 1], + [ 2, 1, 1, 1], + [ 3, 1, 1, 1], + ... + [20, 0, 0, 0]]) + """ + # Convert back to more readable, original format + reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) + + # Add empty links back and add node ID back + s = np.zeros( + [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) + s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back + s[:num_nodes, 1:] = reshaped_nodes # put values back in + new_obs = s + + # Add links back in + links = obs[-num_links:] + # Links will be added to the last protocol/service slot but they are not specific to that service + new_obs[num_nodes:, -1] = links + + return new_obs + + +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): + """Return string describing change between two observations. + + example: + obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) + obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) + output = 'ID 1: SERVICE 2 set to GOOD' + """ + obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) + obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) + list_of_changes = [] + for n, row in enumerate(obs1 - obs2): + if row.any() != 0: + relevant_changes = np.where(row != 0, obs2[n], -1) + relevant_changes[0] = obs2[n, 0] # ID is always relevant + is_link = relevant_changes[0] > num_nodes + desc = _describe_obs_change_helper(relevant_changes, is_link) + list_of_changes.append(desc) + + change_string = "\n ".join(list_of_changes) + if len(list_of_changes) > 0: + change_string = "\n " + change_string + return change_string + + +def _describe_obs_change_helper(obs_change, is_link): + """ + Helper funcion to describe what has changed. + + example: + [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" + + Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + """ + # Indexes where a change has occured, not including 0th index + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + # Node pol types, Indexes >= 3 are service nodes + node_pol_types = [ + NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) + for i in index_changed + ] + # Account for hardware states, software sattes and links + states = [ + LinkStatus(obs_change[i]).name + if is_link + else HardwareState(obs_change[i]).name + if i == 1 + else SoftwareState(obs_change[i]).name + for i in index_changed + ] + + if not is_link: + desc = f"ID {obs_change[0]}:" + for node_pol_type, state in list(zip(node_pol_types, states)): + desc = desc + " " + node_pol_type + " changed to " + state + "." + else: + desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." + + return desc + + +def transform_action_node_enum(action): + """ + Convert a node action from readable string format, to enumerated format. + + example: + [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] + """ + action_node_id = action[0] + action_node_property = NodePOLType[action[1]].value + + if action[1] == "OPERATING": + property_action = NodeHardwareAction[action[2]].value + elif action[1] == "OS" or action[1] == "SERVICE": + property_action = NodeSoftwareAction[action[2]].value + else: + property_action = 0 + + action_service_index = action[3] + + new_action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + + return new_action + + +def transform_action_node_readable(action): + """ + Convert a node action from enumerated format to readable format. + + example: + [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + """ + action_node_property = NodePOLType(action[1]).name + + if action_node_property == "OPERATING": + property_action = NodeHardwareAction(action[2]).name + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[ + 2 + ] <= 1: + property_action = NodeSoftwareAction(action[2]).name + else: + property_action = "NONE" + + new_action = [action[0], action_node_property, property_action, action[3]] + return new_action + + +def node_action_description(action): + """Generate string describing a node-based action.""" + if isinstance(action[1], (int, np.int64)): + # transform action to readable format + action = transform_action_node_readable(action) + + node_id = action[0] + node_property = action[1] + property_action = action[2] + service_id = action[3] + + if property_action == "NONE": + return "" + if node_property == "OPERATING" or node_property == "OS": + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + elif node_property == "SERVICE": + description = ( + f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" + ) + else: + return "" + + return description + + +def transform_action_acl_readable(action): + """ + Transform an ACL action to a more readable format. + + example: + [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + """ + action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} + action_permissions = {0: "DENY", 1: "ALLOW"} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, 0 means any, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == 0: + new_action[n + 2] = "ANY" + + return new_action + + +def transform_action_acl_enum(action): + """Convert a acl action from readable string format, to enumerated format.""" + action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} + action_permissions = {"DENY": 0, "ALLOW": 1} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == "ANY": + new_action[n + 2] = 0 + + new_action = np.array(new_action) + return new_action + + +def acl_action_description(action): + """Generate string describing a acl-based action.""" + if isinstance(action[0], (int, np.int64)): + # transform action to readable format + action = transform_action_acl_readable(action) + if action[0] == "NONE": + description = "NO ACL RULE APPLIED" + else: + description = ( + f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," + f" for protocol/service index {action[4]} on port index {action[5]}" + ) + + return description + + +def get_node_of_ip(ip, node_dict): + """ + Get the node ID of an IP address. + + node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) + """ + for node_key, node_value in node_dict.items(): + node_ip = node_value.get_ip_address() + if node_ip == ip: + return node_key + + +def is_valid_node_action(action): + """Is the node action an actual valid action. + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch + - Node already being in that state (turning an ON node ON) + """ + action_r = transform_action_node_readable(action) + + node_property = action_r[1] + node_action = action_r[2] + + if node_property == "NONE": + return False + if node_action == "NONE": + return False + if node_property == "OPERATING" and node_action == "PATCHING": + # Operating State cannot PATCH + return False + if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + # Software States can only do Nothing or Patch + return False + return True + + +def is_valid_acl_action(action): + """ + Is the ACL action an actual valid action. + + Only uses information about the action to determine if the action has an effect. + + Does NOT consider: + - Trying to create identical rules + - Trying to create a rule which is a subset of another rule (caused by "ANY") + """ + action_r = transform_action_acl_readable(action) + + action_decision = action_r[0] + action_permission = action_r[1] + action_source_id = action_r[2] + action_destination_id = action_r[3] + + if action_decision == "NONE": + return False + if ( + action_source_id == action_destination_id + and action_source_id != "ANY" + and action_destination_id != "ANY" + ): + # ACL rule towards itself + return False + if action_permission == "DENY": + # DENY is unnecessary, we can create and delete allow rules instead + # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. + return False + + return True + + +def is_valid_acl_action_extra(action): + """Harsher version of valid acl actions, does not allow action.""" + if is_valid_acl_action(action) is False: + return False + + action_r = transform_action_acl_readable(action) + action_protocol = action_r[4] + action_port = action_r[5] + + # Don't allow protocols or ports to be ANY + # in the future we might want to do the opposite, and only have ANY option for ports and service + if action_protocol == "ANY": + return False + if action_port == "ANY": + return False + + return True + + +def get_new_action(old_action, action_dict): + """Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + + old_action can be either node or acl action type. + """ + for key, val in action_dict.items(): + if list(val) == list(old_action): + return key + # Not all possible actions are included in dict, only valid action are + # if action is not in the dict, its an invalid action so return 0 + return 0 + + +def get_action_description(action, action_dict): + """Get a string describing/explaining what an action is doing in words.""" + action_array = action_dict[action] + if len(action_array) == 4: + # node actions have length 4 + action_description = node_action_description(action_array) + elif len(action_array) == 6: + # acl actions have length 6 + action_description = acl_action_description(action_array) + else: + # Should never happen + action_description = "Unrecognised action" + + return action_description diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0e00c9e4..0aebf2a4 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -91,3 +91,29 @@ class FileSystemState(Enum): DESTROYED = 3 REPAIRING = 4 RESTORING = 5 + + +class NodeHardwareAction(Enum): + """Node hardware action.""" + + NONE = 0 + ON = 1 + OFF = 2 + RESET = 3 + + +class NodeSoftwareAction(Enum): + """Node software action.""" + + NONE = 0 + PATCHING = 1 + + +class LinkStatus(Enum): + """Link traffic status.""" + + NONE = 0 + LOW = 1 + MEDIUM = 2 + HIGH = 3 + OVERLOAD = 4 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 99c7c09f..1c72bba0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -15,6 +15,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -232,15 +233,9 @@ class Primaite(Env): # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa - self.action_space = spaces.MultiDiscrete( - [ - self.num_nodes, - self.ACTION_SPACE_NODE_PROPERTY_VALUES, - self.ACTION_SPACE_NODE_ACTION_VALUES, - self.num_services, - ] - ) - else: + self.action_dict = self.create_node_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) + elif self.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) @@ -249,16 +244,12 @@ class Primaite(Env): # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) - self.action_space = spaces.MultiDiscrete( - [ - self.ACTION_SPACE_ACL_ACTION_VALUES, - self.ACTION_SPACE_ACL_PERMISSION_VALUES, - self.num_nodes + 1, - self.num_nodes + 1, - self.num_services + 1, - self.num_ports + 1, - ] - ) + self.action_dict = self.create_acl_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) + else: + _LOGGER.info("Action space type ANY selected") + self.action_dict = self.create_node_and_acl_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) # Set up a csv to store the results of the training try: @@ -1163,3 +1154,92 @@ class Primaite(Env): else: # Bad formatting pass + + def create_node_action_dict(self): + """ + Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. + + Note: Only actions that have the potential to change the state exist in the mapping (except for key 0) + + example return: + {0: [1, 0, 0, 0], + 1: [1, 1, 1, 0], + 2: [1, 1, 2, 0], + 3: [1, 1, 3, 0], + 4: [1, 2, 1, 0], + 5: [1, 3, 1, 0], + ... + } + + """ + # reserve 0 action to be a nothing action + actions = {0: [1, 0, 0, 0]} + + action_key = 1 + for node in range(1, self.num_nodes + 1): + # 4 node properties (NONE, OPERATING, OS, SERVICE) + for node_property in range(4): + # Node Actions either: + # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state + # Use MAX to ensure we get them all + for node_action in range(4): + for service_state in range(self.num_services): + action = [node, node_property, node_action, service_state] + # check to see if its a nothing aciton (has no effect) + if is_valid_node_action(action): + actions[action_key] = action + action_key += 1 + + return actions + + def create_acl_action_dict(self): + """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.""" + # reserve 0 action to be a nothing action + actions = {0: [0, 0, 0, 0, 0, 0]} + + action_key = 1 + # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE + for action_decision in range(3): + # 2 possible action permissions 0 = DENY, 1 = CREATE + for action_permission in range(2): + # Number of nodes + 1 (for any) + for source_ip in range(self.num_nodes + 1): + for dest_ip in range(self.num_nodes + 1): + for protocol in range(self.num_services + 1): + for port in range(self.num_ports + 1): + action = [ + action_decision, + action_permission, + source_ip, + dest_ip, + protocol, + port, + ] + # Check to see if its an action we want to include as possible i.e. not a nothing action + if is_valid_acl_action_extra(action): + actions[action_key] = action + action_key += 1 + + return actions + + def create_node_and_acl_action_dict(self): + """ + Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. + + The dictionary contains actions of both Node and ACL action types. + + """ + node_action_dict = self.create_node_action_dict() + acl_action_dict = self.create_acl_action_dict() + + # Change node keys to not overlap with acl keys + # Only 1 nothing action (key 0) is required, remove the other + new_node_action_dict = { + k + len(acl_action_dict) - 1: v + for k, v in node_action_dict.items() + if k != 0 + } + + # Combine the Node dict and ACL dict + combined_action_dict = {**acl_action_dict, **new_node_action_dict} + return combined_action_dict diff --git a/tests/config/single_action_space_config.yaml b/tests/config/single_action_space_config.yaml new file mode 100644 index 00000000..6f6bb4e6 --- /dev/null +++ b/tests/config/single_action_space_config.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: GENERIC +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Operating State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node O/S or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py new file mode 100644 index 00000000..3ec1dc2e --- /dev/null +++ b/tests/test_single_action_space.py @@ -0,0 +1,12 @@ +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_primaite_env_from_config + + +def test_single_action_space(): + """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_lay_down_config.yaml", + ) + print("Average Reward:", env.average_reward) From e2fb03b9bd2b91c826b24650ea0884da30f9f022 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 26 May 2023 14:29:02 +0100 Subject: [PATCH 02/37] 1429 - added code from ADSP branch to primaite_env.py and added NONE = 0 to NodePOLType in enums.py --- src/primaite/common/enums.py | 1 + src/primaite/environment/primaite_env.py | 38 ++++++++++++++++-------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0aebf2a4..20660e86 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -49,6 +49,7 @@ class SoftwareState(Enum): class NodePOLType(Enum): """Node Pattern of Life type enumeration.""" + NONE = 0 OPERATING = 1 OS = 2 SERVICE = 3 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 1c72bba0..0ebcd973 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -247,7 +247,7 @@ class Primaite(Env): self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info("Action space type ANY selected") + _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) @@ -453,8 +453,18 @@ class Primaite(Env): # At the moment, actions are only affecting nodes if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) - else: + elif self.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) + elif ( + len(self.action_dict[_action]) == 6 + ): # ACL actions in multidiscrete form have len 6 + self.apply_actions_to_acl(_action) + elif ( + len(self.action_dict[_action]) == 4 + ): # Node actions in multdiscrete (array) from have len 4 + self.apply_actions_to_nodes(_action) + else: + logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): """ @@ -463,10 +473,11 @@ class Primaite(Env): Args: _action: The action space from the agent """ - node_id = _action[0] - node_property = _action[1] - property_action = _action[2] - service_index = _action[3] + readable_action = self.action_dict[_action] + node_id = readable_action[0] + node_property = readable_action[1] + property_action = readable_action[2] + service_index = readable_action[3] # Check that the action is requesting a valid node try: @@ -552,12 +563,15 @@ class Primaite(Env): Args: _action: The action space from the agent """ - action_decision = _action[0] - action_permission = _action[1] - action_source_ip = _action[2] - action_destination_ip = _action[3] - action_protocol = _action[4] - action_port = _action[5] + # Convert discrete value back to multidiscrete + multidiscrete_action = self.action_dict[_action] + + action_decision = multidiscrete_action[0] + action_permission = multidiscrete_action[1] + action_source_ip = multidiscrete_action[2] + action_destination_ip = multidiscrete_action[3] + action_protocol = multidiscrete_action[4] + action_port = multidiscrete_action[5] if action_decision == 0: # It's decided to do nothing From 843f32bf718158590598f970b294af83b40df515 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 26 May 2023 14:50:15 +0100 Subject: [PATCH 03/37] Fix minor logic errors in main script --- src/primaite/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/main.py b/src/primaite/main.py index 0963fa7e..0f94c3f8 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -25,6 +25,7 @@ from primaite.transactions.transactions_to_file import write_transaction_to_file def run_generic(): """Run against a generic agent.""" for episode in range(0, config_values.num_episodes): + env.reset() for step in range(0, config_values.num_steps): # Send the observation space to the agent to get an action # TEMP - random action for now @@ -42,7 +43,6 @@ def run_generic(): time.sleep(config_values.time_delay / 1000) # Reset the environment at the end of the episode - env.reset() env.close() @@ -375,7 +375,7 @@ logging.info("Saving transaction logs...") write_transaction_to_file(transaction_list) -config_file_main.close +config_file_main.close() print("Finished") logging.info("Finished") From dd780b7451f8db7f1368275333ae361e50669332 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 30 May 2023 08:50:57 +0000 Subject: [PATCH 04/37] Make reward calculation consider red POL --- src/primaite/environment/primaite_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 99c7c09f..3fe7f0f6 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -398,7 +398,7 @@ class Primaite(Env): # 5. Calculate reward signal (for RL) reward = calculate_reward_function( self.nodes_post_pol, - self.nodes_post_blue, + self.nodes_post_red, self.nodes_reference, self.green_iers, self.red_iers, From 20d13f42a2b8a797a7ffcb970d0d1a653770ee29 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 31 May 2023 13:15:25 +0100 Subject: [PATCH 05/37] 1443 - added changes from ADSP to observation space in primaite_env.py --- src/primaite/environment/primaite_env.py | 41 +++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0ebcd973..84b485bd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -207,6 +207,7 @@ class Primaite(Env): # Calculate the number of items that need to be included in the # observation space + """ num_items = self.num_links + self.num_nodes # Set the number of observation parameters, being # of services plus id, # hardware state, file system state and SoftwareState (i.e. 4) @@ -221,6 +222,23 @@ class Primaite(Env): shape=self.observation_shape, dtype=np.int64, ) + """ + self.num_observation_parameters = ( + self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS + ) + # Define the observation space: + # There are: + # 4 Operating States (ON/OFF/RESETTING) + NONE (0) + # 4 OS States (GOOD/PATCHING/COMPROMISED) + NONE + # 5 Service States (NONE/GOOD/PATCHING/COMPROMISED/OVERWHELMED) + NONE + # There can be any number of services + # There are 5 node types No Traffic, Low Traffic, Medium Traffic, High Traffic, Overloaded/max traffic + self.observation_space = spaces.MultiDiscrete( + ([4, 4] + [5] * self.num_services) * self.num_nodes + [5] * self.num_links + ) + + # Define the observation shape + self.observation_shape = self.observation_space.sample().shape # This is the observation that is sent back via the rest and step functions self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) @@ -396,7 +414,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -678,6 +696,19 @@ class Primaite(Env): def update_environent_obs(self): """Updates the observation space based on the node and link status.""" + # Convert back to more readable, original format + reshaped_nodes = self.env_obs[: -self.num_links].reshape( + self.num_nodes, self.num_services + 2 + ) + + # Add empty links back and add node ID back + s = np.zeros( + [reshaped_nodes.shape[0] + self.num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) + s[:, 0] = range(1, self.num_nodes + self.num_links + 1) # Adding ID back + s[: self.num_nodes, 1:] = reshaped_nodes # put values back in + self.env_obs = s item_index = 0 # Do nodes first @@ -720,6 +751,13 @@ class Primaite(Env): protocol_index += 1 item_index += 1 + # Remove ID columns, remove links and flatten to 1D array + node_obs = self.env_obs[: self.num_nodes, 1:].flatten() + # Remove ID, remove all data except link traffic status + link_obs = self.env_obs[self.num_nodes :, 3:].flatten() + # Combine nodes and links + self.env_obs = np.append(node_obs, link_obs) + def load_config(self): """Loads config data in order to build the environment configuration.""" for item in self.config_data: @@ -1187,6 +1225,7 @@ class Primaite(Env): """ # reserve 0 action to be a nothing action + # Used to be {0: [1, 0, 0, 0]} actions = {0: [1, 0, 0, 0]} action_key = 1 From ae2f4d472ec78d67c4268352cd773fe097572bf2 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 31 May 2023 14:11:15 +0100 Subject: [PATCH 06/37] 1443 - reverted changes made to observation space and added config files for testing --- src/primaite/environment/primaite_env.py | 38 ------------------- .../single_action_space_lay_down_config.yaml | 29 ++++++++++++++ ...l => single_action_space_main_config.yaml} | 0 tests/test_single_action_space.py | 4 +- 4 files changed, 31 insertions(+), 40 deletions(-) create mode 100644 tests/config/single_action_space_lay_down_config.yaml rename tests/config/{single_action_space_config.yaml => single_action_space_main_config.yaml} (100%) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 84b485bd..49c45f3e 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -207,7 +207,6 @@ class Primaite(Env): # Calculate the number of items that need to be included in the # observation space - """ num_items = self.num_links + self.num_nodes # Set the number of observation parameters, being # of services plus id, # hardware state, file system state and SoftwareState (i.e. 4) @@ -222,23 +221,6 @@ class Primaite(Env): shape=self.observation_shape, dtype=np.int64, ) - """ - self.num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - # Define the observation space: - # There are: - # 4 Operating States (ON/OFF/RESETTING) + NONE (0) - # 4 OS States (GOOD/PATCHING/COMPROMISED) + NONE - # 5 Service States (NONE/GOOD/PATCHING/COMPROMISED/OVERWHELMED) + NONE - # There can be any number of services - # There are 5 node types No Traffic, Low Traffic, Medium Traffic, High Traffic, Overloaded/max traffic - self.observation_space = spaces.MultiDiscrete( - ([4, 4] + [5] * self.num_services) * self.num_nodes + [5] * self.num_links - ) - - # Define the observation shape - self.observation_shape = self.observation_space.sample().shape # This is the observation that is sent back via the rest and step functions self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) @@ -696,19 +678,6 @@ class Primaite(Env): def update_environent_obs(self): """Updates the observation space based on the node and link status.""" - # Convert back to more readable, original format - reshaped_nodes = self.env_obs[: -self.num_links].reshape( - self.num_nodes, self.num_services + 2 - ) - - # Add empty links back and add node ID back - s = np.zeros( - [reshaped_nodes.shape[0] + self.num_links, reshaped_nodes.shape[1] + 1], - dtype=np.int64, - ) - s[:, 0] = range(1, self.num_nodes + self.num_links + 1) # Adding ID back - s[: self.num_nodes, 1:] = reshaped_nodes # put values back in - self.env_obs = s item_index = 0 # Do nodes first @@ -751,13 +720,6 @@ class Primaite(Env): protocol_index += 1 item_index += 1 - # Remove ID columns, remove links and flatten to 1D array - node_obs = self.env_obs[: self.num_nodes, 1:].flatten() - # Remove ID, remove all data except link traffic status - link_obs = self.env_obs[self.num_nodes :, 3:].flatten() - # Combine nodes and links - self.env_obs = np.append(node_obs, link_obs) - def load_config(self): """Loads config data in order to build the environment configuration.""" for item in self.config_data: diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml new file mode 100644 index 00000000..058b790b --- /dev/null +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -0,0 +1,29 @@ +- itemType: ACTIONS + type: NODE +- itemType: STEPS + steps: 15 +- itemType: PORTS + portsList: + - port: '21' +- itemType: SERVICES + serviceList: + - name: ftp +- itemType: NODE + node_id: '1' + name: node + node_class: SERVICE + node_type: COMPUTER + priority: P1 + hardware_state: 'ON' + ip_address: 192.168.0.1 + software_state: GOOD + file_system_state: GOOD + services: + - name: ftp + port: '21' + state: GOOD +- itemType: POSITION + positions: + - node: '1' + x_pos: 309 + y_pos: 78 diff --git a/tests/config/single_action_space_config.yaml b/tests/config/single_action_space_main_config.yaml similarity index 100% rename from tests/config/single_action_space_config.yaml rename to tests/config/single_action_space_main_config.yaml diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 3ec1dc2e..8c87d57b 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -5,8 +5,8 @@ from tests.conftest import _get_primaite_env_from_config def test_single_action_space(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" env = _get_primaite_env_from_config( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + main_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_lay_down_config.yaml", + / "single_action_space_lay_down_config.yaml", ) print("Average Reward:", env.average_reward) From 4108f8036c0eac8e402593a3f3ace2368bb45dd4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 May 2023 17:03:53 +0100 Subject: [PATCH 07/37] Start creating observations module --- src/primaite/environment/observations.py | 227 +++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 src/primaite/environment/observations.py diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py new file mode 100644 index 00000000..40cc26a5 --- /dev/null +++ b/src/primaite/environment/observations.py @@ -0,0 +1,227 @@ +# """Module for handling configurable observation spaces in PrimAITE.""" +# import logging +# from abc import ABC, abstractmethod +# from enum import Enum + +# import numpy as np +# from gym import spaces + +# from primaite.common.enums import FileSystemState, HardwareState, SoftwareState +# from primaite.environment.primaite_env import Primaite +# from primaite.nodes.active_node import ActiveNode +# from primaite.nodes.service_node import ServiceNode + +# _LOGGER = logging.getLogger(__name__) + + +# class AbstractObservationComponent(ABC): +# """Represents a part of the PrimAITE observation space.""" +# @abstractmethod +# def __init__(self, env: Primaite): +# _LOGGER.info(f"Initialising {self} observation component") +# self.env: Primaite = env +# self.space: spaces.Space +# self.current_observation: np.ndarray # type might be too restrictive? +# return NotImplemented + +# @abstractmethod +# def update(self): +# """Look at the environment and update the current observation value""" +# self.current_observation = NotImplemented + +# # @abstractmethod +# # def export(self): +# # return NotImplemented + + +# class NodeLinkTable(AbstractObservationComponent): +# """Table with nodes/links as rows and hardware/software status as cols. + +# #todo: write full description + +# """ + +# _FIXED_PARAMETERS = 4 +# _MAX_VAL = 1_000_000 +# _DATA_TYPE = np.int64 + +# def __init__(self, env: Primaite): +# super().__init__() + +# # 1. Define the shape of your observation space component +# num_items = self.env.num_links + self.env.num_nodes +# num_columns = self.env.num_services + self._FIXED_PARAMETERS +# observation_shape = (num_items, num_columns) + +# # 2. Create Observation space +# self.space = spaces.Box( +# low=0, +# high=self._MAX_VAL, +# shape=observation_shape, +# dtype=self._DATA_TYPE, +# ) + +# # 3. Initialise Observation with zeroes +# self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) + +# def update_obs(self): +# item_index = 0 +# nodes = self.env.nodes +# links = self.env.links +# # Do nodes first +# for _, node in nodes.items(): +# self.current_observation[item_index][0] = int(node.node_id) +# self.current_observation[item_index][1] = node.hardware_state.value +# if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): +# self.current_observation[item_index][2] = node.software_state.value +# self.current_observation[item_index][ +# 3 +# ] = node.file_system_state_observed.value +# else: +# self.current_observation[item_index][2] = 0 +# self.current_observation[item_index][3] = 0 +# service_index = 4 +# if isinstance(node, ServiceNode): +# for service in self.env.services_list: +# if node.has_service(service): +# self.current_observation[item_index][ +# service_index +# ] = node.get_service_state(service).value +# else: +# self.current_observation[item_index][service_index] = 0 +# service_index += 1 +# else: +# # Not a service node +# for service in self.env.services_list: +# self.current_observation[item_index][service_index] = 0 +# service_index += 1 +# item_index += 1 + +# # Now do links +# for _, link in links.items(): +# self.current_observation[item_index][0] = int(link.get_id()) +# self.current_observation[item_index][1] = 0 +# self.current_observation[item_index][2] = 0 +# self.current_observation[item_index][3] = 0 +# protocol_list = link.get_protocol_list() +# protocol_index = 0 +# for protocol in protocol_list: +# self.current_observation[item_index][ +# protocol_index + 4 +# ] = protocol.get_load() +# protocol_index += 1 +# item_index += 1 + + +# class NodeStatuses(AbstractObservationComponent): +# _DATA_TYPE = np.int64 + +# def __init__(self): +# super().__init__() + +# # 1. Define the shape of your observation space component +# shape = [ +# len(HardwareState) + 1, +# len(SoftwareState) + 1, +# len(FileSystemState) + 1, +# ] +# services_shape = [len(SoftwareState) + 1] * self.env.num_services +# shape = shape + services_shape + +# # 2. Create Observation space +# self.space = spaces.MultiDiscrete(shape) + +# # 3. Initialise observation with zeroes +# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + +# def update_obs(self): +# obs = [] +# for _, node in self.env.nodes.items(): +# hardware_state = node.hardware_state.value +# software_state = 0 +# file_system_state = 0 +# service_states = [0] * self.env.num_services + +# if isinstance(node, ActiveNode): +# software_state = node.software_state.value +# file_system_state = node.file_system_state_observed.value + +# if isinstance(node, ServiceNode): +# for i, service in enumerate(self.env.services_list): +# if node.has_service(service): +# service_states[i] = node.get_service_state(service).value +# obs.extend([hardware_state, software_state, file_system_state, *service_states]) +# self.current_observation[:] = obs + + +# class LinkTrafficLevels(AbstractObservationComponent): +# _DATA_TYPE = np.int64 + +# def __init__( +# self, combine_service_traffic: bool = False, quantisation_levels: int = 5 +# ): +# super().__init__() +# self._combine_service_traffic: bool = combine_service_traffic +# self._quantisation_levels: int = quantisation_levels +# self._entries_per_link: int = 1 + +# if not self._combine_service_traffic: +# self._entries_per_link = self.env.num_services + +# # 1. Define the shape of your observation space component +# shape = ( +# [self._quantisation_levels] * self.env.num_links * self._entries_per_link +# ) + +# # 2. Create Observation space +# self.space = spaces.MultiDiscrete(shape) + +# # 3. Initialise observation with zeroes +# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + +# def update_obs(self): +# obs = [] +# for _, link in self.env.links.items(): +# bandwidth = link.bandwidth +# if self._combine_service_traffic: +# loads = [link.get_current_load()] +# else: +# loads = [protocol.get_load() for protocol in link.protocol_list] + +# for load in loads: +# if load <= 0: +# traffic_level = 0 +# elif load >= bandwidth: +# traffic_level = self._quantisation_levels - 1 +# else: +# traffic_level = (load / bandwidth) // ( +# 1 / (self._quantisation_levels - 1) +# ) + 1 + +# obs.append(int(traffic_level)) + +# self.current_observation[:] = obs + + +# class ObservationsHandler: +# class registry(Enum): +# NODE_LINK_TABLE: NodeLinkTable +# NODE_STATUSES: NodeStatuses +# LINK_TRAFFIC_LEVELS: LinkTrafficLevels + +# def __init__(self): +# ... +# # i can access the registry items like this: +# # self.registry.LINK_TRAFFIC_LEVELS + +# def update_obs(self): +# ... + +# def register(self): +# ... + +# def deregister(self, observation: AbstractObservationComponent): +# ... + +# def export(self): +# ... From d351e575aeede80242c1f8f95055da0ca7bea6d1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 13:28:40 +0100 Subject: [PATCH 08/37] Integrate observation handler with components --- src/primaite/environment/observations.py | 393 ++++++++++++----------- 1 file changed, 211 insertions(+), 182 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 40cc26a5..338c11a1 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,227 +1,256 @@ -# """Module for handling configurable observation spaces in PrimAITE.""" -# import logging -# from abc import ABC, abstractmethod -# from enum import Enum +"""Module for handling configurable observation spaces in PrimAITE.""" +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Tuple -# import numpy as np -# from gym import spaces +import numpy as np +from gym import spaces -# from primaite.common.enums import FileSystemState, HardwareState, SoftwareState -# from primaite.environment.primaite_env import Primaite -# from primaite.nodes.active_node import ActiveNode -# from primaite.nodes.service_node import ServiceNode +from primaite.common.enums import FileSystemState, HardwareState, SoftwareState +from primaite.environment.primaite_env import Primaite +from primaite.nodes.active_node import ActiveNode +from primaite.nodes.service_node import ServiceNode -# _LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -# class AbstractObservationComponent(ABC): -# """Represents a part of the PrimAITE observation space.""" -# @abstractmethod -# def __init__(self, env: Primaite): -# _LOGGER.info(f"Initialising {self} observation component") -# self.env: Primaite = env -# self.space: spaces.Space -# self.current_observation: np.ndarray # type might be too restrictive? -# return NotImplemented +class AbstractObservationComponent(ABC): + """Represents a part of the PrimAITE observation space.""" -# @abstractmethod -# def update(self): -# """Look at the environment and update the current observation value""" -# self.current_observation = NotImplemented + @abstractmethod + def __init__(self, env: Primaite): + _LOGGER.info(f"Initialising {self} observation component") + self.env: Primaite = env + self.space: spaces.Space + self.current_observation: np.ndarray # type might be too restrictive? + return NotImplemented -# # @abstractmethod -# # def export(self): -# # return NotImplemented + @abstractmethod + def update(self): + """Look at the environment and update the current observation value.""" + self.current_observation = NotImplemented -# class NodeLinkTable(AbstractObservationComponent): -# """Table with nodes/links as rows and hardware/software status as cols. +class NodeLinkTable(AbstractObservationComponent): + """Table with nodes/links as rows and hardware/software status as cols. -# #todo: write full description + #todo: write full description -# """ + """ -# _FIXED_PARAMETERS = 4 -# _MAX_VAL = 1_000_000 -# _DATA_TYPE = np.int64 + _FIXED_PARAMETERS = 4 + _MAX_VAL = 1_000_000 + _DATA_TYPE = np.int64 -# def __init__(self, env: Primaite): -# super().__init__() + def __init__(self, env: Primaite): + super().__init__() -# # 1. Define the shape of your observation space component -# num_items = self.env.num_links + self.env.num_nodes -# num_columns = self.env.num_services + self._FIXED_PARAMETERS -# observation_shape = (num_items, num_columns) + # 1. Define the shape of your observation space component + num_items = self.env.num_links + self.env.num_nodes + num_columns = self.env.num_services + self._FIXED_PARAMETERS + observation_shape = (num_items, num_columns) -# # 2. Create Observation space -# self.space = spaces.Box( -# low=0, -# high=self._MAX_VAL, -# shape=observation_shape, -# dtype=self._DATA_TYPE, -# ) + # 2. Create Observation space + self.space = spaces.Box( + low=0, + high=self._MAX_VAL, + shape=observation_shape, + dtype=self._DATA_TYPE, + ) -# # 3. Initialise Observation with zeroes -# self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) + # 3. Initialise Observation with zeroes + self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) -# def update_obs(self): -# item_index = 0 -# nodes = self.env.nodes -# links = self.env.links -# # Do nodes first -# for _, node in nodes.items(): -# self.current_observation[item_index][0] = int(node.node_id) -# self.current_observation[item_index][1] = node.hardware_state.value -# if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): -# self.current_observation[item_index][2] = node.software_state.value -# self.current_observation[item_index][ -# 3 -# ] = node.file_system_state_observed.value -# else: -# self.current_observation[item_index][2] = 0 -# self.current_observation[item_index][3] = 0 -# service_index = 4 -# if isinstance(node, ServiceNode): -# for service in self.env.services_list: -# if node.has_service(service): -# self.current_observation[item_index][ -# service_index -# ] = node.get_service_state(service).value -# else: -# self.current_observation[item_index][service_index] = 0 -# service_index += 1 -# else: -# # Not a service node -# for service in self.env.services_list: -# self.current_observation[item_index][service_index] = 0 -# service_index += 1 -# item_index += 1 + def update_obs(self): + """Update the observation. -# # Now do links -# for _, link in links.items(): -# self.current_observation[item_index][0] = int(link.get_id()) -# self.current_observation[item_index][1] = 0 -# self.current_observation[item_index][2] = 0 -# self.current_observation[item_index][3] = 0 -# protocol_list = link.get_protocol_list() -# protocol_index = 0 -# for protocol in protocol_list: -# self.current_observation[item_index][ -# protocol_index + 4 -# ] = protocol.get_load() -# protocol_index += 1 -# item_index += 1 + todo: complete description.. + """ + item_index = 0 + nodes = self.env.nodes + links = self.env.links + # Do nodes first + for _, node in nodes.items(): + self.current_observation[item_index][0] = int(node.node_id) + self.current_observation[item_index][1] = node.hardware_state.value + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + self.current_observation[item_index][2] = node.software_state.value + self.current_observation[item_index][ + 3 + ] = node.file_system_state_observed.value + else: + self.current_observation[item_index][2] = 0 + self.current_observation[item_index][3] = 0 + service_index = 4 + if isinstance(node, ServiceNode): + for service in self.env.services_list: + if node.has_service(service): + self.current_observation[item_index][ + service_index + ] = node.get_service_state(service).value + else: + self.current_observation[item_index][service_index] = 0 + service_index += 1 + else: + # Not a service node + for service in self.env.services_list: + self.current_observation[item_index][service_index] = 0 + service_index += 1 + item_index += 1 + + # Now do links + for _, link in links.items(): + self.current_observation[item_index][0] = int(link.get_id()) + self.current_observation[item_index][1] = 0 + self.current_observation[item_index][2] = 0 + self.current_observation[item_index][3] = 0 + protocol_list = link.get_protocol_list() + protocol_index = 0 + for protocol in protocol_list: + self.current_observation[item_index][ + protocol_index + 4 + ] = protocol.get_load() + protocol_index += 1 + item_index += 1 -# class NodeStatuses(AbstractObservationComponent): -# _DATA_TYPE = np.int64 +class NodeStatuses(AbstractObservationComponent): + """todo: complete description.""" -# def __init__(self): -# super().__init__() + _DATA_TYPE = np.int64 -# # 1. Define the shape of your observation space component -# shape = [ -# len(HardwareState) + 1, -# len(SoftwareState) + 1, -# len(FileSystemState) + 1, -# ] -# services_shape = [len(SoftwareState) + 1] * self.env.num_services -# shape = shape + services_shape + def __init__(self): + super().__init__() -# # 2. Create Observation space -# self.space = spaces.MultiDiscrete(shape) + # 1. Define the shape of your observation space component + shape = [ + len(HardwareState) + 1, + len(SoftwareState) + 1, + len(FileSystemState) + 1, + ] + services_shape = [len(SoftwareState) + 1] * self.env.num_services + shape = shape + services_shape -# # 3. Initialise observation with zeroes -# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + # 2. Create Observation space + self.space = spaces.MultiDiscrete(shape) -# def update_obs(self): -# obs = [] -# for _, node in self.env.nodes.items(): -# hardware_state = node.hardware_state.value -# software_state = 0 -# file_system_state = 0 -# service_states = [0] * self.env.num_services + # 3. Initialise observation with zeroes + self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) -# if isinstance(node, ActiveNode): -# software_state = node.software_state.value -# file_system_state = node.file_system_state_observed.value + def update_obs(self): + """todo: complete description.""" + obs = [] + for _, node in self.env.nodes.items(): + hardware_state = node.hardware_state.value + software_state = 0 + file_system_state = 0 + service_states = [0] * self.env.num_services -# if isinstance(node, ServiceNode): -# for i, service in enumerate(self.env.services_list): -# if node.has_service(service): -# service_states[i] = node.get_service_state(service).value -# obs.extend([hardware_state, software_state, file_system_state, *service_states]) -# self.current_observation[:] = obs + if isinstance(node, ActiveNode): + software_state = node.software_state.value + file_system_state = node.file_system_state_observed.value + + if isinstance(node, ServiceNode): + for i, service in enumerate(self.env.services_list): + if node.has_service(service): + service_states[i] = node.get_service_state(service).value + obs.extend([hardware_state, software_state, file_system_state, *service_states]) + self.current_observation[:] = obs -# class LinkTrafficLevels(AbstractObservationComponent): -# _DATA_TYPE = np.int64 +class LinkTrafficLevels(AbstractObservationComponent): + """todo: complete description.""" -# def __init__( -# self, combine_service_traffic: bool = False, quantisation_levels: int = 5 -# ): -# super().__init__() -# self._combine_service_traffic: bool = combine_service_traffic -# self._quantisation_levels: int = quantisation_levels -# self._entries_per_link: int = 1 + _DATA_TYPE = np.int64 -# if not self._combine_service_traffic: -# self._entries_per_link = self.env.num_services + def __init__( + self, combine_service_traffic: bool = False, quantisation_levels: int = 5 + ): + super().__init__() + self._combine_service_traffic: bool = combine_service_traffic + self._quantisation_levels: int = quantisation_levels + self._entries_per_link: int = 1 -# # 1. Define the shape of your observation space component -# shape = ( -# [self._quantisation_levels] * self.env.num_links * self._entries_per_link -# ) + if not self._combine_service_traffic: + self._entries_per_link = self.env.num_services -# # 2. Create Observation space -# self.space = spaces.MultiDiscrete(shape) + # 1. Define the shape of your observation space component + shape = ( + [self._quantisation_levels] * self.env.num_links * self._entries_per_link + ) -# # 3. Initialise observation with zeroes -# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + # 2. Create Observation space + self.space = spaces.MultiDiscrete(shape) -# def update_obs(self): -# obs = [] -# for _, link in self.env.links.items(): -# bandwidth = link.bandwidth -# if self._combine_service_traffic: -# loads = [link.get_current_load()] -# else: -# loads = [protocol.get_load() for protocol in link.protocol_list] + # 3. Initialise observation with zeroes + self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) -# for load in loads: -# if load <= 0: -# traffic_level = 0 -# elif load >= bandwidth: -# traffic_level = self._quantisation_levels - 1 -# else: -# traffic_level = (load / bandwidth) // ( -# 1 / (self._quantisation_levels - 1) -# ) + 1 + def update_obs(self): + """todo: complete description.""" + obs = [] + for _, link in self.env.links.items(): + bandwidth = link.bandwidth + if self._combine_service_traffic: + loads = [link.get_current_load()] + else: + loads = [protocol.get_load() for protocol in link.protocol_list] -# obs.append(int(traffic_level)) + for load in loads: + if load <= 0: + traffic_level = 0 + elif load >= bandwidth: + traffic_level = self._quantisation_levels - 1 + else: + traffic_level = (load / bandwidth) // ( + 1 / (self._quantisation_levels - 1) + ) + 1 -# self.current_observation[:] = obs + obs.append(int(traffic_level)) + + self.current_observation[:] = obs -# class ObservationsHandler: -# class registry(Enum): -# NODE_LINK_TABLE: NodeLinkTable -# NODE_STATUSES: NodeStatuses -# LINK_TRAFFIC_LEVELS: LinkTrafficLevels +class ObservationsHandler: + """todo: complete description.""" -# def __init__(self): -# ... -# # i can access the registry items like this: -# # self.registry.LINK_TRAFFIC_LEVELS + class registry(Enum): + """todo: complete description.""" -# def update_obs(self): -# ... + NODE_LINK_TABLE: NodeLinkTable + NODE_STATUSES: NodeStatuses + LINK_TRAFFIC_LEVELS: LinkTrafficLevels -# def register(self): -# ... + def __init__(self): + """todo: complete description.""" + """Initialise the handler without any components yet. They""" + self.registered_obs_components: List[AbstractObservationComponent] = [] + self.space: spaces.Space + self.current_observation: Tuple[np.ndarray] + # i can access the registry items like this: + # self.registry.LINK_TRAFFIC_LEVELS -# def deregister(self, observation: AbstractObservationComponent): -# ... + def update_obs(self): + """todo: complete description.""" + current_obs = [] + for obs in self.registered_obs_components: + obs.update_obs() + current_obs.append(obs.current_observation) + self.current_observation = tuple(current_obs) -# def export(self): -# ... + def register(self, obs_component: AbstractObservationComponent): + """todo: complete description.""" + self.registered_obs_components.append(obs_component) + self.update_space() + + def deregister(self, obs_component: AbstractObservationComponent): + """todo: complete description.""" + self.registered_obs_components.remove(obs_component) + self.update_space() + + def update_space(self): + """todo: complete description.""" + component_spaces = [] + for obs_comp in self.registered_obs_components: + component_spaces.append(obs_comp.space) + self.space = spaces.Tuple(component_spaces) From f72a80c9d2a8ddeac8bb81e58f4dd87a86271610 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Thu, 1 Jun 2023 16:27:25 +0100 Subject: [PATCH 09/37] 1443 - added in print test statements --- src/primaite/agents/utils.py | 2 ++ src/primaite/environment/primaite_env.py | 24 ++++++++++++++--- .../single_action_space_lay_down_config.yaml | 26 ++++++++++++++++++- .../single_action_space_main_config.yaml | 2 +- 4 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index d3924b24..1ada88ba 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -414,6 +414,8 @@ def is_valid_node_action(action): node_property = action_r[1] node_action = action_r[2] + # print("node property", node_property, "\nnode action", node_action) + if node_property == "NONE": return False if node_action == "NONE": diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 49c45f3e..5d783af1 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -42,6 +42,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) +_LOGGER.setLevel(logging.INFO) class Primaite(Env): @@ -235,6 +236,7 @@ class Primaite(Env): # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) + print(self.action_space, "NODE action space") elif self.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): @@ -246,11 +248,12 @@ class Primaite(Env): # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) + print(self.action_space, "ACL action space") else: - _LOGGER.info("Action space type ANY selected - Node + ACL") + _LOGGER.warning("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - + print(self.action_space, "ANY action space") # Set up a csv to store the results of the training try: now = datetime.now() # current date and time @@ -450,6 +453,7 @@ class Primaite(Env): Args: _action: The action space from the agent """ + # print("intepret action") # At the moment, actions are only affecting nodes if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) @@ -464,6 +468,7 @@ class Primaite(Env): ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: + print("invalid action type found") logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): @@ -1084,6 +1089,7 @@ class Primaite(Env): item: A config data item representing action info """ self.action_type = ActionType[action_info["type"]] + print("action type selected: ", self.action_type) def get_steps_info(self, steps_info): """ @@ -1187,9 +1193,8 @@ class Primaite(Env): """ # reserve 0 action to be a nothing action - # Used to be {0: [1, 0, 0, 0]} actions = {0: [1, 0, 0, 0]} - + # print("node dict function call", self.num_nodes + 1) action_key = 1 for node in range(1, self.num_nodes + 1): # 4 node properties (NONE, OPERATING, OS, SERVICE) @@ -1197,11 +1202,14 @@ class Primaite(Env): # Node Actions either: # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state # Use MAX to ensure we get them all + # print(self.num_services, "num services") for node_action in range(4): for service_state in range(self.num_services): action = [node, node_property, node_action, service_state] # check to see if its a nothing aciton (has no effect) + # print("action node",action) if is_valid_node_action(action): + print("true") actions[action_key] = action action_key += 1 @@ -1213,6 +1221,7 @@ class Primaite(Env): actions = {0: [0, 0, 0, 0, 0, 0]} action_key = 1 + # print("node count",self.num_nodes + 1) # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE for action_decision in range(3): # 2 possible action permissions 0 = DENY, 1 = CREATE @@ -1230,10 +1239,13 @@ class Primaite(Env): protocol, port, ] + # print("action acl", action) # Check to see if its an action we want to include as possible i.e. not a nothing action if is_valid_acl_action_extra(action): + print("true") actions[action_key] = action action_key += 1 + # print("false") return actions @@ -1247,6 +1259,8 @@ class Primaite(Env): node_action_dict = self.create_node_action_dict() acl_action_dict = self.create_acl_action_dict() + print(len(node_action_dict), len(acl_action_dict)) + # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other new_node_action_dict = { @@ -1257,4 +1271,6 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} + logging.warning("logging is working") + # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml index 058b790b..6d44356f 100644 --- a/tests/config/single_action_space_lay_down_config.yaml +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -1,5 +1,5 @@ - itemType: ACTIONS - type: NODE + type: ANY - itemType: STEPS steps: 15 - itemType: PORTS @@ -27,3 +27,27 @@ - node: '1' x_pos: 309 y_pos: 78 +- itemType: RED_POL + id: '1' + startStep: 1 + endStep: 3 + targetNodeId: '1' + initiator: DIRECT + type: FILE + protocol: NA + state: CORRUPT + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '2' + startStep: 3 + endStep: 15 + targetNodeId: '1' + initiator: DIRECT + type: FILE + protocol: NA + state: GOOD + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index 6f6bb4e6..7fcc002f 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -11,7 +11,7 @@ numEpisodes: 1 # Time delay between steps (for generic agents) timeDelay: 1 # Filename of the scenario / laydown -configFilename: one_node_states_on_off_lay_down_config.yaml +configFilename: single_action_space_lay_down_config.yaml # Type of session to be run (TRAINING or EVALUATION) sessionType: TRAINING # Determine whether to load an agent from file From bab6c27f06232d42a17b357f8562087f486f4a56 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 16:42:10 +0100 Subject: [PATCH 10/37] Integrate obs handler with Primaite Env --- src/primaite/config/config_1_DDOS_BASIC.yaml | 8 + src/primaite/environment/observations.py | 90 +++++- src/primaite/environment/primaite_env.py | 282 ++----------------- 3 files changed, 115 insertions(+), 265 deletions(-) diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index ada813f3..a1961df3 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,5 +1,13 @@ - itemType: ACTIONS type: NODE +- itemType: OBSERVATION_SPACE + components: + - name: NODE_LINK_TABLE + - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + options: + - combine_service_traffic : False + - quantisation_levels : 7 - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 338c11a1..94c2730f 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -35,7 +35,23 @@ class AbstractObservationComponent(ABC): class NodeLinkTable(AbstractObservationComponent): """Table with nodes/links as rows and hardware/software status as cols. - #todo: write full description + Initialise the observation space with the BOX option chosen. + + This will create the observation space formatted as a table of integers. + There is one row per node, followed by one row per link. + Columns are as follows: + * node/link ID + * node hardware status / 0 for links + * node operating system status (if active/service) / 0 for links + * node file system status (active/service only) / 0 for links + * node service1 status / traffic load from that service for links + * node service2 status / traffic load from that service for links + * ... + * node serviceN status / traffic load from that service for links + + For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be + ``(12, 7)`` + #todo: clean up description """ @@ -44,7 +60,7 @@ class NodeLinkTable(AbstractObservationComponent): _DATA_TYPE = np.int64 def __init__(self, env: Primaite): - super().__init__() + super().__init__(env) # 1. Define the shape of your observation space component num_items = self.env.num_links + self.env.num_nodes @@ -65,6 +81,10 @@ class NodeLinkTable(AbstractObservationComponent): def update_obs(self): """Update the observation. + Update the environment's observation state based on the current status of nodes and links. + + The structure of the observation space is described in :func:`~_init_box_observations` + This function can only be called if the observation space setting is set to BOX. todo: complete description.. """ item_index = 0 @@ -116,12 +136,20 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """todo: complete description.""" + """todo: complete description. + + This will create the observation space with node observations followed by link observations. + Each node has 3 elements in the observation space plus 1 per service, more specifically: + * hardware state + * operating system state + * file system state + * service states (one per service) + """ _DATA_TYPE = np.int64 - def __init__(self): - super().__init__() + def __init__(self, env: Primaite): + super().__init__(env) # 1. Define the shape of your observation space component shape = [ @@ -139,7 +167,15 @@ class NodeStatuses(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update_obs(self): - """todo: complete description.""" + """todo: complete description. + + Update the environment's observation state based on the current status of nodes and links. + + The structure of the observation space is described in :func:`~_init_multidiscrete_observations` + This function can only be called if the observation space setting is set to MULTIDISCRETE. + + + """ obs = [] for _, node in self.env.nodes.items(): hardware_state = node.hardware_state.value @@ -160,14 +196,26 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """todo: complete description.""" + """todo: complete description. + + Each link has one element in the observation space, corresponding to the traffic load, + it can take the following values: + 0 = No traffic (0% of bandwidth) + 1 = No traffic (0%-33% of bandwidth) + 2 = No traffic (33%-66% of bandwidth) + 3 = No traffic (66%-100% of bandwidth) + 4 = No traffic (100% of bandwidth) + """ _DATA_TYPE = np.int64 def __init__( - self, combine_service_traffic: bool = False, quantisation_levels: int = 5 + self, + env: Primaite, + combine_service_traffic: bool = False, + quantisation_levels: int = 5, ): - super().__init__() + super().__init__(env) self._combine_service_traffic: bool = combine_service_traffic self._quantisation_levels: int = quantisation_levels self._entries_per_link: int = 1 @@ -212,7 +260,7 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: - """todo: complete description.""" + """Component-based observation space handler.""" class registry(Enum): """todo: complete description.""" @@ -254,3 +302,25 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) self.space = spaces.Tuple(component_spaces) + + @classmethod + def from_config(cls, obs_space_config): + """todo: complete description. + + This method parses config items related to the observation space, then + creates the necessary components and adds them to the observation handler. + """ + # Instantiate the handler + handler = cls() + + for component_cfg in obs_space_config["components"]: + # Figure out which class can instantiate the desired component + comp_type = component_cfg["name"] + comp_class = cls.registry[comp_type].value + + # Create the component with options from the YAML + component = comp_class(**component_cfg["options"]) + + handler.register(component) + + return handler diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 56893ee9..afa04060 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,7 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import networkx as nx import numpy as np @@ -23,11 +23,11 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, - ObservationType, Priority, SoftwareState, ) from primaite.common.service import Service +from primaite.environment.observations import ObservationsHandler from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode @@ -149,8 +149,8 @@ class Primaite(Env): # The action type self.action_type = 0 - # Observation type, by default box. - self.observation_type = ObservationType.BOX + # todo: proper description here + self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown try: @@ -161,6 +161,10 @@ class Primaite(Env): _LOGGER.error("Could not load the environment configuration") _LOGGER.error("Exception occured", exc_info=True) + # If it doesn't exist after parsing config, create default obs space. + if self.get("obs_handler") is None: + self.configure_obs_space() + # Store the node objects as node attributes # (This is so we can access them as objects) for node in self.network: @@ -641,252 +645,17 @@ class Primaite(Env): else: pass - def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the BOX option chosen. - - This will create the observation space formatted as a table of integers. - There is one row per node, followed by one row per link. - Columns are as follows: - * node/link ID - * node hardware status / 0 for links - * node operating system status (if active/service) / 0 for links - * node file system status (active/service only) / 0 for links - * node service1 status / traffic load from that service for links - * node service2 status / traffic load from that service for links - * ... - * node serviceN status / traffic load from that service for links - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(12, 7)`` - - :return: Box gym observation - :rtype: gym.spaces.Box - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space type BOX selected") - - # 1. Determine observation shape from laydown - num_items = self.num_links + self.num_nodes - num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - observation_shape = (num_items, num_observation_parameters) - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.Box( - low=0, - high=self.OBSERVATION_SPACE_HIGH_VALUE, - shape=observation_shape, - dtype=np.int64, - ) - initial_observation = np.zeros(observation_shape, dtype=np.int64) - - return observation_space, initial_observation - - def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Initialise the observation space with the MULTIDISCRETE option chosen. - - This will create the observation space with node observations followed by link observations. - Each node has 3 elements in the observation space plus 1 per service, more specifically: - * hardware state - * operating system state - * file system state - * service states (one per service) - Each link has one element in the observation space, corresponding to the traffic load, - it can take the following values: - 0 = No traffic (0% of bandwidth) - 1 = No traffic (0%-33% of bandwidth) - 2 = No traffic (33%-66% of bandwidth) - 3 = No traffic (66%-100% of bandwidth) - 4 = No traffic (100% of bandwidth) - - For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be - ``(37,)`` - - :return: MultiDiscrete gym observation - :rtype: gym.spaces.MultiDiscrete - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - _LOGGER.info("Observation space MULTIDISCRETE selected") - - # 1. Determine observation shape from laydown - node_obs_shape = [ - len(HardwareState) + 1, - len(SoftwareState) + 1, - len(FileSystemState) + 1, - ] - node_services = [len(SoftwareState) + 1] * self.num_services - node_obs_shape = node_obs_shape + node_services - # the magic number 5 refers to 5 states of quantisation of traffic amount. - # (zero, low, medium, high, fully utilised/overwhelmed) - link_obs_shape = [5] * self.num_links - observation_shape = node_obs_shape * self.num_nodes + link_obs_shape - - # 2. Create observation space & zeroed out sample from space. - observation_space = spaces.MultiDiscrete(observation_shape) - initial_observation = np.zeros(len(observation_shape), dtype=np.int64) - - return observation_space, initial_observation - def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """Build the observation space based on network laydown and provide initial obs. - - This method uses the object's `num_links`, `num_nodes`, `num_services`, - `OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type` - attributes to figure out the correct shape and format for the observation space. - - :raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType` - :return: Gym observation space - :rtype: gym.spaces.Space - :return: Initial observation with all entires set to 0 - :rtype: numpy.Array - """ - if self.observation_type == ObservationType.BOX: - observation_space, initial_observation = self._init_box_observations() - return observation_space, initial_observation - elif self.observation_type == ObservationType.MULTIDISCRETE: - ( - observation_space, - initial_observation, - ) = self._init_multidiscrete_observations() - return observation_space, initial_observation - else: - errmsg = ( - f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}" - f", got {self.observation_type} instead" - ) - _LOGGER.error(errmsg) - raise ValueError(errmsg) - - def _update_env_obs_box(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_box_observations` - This function can only be called if the observation space setting is set to BOX. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.BOX - item_index = 0 - - # Do nodes first - for node_key, node in self.nodes.items(): - self.env_obs[item_index][0] = int(node.node_id) - self.env_obs[item_index][1] = node.hardware_state.value - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.env_obs[item_index][2] = node.software_state.value - self.env_obs[item_index][3] = node.file_system_state_observed.value - else: - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - service_index = 4 - if isinstance(node, ServiceNode): - for service in self.services_list: - if node.has_service(service): - self.env_obs[item_index][ - service_index - ] = node.get_service_state(service).value - else: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - else: - # Not a service node - for service in self.services_list: - self.env_obs[item_index][service_index] = 0 - service_index += 1 - item_index += 1 - - # Now do links - for link_key, link in self.links.items(): - self.env_obs[item_index][0] = int(link.get_id()) - self.env_obs[item_index][1] = 0 - self.env_obs[item_index][2] = 0 - self.env_obs[item_index][3] = 0 - protocol_list = link.get_protocol_list() - protocol_index = 0 - for protocol in protocol_list: - self.env_obs[item_index][protocol_index + 4] = protocol.get_load() - protocol_index += 1 - item_index += 1 - - def _update_env_obs_multidiscrete(self): - """Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_multidiscrete_observations` - This function can only be called if the observation space setting is set to MULTIDISCRETE. - - :raises AssertionError: If this function is called when the environment has the incorrect ``observation_type`` - """ - assert self.observation_type == ObservationType.MULTIDISCRETE - obs = [] - # 1. Set nodes - # Each node has the following variables in the observation space: - # - Hardware state - # - Software state - # - File System state - # - Service 1 state - # - Service 2 state - # - ... - # - Service N state - for node_key, node in self.nodes.items(): - hardware_state = node.hardware_state.value - software_state = 0 - file_system_state = 0 - services_states = [0] * self.num_services - - if isinstance( - node, ActiveNode - ): # ServiceNode is a subclass of ActiveNode so no need to check that also - software_state = node.software_state.value - file_system_state = node.file_system_state_observed.value - - if isinstance(node, ServiceNode): - for i, service in enumerate(self.services_list): - if node.has_service(service): - services_states[i] = node.get_service_state(service).value - - obs.extend( - [ - hardware_state, - software_state, - file_system_state, - *services_states, - ] - ) - - # 2. Set links - # Each link has just one variable in the observation space, it represents the traffic amount - # In order for the space to be fully MultiDiscrete, the amount of - # traffic on each link is quantised into a few levels: - # 0: no traffic (0% of bandwidth) - # 1: low traffic (0-33% of bandwidth) - # 2: medium traffic (33-66% of bandwidth) - # 3: high traffic (66-100% of bandwidth) - # 4: max traffic/overloaded (100% of bandwidth) - - for link_key, link in self.links.items(): - bandwidth = link.bandwidth - load = link.get_current_load() - - if load <= 0: - traffic_level = 0 - elif load >= bandwidth: - traffic_level = 4 - else: - traffic_level = (load / bandwidth) // (1 / 3) + 1 - - obs.append(int(traffic_level)) - - self.env_obs = np.asarray(obs) + """todo: write docstring.""" + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): - """Updates the observation space based on the node and link status.""" - if self.observation_type == ObservationType.BOX: - self._update_env_obs_box() - elif self.observation_type == ObservationType.MULTIDISCRETE: - self._update_env_obs_multidiscrete() + """Updates the observation space based on the node and link status. + + todo: better docstring + """ + self.obs_handler.update_obs() + self.env_obs = self.obs_handler.current_observation def load_config(self): """Loads config data in order to build the environment configuration.""" @@ -921,9 +690,9 @@ class Primaite(Env): elif item["itemType"] == "ACTIONS": # Get the action information self.get_action_info(item) - elif item["itemType"] == "OBSERVATIONS": + elif item["itemType"] == "OBSERVATION_SPACE": # Get the observation information - self.get_observation_info(item) + self.configure_obs_space(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1256,13 +1025,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def get_observation_info(self, observation_info): - """Extracts observation_info. + def configure_obs_space(self, observation_config: Optional[Dict] = None): + """todo: better docstring.""" + if observation_config is None: + observation_config = { + "components": [ + {"name": "NODE_LINK_TABLE"}, + ] + } - :param observation_info: Config item that defines which type of observation space to use - :type observation_info: str - """ - self.observation_type = ObservationType[observation_info["type"]] + self.obs_handler = ObservationsHandler[observation_config] def get_steps_info(self, steps_info): """ From e43649a838afd7b3eaa624e669914dd852b427cd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 17:42:35 +0100 Subject: [PATCH 11/37] Fix trying to init obs before building network --- src/primaite/environment/observations.py | 74 +++++++++++++----------- src/primaite/environment/primaite_env.py | 39 +++++++------ 2 files changed, 61 insertions(+), 52 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 94c2730f..a1b0d9ac 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,17 +1,22 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from enum import Enum -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Tuple import numpy as np from gym import spaces from primaite.common.enums import FileSystemState, HardwareState, SoftwareState -from primaite.environment.primaite_env import Primaite from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode +# This dependency is only needed for type hints, +# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking +# Therefore, this avoids circular dependency problem. +if TYPE_CHECKING: + from primaite.environment.primaite_env import Primaite + + _LOGGER = logging.getLogger(__name__) @@ -19,9 +24,9 @@ class AbstractObservationComponent(ABC): """Represents a part of the PrimAITE observation space.""" @abstractmethod - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): _LOGGER.info(f"Initialising {self} observation component") - self.env: Primaite = env + self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? return NotImplemented @@ -51,7 +56,7 @@ class NodeLinkTable(AbstractObservationComponent): For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be ``(12, 7)`` - #todo: clean up description + #TODO: clean up description """ @@ -59,7 +64,7 @@ class NodeLinkTable(AbstractObservationComponent): _MAX_VAL = 1_000_000 _DATA_TYPE = np.int64 - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): super().__init__(env) # 1. Define the shape of your observation space component @@ -76,16 +81,16 @@ class NodeLinkTable(AbstractObservationComponent): ) # 3. Initialise Observation with zeroes - self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) + self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) - def update_obs(self): + def update(self): """Update the observation. Update the environment's observation state based on the current status of nodes and links. The structure of the observation space is described in :func:`~_init_box_observations` This function can only be called if the observation space setting is set to BOX. - todo: complete description.. + TODO: complete description.. """ item_index = 0 nodes = self.env.nodes @@ -136,7 +141,7 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """todo: complete description. + """TODO: complete description. This will create the observation space with node observations followed by link observations. Each node has 3 elements in the observation space plus 1 per service, more specifically: @@ -148,7 +153,7 @@ class NodeStatuses(AbstractObservationComponent): _DATA_TYPE = np.int64 - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): super().__init__(env) # 1. Define the shape of your observation space component @@ -166,8 +171,8 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - def update_obs(self): - """todo: complete description. + def update(self): + """TODO: complete description. Update the environment's observation state based on the current status of nodes and links. @@ -196,7 +201,7 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """todo: complete description. + """TODO: complete description. Each link has one element in the observation space, corresponding to the traffic load, it can take the following values: @@ -211,7 +216,7 @@ class LinkTrafficLevels(AbstractObservationComponent): def __init__( self, - env: Primaite, + env: "Primaite", combine_service_traffic: bool = False, quantisation_levels: int = 5, ): @@ -234,8 +239,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - def update_obs(self): - """todo: complete description.""" + def update(self): + """TODO: complete description.""" obs = [] for _, link in self.env.links.items(): bandwidth = link.bandwidth @@ -262,15 +267,14 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: """Component-based observation space handler.""" - class registry(Enum): - """todo: complete description.""" - - NODE_LINK_TABLE: NodeLinkTable - NODE_STATUSES: NodeStatuses - LINK_TRAFFIC_LEVELS: LinkTrafficLevels + registry = { + "NODE_LINK_TABLE": NodeLinkTable, + "NODE_STATUSES": NodeStatuses, + "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, + } def __init__(self): - """todo: complete description.""" + """TODO: complete description.""" """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space @@ -279,33 +283,33 @@ class ObservationsHandler: # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): - """todo: complete description.""" + """TODO: complete description.""" current_obs = [] for obs in self.registered_obs_components: - obs.update_obs() + obs.update() current_obs.append(obs.current_observation) self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): - """todo: complete description.""" + """TODO: complete description.""" self.registered_obs_components.append(obs_component) self.update_space() def deregister(self, obs_component: AbstractObservationComponent): - """todo: complete description.""" + """TODO: complete description.""" self.registered_obs_components.remove(obs_component) self.update_space() def update_space(self): - """todo: complete description.""" + """TODO: complete description.""" component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) self.space = spaces.Tuple(component_spaces) @classmethod - def from_config(cls, obs_space_config): - """todo: complete description. + def from_config(cls, env: "Primaite", obs_space_config: dict): + """TODO: complete description. This method parses config items related to the observation space, then creates the necessary components and adds them to the observation handler. @@ -316,11 +320,13 @@ class ObservationsHandler: for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] - comp_class = cls.registry[comp_type].value + comp_class = cls.registry[comp_type] # Create the component with options from the YAML - component = comp_class(**component_cfg["options"]) + options = component_cfg.get("options") or {} + component = comp_class(env, **options) handler.register(component) + handler.update_obs() return handler diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index afa04060..0107920f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -149,7 +149,8 @@ class Primaite(Env): # The action type self.action_type = 0 - # todo: proper description here + # TODO: proper description here + self.obs_config: dict self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown @@ -161,10 +162,6 @@ class Primaite(Env): _LOGGER.error("Could not load the environment configuration") _LOGGER.error("Exception occured", exc_info=True) - # If it doesn't exist after parsing config, create default obs space. - if self.get("obs_handler") is None: - self.configure_obs_space() - # Store the node objects as node attributes # (This is so we can access them as objects) for node in self.network: @@ -195,6 +192,10 @@ class Primaite(Env): _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") + # # If it doesn't exist after parsing config, create default obs space. + # if getattr(self, "obs_handler", None) is None: + # self.configure_obs_space() + # Initiate observation space self.observation_space, self.env_obs = self.init_observations() @@ -646,13 +647,22 @@ class Primaite(Env): pass def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """todo: write docstring.""" + """TODO: write docstring.""" + if getattr(self, "obs_config", None) is None: + self.obs_config = { + "components": [ + {"name": "NODE_LINK_TABLE"}, + ] + } + + self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): """Updates the observation space based on the node and link status. - todo: better docstring + TODO: better docstring """ self.obs_handler.update_obs() self.env_obs = self.obs_handler.current_observation @@ -692,7 +702,7 @@ class Primaite(Env): self.get_action_info(item) elif item["itemType"] == "OBSERVATION_SPACE": # Get the observation information - self.configure_obs_space(item) + self.save_obs_config(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1025,16 +1035,9 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def configure_obs_space(self, observation_config: Optional[Dict] = None): - """todo: better docstring.""" - if observation_config is None: - observation_config = { - "components": [ - {"name": "NODE_LINK_TABLE"}, - ] - } - - self.obs_handler = ObservationsHandler[observation_config] + def save_obs_config(self, obs_config: Optional[Dict] = None): + """TODO: better docstring.""" + self.obs_config = obs_config def get_steps_info(self, steps_info): """ From 0a804e714d8ef7b5d9a14920f1657a0c1595a543 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 17:50:18 +0100 Subject: [PATCH 12/37] Better Obs default handling --- src/primaite/environment/primaite_env.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0107920f..81557075 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -47,15 +47,12 @@ _LOGGER = logging.getLogger(__name__) class Primaite(Env): """PRIMmary AI Training Evironment (Primaite) class.""" - # Observation / Action Space contants - OBSERVATION_SPACE_FIXED_PARAMETERS = 4 + # Action Space contants ACTION_SPACE_NODE_PROPERTY_VALUES = 5 ACTION_SPACE_NODE_ACTION_VALUES = 4 ACTION_SPACE_ACL_ACTION_VALUES = 3 ACTION_SPACE_ACL_PERMISSION_VALUES = 2 - OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space - def __init__(self, _config_values, _transaction_list): """ Init. @@ -149,8 +146,11 @@ class Primaite(Env): # The action type self.action_type = 0 - # TODO: proper description here - self.obs_config: dict + # stores the observation config from the yaml, default is NODE_LINK_TABLE + self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]} + + # Observation Handler manages the user-configurable observation space. + # It will be initialised later. self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown @@ -192,10 +192,6 @@ class Primaite(Env): _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") - # # If it doesn't exist after parsing config, create default obs space. - # if getattr(self, "obs_handler", None) is None: - # self.configure_obs_space() - # Initiate observation space self.observation_space, self.env_obs = self.init_observations() @@ -648,13 +644,6 @@ class Primaite(Env): def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: """TODO: write docstring.""" - if getattr(self, "obs_config", None) is None: - self.obs_config = { - "components": [ - {"name": "NODE_LINK_TABLE"}, - ] - } - self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation From d473794963d55957baa1954fb7f873a3940a0ac6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 18:01:47 +0100 Subject: [PATCH 13/37] Let single-component spaces not use Tuple Spaces --- src/primaite/environment/observations.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a1b0d9ac..5bad056c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Tuple, Union import numpy as np from gym import spaces @@ -278,7 +278,7 @@ class ObservationsHandler: """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space - self.current_observation: Tuple[np.ndarray] + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] # i can access the registry items like this: # self.registry.LINK_TRAFFIC_LEVELS @@ -288,7 +288,12 @@ class ObservationsHandler: for obs in self.registered_obs_components: obs.update() current_obs.append(obs.current_observation) - self.current_observation = tuple(current_obs) + + # If there is only one component, don't use a tuple, just pass through that component's obs. + if len(current_obs) == 1: + self.current_observation = current_obs[0] + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """TODO: complete description.""" @@ -305,7 +310,12 @@ class ObservationsHandler: component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - self.space = spaces.Tuple(component_spaces) + + # If there is only one component, don't use a tuple space, just pass through that component's space. + if len(component_spaces) == 1: + self.space = component_spaces[0] + else: + self.space = spaces.Tuple(component_spaces) @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): From 084112a2e4e2fb419af81d2c6f6ffd8909c6b2d6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 21:28:38 +0100 Subject: [PATCH 14/37] Add docstrings to new observation code --- src/primaite/environment/observations.py | 133 ++++++++++++++++------- src/primaite/environment/primaite_env.py | 25 +++-- 2 files changed, 109 insertions(+), 49 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 5bad056c..c4402b69 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -33,18 +33,16 @@ class AbstractObservationComponent(ABC): @abstractmethod def update(self): - """Look at the environment and update the current observation value.""" + """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented class NodeLinkTable(AbstractObservationComponent): - """Table with nodes/links as rows and hardware/software status as cols. - - Initialise the observation space with the BOX option chosen. + """Table with nodes and links as rows and hardware/software status as cols. This will create the observation space formatted as a table of integers. There is one row per node, followed by one row per link. - Columns are as follows: + The number of columns is 4 plus one per service. They are: * node/link ID * node hardware status / 0 for links * node operating system status (if active/service) / 0 for links @@ -56,8 +54,6 @@ class NodeLinkTable(AbstractObservationComponent): For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be ``(12, 7)`` - #TODO: clean up description - """ _FIXED_PARAMETERS = 4 @@ -84,13 +80,9 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) def update(self): - """Update the observation. + """Update the observation based on current environment state. - Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_box_observations` - This function can only be called if the observation space setting is set to BOX. - TODO: complete description.. + The structure of the observation space is described in :class:`.NodeLinkTable` """ item_index = 0 nodes = self.env.nodes @@ -141,14 +133,30 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """TODO: complete description. + """Flat list of nodes' hardware, OS, file system, and service states. - This will create the observation space with node observations followed by link observations. - Each node has 3 elements in the observation space plus 1 per service, more specifically: - * hardware state - * operating system state - * file system state - * service states (one per service) + The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by + integers. + Each node has 3 elements plus 1 per service. It will have the following structure: + .. code-block:: + [ + node1 hardware state, + node1 OS state, + node1 file system state, + node1 service1 state, + node1 service2 state, + node1 serviceN state (one for each service), + node2 hardware state, + node2 OS state, + node2 file system state, + node2 service1 state, + node2 service2 state, + node2 serviceN state (one for each service), + ... + ] + + :param env: The environment that forms the basis of the observations + :type env: Primaite """ _DATA_TYPE = np.int64 @@ -172,14 +180,9 @@ class NodeStatuses(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update(self): - """TODO: complete description. - - Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_multidiscrete_observations` - This function can only be called if the observation space setting is set to MULTIDISCRETE. - + """Update the observation based on current environment state. + The structure of the observation space is described in :class:`.NodeStatuses` """ obs = [] for _, node in self.env.nodes.items(): @@ -201,15 +204,28 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """TODO: complete description. + """Flat list of traffic levels encoded into banded categories. - Each link has one element in the observation space, corresponding to the traffic load, - it can take the following values: + For each link, total traffic or traffic per service is encoded into a categorical value. + For example, if ``quantisation_levels=5``, the traffic levels represent these values: 0 = No traffic (0% of bandwidth) 1 = No traffic (0%-33% of bandwidth) 2 = No traffic (33%-66% of bandwidth) 3 = No traffic (66%-100% of bandwidth) 4 = No traffic (100% of bandwidth) + + .. note:: + The lowest category always corresponds to no traffic and the highest category to the link being at max capacity. + Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories. + + :param env: The environment that forms the basis of the observations + :type env: Primaite + :param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually, + defaults to False + :type combine_service_traffic: bool, optional + :param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value , + defaults to 5 + :type quantisation_levels: int, optional """ _DATA_TYPE = np.int64 @@ -220,7 +236,10 @@ class LinkTrafficLevels(AbstractObservationComponent): combine_service_traffic: bool = False, quantisation_levels: int = 5, ): + assert quantisation_levels >= 3 + super().__init__(env) + self._combine_service_traffic: bool = combine_service_traffic self._quantisation_levels: int = quantisation_levels self._entries_per_link: int = 1 @@ -240,7 +259,10 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update(self): - """TODO: complete description.""" + """Update the observation based on current environment state. + + The structure of the observation space is described in :class:`.LinkTrafficLevels` + """ obs = [] for _, link in self.env.links.items(): bandwidth = link.bandwidth @@ -265,7 +287,11 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: - """Component-based observation space handler.""" + """Component-based observation space handler. + + This allows users to configure observation spaces by mixing and matching components. + Each component can also define further parameters to make them more flexible. + """ registry = { "NODE_LINK_TABLE": NodeLinkTable, @@ -274,8 +300,6 @@ class ObservationsHandler: } def __init__(self): - """TODO: complete description.""" - """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space self.current_observation: Union[Tuple[np.ndarray], np.ndarray] @@ -283,7 +307,7 @@ class ObservationsHandler: # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): - """TODO: complete description.""" + """Fetch fresh information about the environment.""" current_obs = [] for obs in self.registered_obs_components: obs.update() @@ -296,17 +320,26 @@ class ObservationsHandler: self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): - """TODO: complete description.""" + """Add a component for this handler to track. + + :param obs_component: The component to add. + :type obs_component: AbstractObservationComponent + """ self.registered_obs_components.append(obs_component) self.update_space() def deregister(self, obs_component: AbstractObservationComponent): - """TODO: complete description.""" + """Remove a component from this handler. + + :param obs_component: Which component to remove. It must exist within this object's + ``registered_obs_components`` attribute. + :type obs_component: AbstractObservationComponent + """ self.registered_obs_components.remove(obs_component) self.update_space() def update_space(self): - """TODO: complete description.""" + """Rebuild the handler's composite observation space from its components.""" component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) @@ -319,10 +352,28 @@ class ObservationsHandler: @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): - """TODO: complete description. + """Parse a config dictinary, return a new observation handler populated with new observation component objects. - This method parses config items related to the observation space, then - creates the necessary components and adds them to the observation handler. + The expected format for the config dictionary is: + + ..code-block::python + config = { + components: [ + { + "name": "" + }, + { + "name": "" + "options": {"opt1": val1, "opt2": val2} + }, + { + ... + }, + ] + } + + :return: Observation handler + :rtype: primaite.environment.observations.ObservationsHandler """ # Instantiate the handler handler = cls() diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 81557075..8cff91d8 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,7 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import networkx as nx import numpy as np @@ -643,16 +643,17 @@ class Primaite(Env): pass def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """TODO: write docstring.""" + """Create the environment's observation handler. + + :return: The observation space, initial observation (zeroed out array with the correct shape) + :rtype: Tuple[spaces.Space, np.ndarray] + """ self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): - """Updates the observation space based on the node and link status. - - TODO: better docstring - """ + """Updates the observation space based on the node and link status.""" self.obs_handler.update_obs() self.env_obs = self.obs_handler.current_observation @@ -1024,8 +1025,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def save_obs_config(self, obs_config: Optional[Dict] = None): - """TODO: better docstring.""" + def save_obs_config(self, obs_config: dict): + """Cache the config for the observation space. + + This is necessary as the observation space can't be built while reading the config, + it must be done after all the nodes, links, and services have been initialised. + + :param obs_config: Parsed config relating to the observation space. The format is described in + :py:meth:`primaite.environment.observations.ObservationsHandler.from_config` + :type obs_config: dict + """ self.obs_config = obs_config def get_steps_info(self, steps_info): From ac31c996a740ceb3b9314fe4c6b9493285bdc98a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 21:42:34 +0100 Subject: [PATCH 15/37] Update docs page on observations --- docs/source/about.rst | 49 +++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 8cc08b13..ee84d880 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -182,16 +182,13 @@ All ACL rules are considered when applying an IER. Logic follows the order of ru Observation Spaces ****************** +The observation space provides the blue agent with information about the current status of nodes and links. -The OpenAI Gym observation space provides the status of all nodes and links across the whole system:​ +PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted. -* Nodes (in terms of hardware state, Software State, file system state and services state) ​ -* Links (in terms of current loading for each service/protocol) - -The observation space can be configured as a ``gym.spaces.Box`` or ``gym.spaces.MultiDiscrete``, by setting the ``OBSERVATIONS`` parameter in the laydown config. - -Box-type observation space --------------------------- +NodeLinkTable component +----------------------- +For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below: An example observation space is provided below: @@ -249,8 +246,6 @@ An example observation space is provided below: - 5000 - 0 -The observation space is a 6 x 6 Box type (OpenAI Gym Space) in this example. This is made up from the node and link information detailed below. - For the nodes, the following values are represented: * ID @@ -290,9 +285,9 @@ For the links, the following statuses are represented: * SoftwareState = N/A * Protocol = loading in bits/s -MultiDiscrete-type observation space ------------------------------------- -The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers. +NodeStatus component +---------------------- +This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states, represented by integers. The example above would have the following structure: .. code-block:: @@ -301,9 +296,6 @@ The example above would have the following structure: node1_info node2_info node3_info - link1_status - link2_status - link3_status ] Each ``node_info`` contains the following: @@ -318,7 +310,25 @@ Each ``node_info`` contains the following: service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED) ] -Each ``link_status`` is just a number from 0-4 representing the network load in relation to bandwidth. +In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: + +.. code-block:: + + gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4]) + +LinkTrafficLevels +----------------- +This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value. +The number of bins can be customised with 5 being the default. It has the following strucutre: +.. code-block:: + + [ + link1_status + link2_status + link3_status + ] + +Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth. .. code-block:: @@ -328,12 +338,11 @@ Each ``link_status`` is just a number from 0-4 representing the network load in 3 = high traffic (<100%) 4 = max traffic/ overwhelmed (100%) -The full observation space would have 15 node-related elements and 3 link-related elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: +If the network has three links, the full observation space would have 3 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example: .. code-block:: - gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) - + gym.spaces.MultiDiscrete([5,5,5]) Action Spaces ************** From 16777b30add51902816a0518e223d211dcb2a18b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Jun 2023 21:56:05 +0100 Subject: [PATCH 16/37] begin updating observations tests --- tests/test_observation_space.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 6a187761..a13121b9 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,5 +1,6 @@ """Test env creation and behaviour with different observation spaces.""" +from primaite.environment.observations import NodeStatuses, ObservationsHandler from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config @@ -32,3 +33,13 @@ def test_creating_env_with_multidiscrete_obs(): # the nodes have hardware, OS, FS, and service, the links just have bandwidth, # therefore we need 3*4 + 2 observations assert env.env_obs.shape == (3 * 4 + 2,) + + +def test_component_registration(): + """Test that we can register and deregister a component.""" + handler = ObservationsHandler() + component = NodeStatuses() + handler.register(component) + assert component in handler.registered_obs_components + handler.deregister(component) + assert component not in handler.registered_obs_components From 602bf9ba9a3bce52cfe2a7fbc715f8203e31884a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 09:10:53 +0100 Subject: [PATCH 17/37] Edit configs for observation space --- src/primaite/config/config_1_DDOS_BASIC.yaml | 8 -------- .../laydown_with_LINK_TRAFFIC_LEVELS.yaml} | 5 +++-- .../laydown_with_NODE_LINK_TABLE.yaml} | 8 ++++++-- 3 files changed, 9 insertions(+), 12 deletions(-) rename tests/config/{box_obs_space_laydown_config.yaml => obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml} (94%) rename tests/config/{multidiscrete_obs_space_laydown_config.yaml => obs_tests/laydown_with_NODE_LINK_TABLE.yaml} (87%) diff --git a/src/primaite/config/config_1_DDOS_BASIC.yaml b/src/primaite/config/config_1_DDOS_BASIC.yaml index a1961df3..ada813f3 100644 --- a/src/primaite/config/config_1_DDOS_BASIC.yaml +++ b/src/primaite/config/config_1_DDOS_BASIC.yaml @@ -1,13 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: NODE_LINK_TABLE - - name: NODE_STATUSES - - name: LINK_TRAFFIC_LEVELS - options: - - combine_service_traffic : False - - quantisation_levels : 7 - itemType: STEPS steps: 128 - itemType: PORTS diff --git a/tests/config/box_obs_space_laydown_config.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml similarity index 94% rename from tests/config/box_obs_space_laydown_config.yaml rename to tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml index 203bc0e7..d1909125 100644 --- a/tests/config/box_obs_space_laydown_config.yaml +++ b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml @@ -1,7 +1,8 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: BOX +- itemType: OBSERVATION_SPACE + components: + - name: NODE_STATUSES - itemType: STEPS steps: 5 - itemType: PORTS diff --git a/tests/config/multidiscrete_obs_space_laydown_config.yaml b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml similarity index 87% rename from tests/config/multidiscrete_obs_space_laydown_config.yaml rename to tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml index 38438d6d..36fb8199 100644 --- a/tests/config/multidiscrete_obs_space_laydown_config.yaml +++ b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml @@ -1,7 +1,11 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATIONS - type: MULTIDISCRETE +- itemType: OBSERVATION_SPACE + components: + - name: NODE_LINK_TABLE + options: + - combine_service_traffic: false + - quantisation_levels: 8 - itemType: STEPS steps: 5 - itemType: PORTS From d0c11a14ed2214ab0257f268bbed5ecca1783f15 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 2 Jun 2023 09:51:15 +0100 Subject: [PATCH 18/37] 893 - added ANY to enums.py --- src/primaite/common/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 20660e86..0e43ea38 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -82,6 +82,7 @@ class ActionType(Enum): NODE = 0 ACL = 1 + ANY = 2 class FileSystemState(Enum): From 66fdae5df17f18ff7d5674b6e8088aa8b973c835 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 2 Jun 2023 11:55:31 +0100 Subject: [PATCH 19/37] 893 - added test which shows the new action space has been created when ANY is selected in single_action_space_lay_down_config.yaml --- src/primaite/environment/primaite_env.py | 7 ++++--- tests/test_single_action_space.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5d783af1..9b0bbeec 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -249,11 +249,13 @@ class Primaite(Env): self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) print(self.action_space, "ACL action space") - else: - _LOGGER.warning("Action space type ANY selected - Node + ACL") + elif self.action_type == ActionType.ANY: + _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) print(self.action_space, "ANY action space") + else: + _LOGGER.info("Invalid action type selected") # Set up a csv to store the results of the training try: now = datetime.now() # current date and time @@ -1271,6 +1273,5 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} - logging.warning("logging is working") # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 8c87d57b..203a6232 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -9,4 +9,18 @@ def test_single_action_space(): lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) - print("Average Reward:", env.average_reward) + """ + nv.action_space.n is the total number of actions in the Discrete action space + This is the number of actions the agent has to choose from. + + The total number of actions that an agent can type when a NODE action type is selected is: 6 + The total number of actions that an agent can take when an ACL action type is selected is: 7 + + These action spaces are combined and the total number of actions is: 12 + This is due to both actions containing the action to "Do nothing", so it needs to be removed from one of the spaces, + to avoid duplicate actions. + + As a result, 12 is the total number of action spaces. + """ + # e + assert env.action_space.n == 12 From 14096b3dd126796505934868a4ae23baf40368d5 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 12:59:01 +0100 Subject: [PATCH 20/37] Add tests for observations --- pytest.ini | 2 + src/primaite/environment/observations.py | 13 +- .../laydown_with_LINK_TRAFFIC_LEVELS.yaml | 43 ++++- .../laydown_with_NODE_LINK_TABLE.yaml | 11 +- .../obs_tests/laydown_with_NODE_STATUSES.yaml | 107 +++++++++++ .../obs_tests/laydown_without_obs_space.yaml | 74 ++++++++ .../obs_tests/main_config_no_agent.yaml | 89 +++++++++ tests/test_observation_space.py | 169 +++++++++++++++--- 8 files changed, 476 insertions(+), 32 deletions(-) create mode 100644 tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml create mode 100644 tests/config/obs_tests/laydown_without_obs_space.yaml create mode 100644 tests/config/obs_tests/main_config_no_agent.yaml diff --git a/pytest.ini b/pytest.ini index e618d7a5..b5fae8d0 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ [pytest] testpaths = tests +markers = + env_config_paths diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index c4402b69..a467a5db 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -165,14 +165,15 @@ class NodeStatuses(AbstractObservationComponent): super().__init__(env) # 1. Define the shape of your observation space component - shape = [ + node_shape = [ len(HardwareState) + 1, len(SoftwareState) + 1, len(FileSystemState) + 1, ] services_shape = [len(SoftwareState) + 1] * self.env.num_services - shape = shape + services_shape + node_shape = node_shape + services_shape + shape = node_shape * self.env.num_nodes # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) @@ -199,7 +200,9 @@ class NodeStatuses(AbstractObservationComponent): for i, service in enumerate(self.env.services_list): if node.has_service(service): service_states[i] = node.get_service_state(service).value - obs.extend([hardware_state, software_state, file_system_state, *service_states]) + obs.extend( + [hardware_state, software_state, file_system_state, *service_states] + ) self.current_observation[:] = obs @@ -303,8 +306,6 @@ class ObservationsHandler: self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space self.current_observation: Union[Tuple[np.ndarray], np.ndarray] - # i can access the registry items like this: - # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): """Fetch fresh information about the environment.""" @@ -318,6 +319,7 @@ class ObservationsHandler: self.current_observation = current_obs[0] else: self.current_observation = tuple(current_obs) + # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -349,6 +351,7 @@ class ObservationsHandler: self.space = component_spaces[0] else: self.space = spaces.Tuple(component_spaces) + # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): diff --git a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml index d1909125..516bf5cc 100644 --- a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml @@ -2,15 +2,20 @@ type: NODE - itemType: OBSERVATION_SPACE components: - - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 - itemType: STEPS steps: 5 - itemType: PORTS portsList: - port: '80' + - port: '53' - itemType: SERVICES serviceList: - name: TCP + - name: UDP ######################################## # Nodes @@ -28,6 +33,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '2' name: SERVER @@ -42,6 +50,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '3' name: SWITCH1 @@ -67,3 +78,33 @@ bandwidth: 1000 source: '3' destination: '2' + +######################################### +# IERS +- itemType: GREEN_IER + id: '5' + startStep: 0 + endStep: 5 + load: 20 + protocol: TCP + port: '80' + source: '1' + destination: '2' + missionCriticality: 5 + +######################################### +# ACL Rules +- itemType: ACL_RULE + id: '6' + permission: ALLOW + source: 192.168.1.1 + destination: 192.168.1.2 + protocol: TCP + port: 80 +- itemType: ACL_RULE + id: '7' + permission: ALLOW + source: 192.168.1.2 + destination: 192.168.1.1 + protocol: TCP + port: 80 diff --git a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml index 36fb8199..0ceefbfa 100644 --- a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml @@ -3,17 +3,16 @@ - itemType: OBSERVATION_SPACE components: - name: NODE_LINK_TABLE - options: - - combine_service_traffic: false - - quantisation_levels: 8 - itemType: STEPS steps: 5 - itemType: PORTS portsList: - port: '80' + - port: '53' - itemType: SERVICES serviceList: - name: TCP + - name: UDP ######################################## # Nodes @@ -31,6 +30,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '2' name: SERVER @@ -45,6 +47,9 @@ - name: TCP port: '80' state: GOOD + - name: UDP + port: '53' + state: GOOD - itemType: NODE node_id: '3' name: SWITCH1 diff --git a/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml b/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml new file mode 100644 index 00000000..56ff3725 --- /dev/null +++ b/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml @@ -0,0 +1,107 @@ +- itemType: ACTIONS + type: NODE +- itemType: OBSERVATION_SPACE + components: + - name: NODE_STATUSES +- itemType: STEPS + steps: 5 +- itemType: PORTS + portsList: + - port: '80' + - port: '53' +- itemType: SERVICES + serviceList: + - name: TCP + - name: UDP + +######################################## +# Nodes +- itemType: NODE + node_id: '1' + name: PC1 + node_class: SERVICE + node_type: COMPUTER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.1 + software_state: COMPROMISED + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: GOOD +- itemType: NODE + node_id: '2' + name: SERVER + node_class: SERVICE + node_type: SERVER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.2 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: OVERWHELMED +- itemType: NODE + node_id: '3' + name: SWITCH1 + node_class: ACTIVE + node_type: SWITCH + priority: P2 + hardware_state: 'ON' + ip_address: 192.168.1.3 + software_state: GOOD + file_system_state: GOOD + +######################################## +# Links +- itemType: LINK + id: '4' + name: link1 + bandwidth: 1000 + source: '1' + destination: '3' +- itemType: LINK + id: '5' + name: link2 + bandwidth: 1000 + source: '3' + destination: '2' + +######################################### +# IERS +- itemType: GREEN_IER + id: '5' + startStep: 0 + endStep: 5 + load: 20 + protocol: TCP + port: '80' + source: '1' + destination: '2' + missionCriticality: 5 + +######################################### +# ACL Rules +- itemType: ACL_RULE + id: '6' + permission: ALLOW + source: 192.168.1.1 + destination: 192.168.1.2 + protocol: TCP + port: 80 +- itemType: ACL_RULE + id: '7' + permission: ALLOW + source: 192.168.1.2 + destination: 192.168.1.1 + protocol: TCP + port: 80 diff --git a/tests/config/obs_tests/laydown_without_obs_space.yaml b/tests/config/obs_tests/laydown_without_obs_space.yaml new file mode 100644 index 00000000..3ef214da --- /dev/null +++ b/tests/config/obs_tests/laydown_without_obs_space.yaml @@ -0,0 +1,74 @@ +- itemType: ACTIONS + type: NODE +- itemType: STEPS + steps: 5 +- itemType: PORTS + portsList: + - port: '80' + - port: '53' +- itemType: SERVICES + serviceList: + - name: TCP + - name: UDP + +######################################## +# Nodes +- itemType: NODE + node_id: '1' + name: PC1 + node_class: SERVICE + node_type: COMPUTER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.1 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: GOOD +- itemType: NODE + node_id: '2' + name: SERVER + node_class: SERVICE + node_type: SERVER + priority: P5 + hardware_state: 'ON' + ip_address: 192.168.1.2 + software_state: GOOD + file_system_state: GOOD + services: + - name: TCP + port: '80' + state: GOOD + - name: UDP + port: '53' + state: GOOD +- itemType: NODE + node_id: '3' + name: SWITCH1 + node_class: ACTIVE + node_type: SWITCH + priority: P2 + hardware_state: 'ON' + ip_address: 192.168.1.3 + software_state: GOOD + file_system_state: GOOD + +######################################## +# Links +- itemType: LINK + id: '4' + name: link1 + bandwidth: 1000 + source: '1' + destination: '3' +- itemType: LINK + id: '5' + name: link2 + bandwidth: 1000 + source: '3' + destination: '2' diff --git a/tests/config/obs_tests/main_config_no_agent.yaml b/tests/config/obs_tests/main_config_no_agent.yaml new file mode 100644 index 00000000..f632dca9 --- /dev/null +++ b/tests/config/obs_tests/main_config_no_agent.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index a13121b9..314728ae 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -1,45 +1,168 @@ """Test env creation and behaviour with different observation spaces.""" +import numpy as np +import pytest -from primaite.environment.observations import NodeStatuses, ObservationsHandler +from primaite.environment.observations import ( + NodeLinkTable, + NodeStatuses, + ObservationsHandler, +) +from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config -def test_creating_env_with_box_obs(): - """Try creating env with box observation space.""" +@pytest.fixture +def env(request): + """Build Primaite environment for integration tests of observation space.""" + marker = request.node.get_closest_marker("env_config_paths") + main_config_path = marker.args[0]["main_config_path"] + lay_down_config_path = marker.args[0]["lay_down_config_path"] env = _get_primaite_env_from_config( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", + main_config_path=main_config_path, + lay_down_config_path=lay_down_config_path, ) - env.update_environent_obs() - - # we have three nodes and two links, with one service - # therefore the box observation space will have: - # * 5 columns (four fixed and one for the service) - # * 5 rows (3 nodes + 2 links) - assert env.env_obs.shape == (5, 5) + yield env -def test_creating_env_with_multidiscrete_obs(): - """Try creating env with MultiDiscrete observation space.""" - env = _get_primaite_env_from_config( +@pytest.mark.env_config_paths( + dict( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT - / "multidiscrete_obs_space_laydown_config.yaml", + / "obs_tests/laydown_without_obs_space.yaml", ) +) +def test_default_obs_space(env: Primaite): + """Create environment with no obs space defined in config and check that the default obs space was created.""" env.update_environent_obs() - # we have three nodes and two links, with one service - # the nodes have hardware, OS, FS, and service, the links just have bandwidth, - # therefore we need 3*4 + 2 observations - assert env.env_obs.shape == (3 * 4 + 2,) + components = env.obs_handler.registered_obs_components + + assert len(components) == 1 + assert isinstance(components[0], NodeLinkTable) -def test_component_registration(): - """Test that we can register and deregister a component.""" +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_without_obs_space.yaml", + ) +) +def test_registering_components(env: Primaite): + """Test regitering and deregistering a component.""" handler = ObservationsHandler() - component = NodeStatuses() + component = NodeStatuses(env) handler.register(component) assert component in handler.registered_obs_components handler.deregister(component) assert component not in handler.registered_obs_components + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_with_NODE_LINK_TABLE.yaml", + ) +) +class TestNodeLinkTable: + """Test the NodeLinkTable observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with box observation space.""" + env.update_environent_obs() + + # we have three nodes and two links, with two service + # therefore the box observation space will have: + # * 5 rows (3 nodes + 2 links) + # * 6 columns (four fixed and two for the services) + assert env.env_obs.shape == (5, 6) + + # def test_value(self, env: Primaite): + # """""" + # ... + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_with_NODE_STATUSES.yaml", + ) +) +class TestNodeStatuses: + """Test the NodeStatuses observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with NodeStatuses as the only component.""" + assert env.env_obs.shape == (15) + + def test_values(self, env: Primaite): + """Test that the hardware and software states are encoded correctly. + + The laydown has: + * one node with a compromised operating system state + * one node with two services, and the second service is overwhelmed. + * all other states are good or null + Therefore, the expected state is: + * node 1: + * hardware = good (1) + * OS = compromised (3) + * file system = good (1) + * service 1 = good (1) + * service 2 = good (1) + * node 2: + * hardware = good (1) + * OS = good (1) + * file system = good (1) + * service 1 = good (1) + * service 2 = overwhelmed (4) + * node 3 (switch): + * hardware = good (1) + * OS = good (1) + * file system = good (1) + * service 1 = n/a (0) + * service 2 = n/a (0) + """ + act = np.asarray([0, 0, 0, 0]) + obs, _, _, _ = env.step(act) + assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) + + +@pytest.mark.env_config_paths( + dict( + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml", + ) +) +class TestLinkTrafficLevels: + """Test the LinkTrafficLevels observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with MultiDiscrete observation space.""" + env.update_environent_obs() + + # we have two links and two services, so the shape should be 2 * 2 + assert env.env_obs.shape == (2 * 2,) + + def test_values(self, env: Primaite): + """Test that traffic values are encoded correctly. + + The laydown has: + * two services + * three nodes + * two links + * an IER trying to send 20 bits of data over both links the whole time (via the first service) + * link bandwidth of 1000, therefore the utilisation is 2% + """ + act = np.asarray([0, 0, 0, 0]) + obs, reward, done, info = env.step(act) + obs, reward, done, info = env.step(act) + + # the observation space has combine_service_traffic set to False, so the space has this format: + # [link1_service1, link1_service2, link2_service1, link2_service2] + # we send 20 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 1 and all others 0 + assert np.array_equal(obs, [1, 0, 1, 0]) From 73adfbb6dd7406141b202b5882eca304cc9cacb8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 13:08:11 +0100 Subject: [PATCH 21/37] Get observation tests passing --- tests/test_observation_space.py | 64 ++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 314728ae..3fe71003 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -79,9 +79,65 @@ class TestNodeLinkTable: # * 6 columns (four fixed and two for the services) assert env.env_obs.shape == (5, 6) - # def test_value(self, env: Primaite): - # """""" - # ... + def test_value(self, env: Primaite): + """Test that the observation is generated correctly. + + The laydown has: + * 3 nodes (2 service nodes and 1 active node) + * 2 services + * 2 links + + Both nodes have both services, and all states are GOOD, therefore the expected observation value is: + + * Node 1: + * 1 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 1 (good service1 state) + * 1 (good service2 state) + * Node 2: + * 2 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 1 (good service1 state) + * 1 (good service2 state) + * Node 3 (active node): + * 3 (id) + * 1 (good hardware state) + * 1 (good OS state) + * 1 (good file system state) + * 0 (doesn't have service1) + * 0 (doesn't have service2) + * Link 1: + * 4 (id) + * 0 (n/a hardware state) + * 0 (n/a OS state) + * 0 (n/a file system state) + * 0 (no traffic for service1) + * 0 (no traffic for service2) + * Link 2: + * 5 (id) + * 0 (good hardware state) + * 0 (good OS state) + * 0 (good file system state) + * 0 (no traffic service1) + * 0 (no traffic for service2) + """ + act = np.asarray([0, 0, 0, 0]) + obs, reward, done, info = env.step(act) + + assert np.array_equal( + obs, + [ + [1, 1, 1, 1, 1, 1], + [2, 1, 1, 1, 1, 1], + [3, 1, 1, 1, 0, 0], + [4, 0, 0, 0, 0, 0], + [5, 0, 0, 0, 0, 0], + ], + ) @pytest.mark.env_config_paths( @@ -96,7 +152,7 @@ class TestNodeStatuses: def test_obs_shape(self, env: Primaite): """Try creating env with NodeStatuses as the only component.""" - assert env.env_obs.shape == (15) + assert env.env_obs.shape == (15,) def test_values(self, env: Primaite): """Test that the hardware and software states are encoded correctly. From f14910ca965787985737fa632aabe84ff9b41b72 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 13:15:38 +0100 Subject: [PATCH 22/37] Fix Link Traffic Levels observation encoding --- src/primaite/environment/observations.py | 2 +- .../obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml | 2 +- tests/test_observation_space.py | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a467a5db..a598d6db 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -281,7 +281,7 @@ class LinkTrafficLevels(AbstractObservationComponent): traffic_level = self._quantisation_levels - 1 else: traffic_level = (load / bandwidth) // ( - 1 / (self._quantisation_levels - 1) + 1 / (self._quantisation_levels - 2) ) + 1 obs.append(int(traffic_level)) diff --git a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml index 516bf5cc..e65ea306 100644 --- a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml @@ -85,7 +85,7 @@ id: '5' startStep: 0 endStep: 5 - load: 20 + load: 999 protocol: TCP port: '80' source: '1' diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 3fe71003..ae862c96 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -210,8 +210,8 @@ class TestLinkTrafficLevels: * two services * three nodes * two links - * an IER trying to send 20 bits of data over both links the whole time (via the first service) - * link bandwidth of 1000, therefore the utilisation is 2% + * an IER trying to send 999 bits of data over both links the whole time (via the first service) + * link bandwidth of 1000, therefore the utilisation is 99.9% """ act = np.asarray([0, 0, 0, 0]) obs, reward, done, info = env.step(act) @@ -219,6 +219,7 @@ class TestLinkTrafficLevels: # the observation space has combine_service_traffic set to False, so the space has this format: # [link1_service1, link1_service2, link2_service1, link2_service2] - # we send 20 bits of data via link1 and link2 on service 1. - # therefore the first and third elements should be 1 and all others 0 - assert np.array_equal(obs, [1, 0, 1, 0]) + # we send 999 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 6 and all others 0 + # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) + assert np.array_equal(obs, [6, 0, 6, 0]) From eaa192eeec9ef060d8c175c4480911ca43056937 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Jun 2023 13:23:03 +0100 Subject: [PATCH 23/37] Update docs with configurable obs space info --- docs/source/config.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/config.rst b/docs/source/config.rst index 88399973..8a8515ca 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -288,6 +288,28 @@ The config_[name].yaml file consists of the following attributes: Determines whether a NODE or ACL action space format is adopted for the session +* **itemType: OBSERVATION_SPACE** [dict] + + Allows for user to configure observation space by combining one or more observation components. List of available + components is is :py:mod:'primaite.environment.observations'. + + The observation space config item should have a ``components`` key which is a list of components. Each component + config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the + component while it is being initialised. + + This example illustrates the correct format for the observation space config item + +.. code-block::yaml + + - itemType: OBSERVATION_SPACE + components: + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 + - name: NODE_STATUSES + - name: LINK_TRAFFIC_LEVELS + * **itemType: STEPS** [int] Determines the number of steps to run in each episode of the session From 1a7d629d5ac7f51150e987edbdaaf4481eaba953 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 11:00:41 +0100 Subject: [PATCH 24/37] 893 - added new tests to test action space size and node is completing both sets of actions in a single episode and created new main config --- src/primaite/environment/primaite_env.py | 1 + ..._space_fixed_blue_actions_main_config.yaml | 89 +++++++++++++++++++ .../single_action_space_lay_down_config.yaml | 50 ++++++----- tests/conftest.py | 40 +++++++++ tests/test_single_action_space.py | 57 +++++++++--- 5 files changed, 199 insertions(+), 38 deletions(-) create mode 100644 tests/config/single_action_space_fixed_blue_actions_main_config.yaml diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 9b0bbeec..be16590f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1273,5 +1273,6 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} + print("combined_action_dict entry", combined_action_dict.items()) # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml new file mode 100644 index 00000000..becbc0f3 --- /dev/null +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: GENERIC_TEST +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: single_action_space_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Operating State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node O/S or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml index 6d44356f..ab3b170e 100644 --- a/tests/config/single_action_space_lay_down_config.yaml +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -15,39 +15,41 @@ node_type: COMPUTER priority: P1 hardware_state: 'ON' + ip_address: 192.168.0.14 + software_state: GOOD + file_system_state: GOOD + services: + - name: ftp + port: '21' + state: COMPROMISED +- itemType: NODE + node_id: '2' + name: server_1 + node_class: SERVICE + node_type: SERVER + priority: P1 + hardware_state: 'ON' ip_address: 192.168.0.1 software_state: GOOD file_system_state: GOOD services: - name: ftp port: '21' - state: GOOD + state: COMPROMISED - itemType: POSITION positions: - node: '1' x_pos: 309 y_pos: 78 -- itemType: RED_POL - id: '1' - startStep: 1 - endStep: 3 - targetNodeId: '1' - initiator: DIRECT - type: FILE - protocol: NA - state: CORRUPT - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- itemType: RED_POL - id: '2' - startStep: 3 + - node: '2' + x_pos: 200 + y_pos: 78 +- itemType: RED_IER + id: '3' + startStep: 2 endStep: 15 - targetNodeId: '1' - initiator: DIRECT - type: FILE - protocol: NA - state: GOOD - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA + load: 1000 + protocol: ftp + port: CORRUPT + source: '1' + destination: '2' diff --git a/tests/conftest.py b/tests/conftest.py index 1e987223..dd732e78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -164,10 +164,13 @@ def _get_primaite_env_from_config( # Load in config data load_config_values() env = Primaite(config_values, []) + # Get the number of steps (which is stored in the child config file) config_values.num_steps = env.episode_steps if env.config_values.agent_identifier == "GENERIC": run_generic(env, config_values) + elif env.config_values.agent_identifier == "GENERIC_TEST": + run_generic_set_actions(env, config_values) return env @@ -197,3 +200,40 @@ def run_generic(env, config_values): # env.reset() # env.close() + + +def run_generic_set_actions(env, config_values): + """Run against a generic agent with specified blue agent actions.""" + # Reset the environment at the start of the episode + # env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = 0 + if step == 5: + # [1, 1, 2, 1, 1, 1] + # Creates an ACL rule + # Deny traffic from server_1 to node_1 on port FTP + action = 7 + elif step == 7: + # [1, 1, 2, 0] Node Action + # Sets Node 1 Hardware State to OFF + # Does not resolve any service + action = 16 + print(action, "ran") + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + # env.reset() + + # env.close() diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 203a6232..fda4c96c 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,26 +1,55 @@ +from primaite.common.enums import HardwareState from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config -def test_single_action_space(): +def test_single_action_space_is_valid(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" env = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) - """ - nv.action_space.n is the total number of actions in the Discrete action space - This is the number of actions the agent has to choose from. + # Retrieve the action space dictionary values from environment + env_action_space_dict = env.action_dict.values() + # Flags to check the conditions of the action space + contains_acl_actions = False + contains_node_actions = False + both_action_spaces = False + # Loop through each element of the list (which is every value from the dictionary) + for dict_item in env_action_space_dict: + # Node action detected + if len(dict_item) == 4: + contains_node_actions = True + # Link action detected + elif len(dict_item) == 6: + contains_acl_actions = True + # If both are there then the ANY action type is working + if contains_node_actions and contains_acl_actions: + both_action_spaces = True + # Check condition should be True + assert both_action_spaces - The total number of actions that an agent can type when a NODE action type is selected is: 6 - The total number of actions that an agent can take when an ACL action type is selected is: 7 - These action spaces are combined and the total number of actions is: 12 - This is due to both actions containing the action to "Do nothing", so it needs to be removed from one of the spaces, - to avoid duplicate actions. - - As a result, 12 is the total number of action spaces. - """ - # e - assert env.action_space.n == 12 +def test_agent_is_executing_actions_from_both_spaces(): + """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", + ) + # Retrieve hardware state of computer_1 node in laydown config + # Agent turned this off in Step 5 + computer_node_hardware_state = env.nodes["1"].hardware_state + # Retrieve the Access Control List object stored by the environment at the end of the episode + access_control_list = env.acl + # Use the Access Control List object acl object attribute to get dictionary + # Use dictionary.values() to get total list of all items in the dictionary + acl_rules_list = access_control_list.acl.values() + # Length of this list tells you how many items are in the dictionary + # This number is the frequency of Access Control Rules in the environment + # In the scenario, we specified that the agent should create only 1 acl rule + num_of_rules = len(acl_rules_list) + # Therefore these statements below MUST be true + assert computer_node_hardware_state == HardwareState.OFF and num_of_rules == 1 From dc7be7d8e6e210a3ed6eec7d8236d3886789202e Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 11:10:38 +0100 Subject: [PATCH 25/37] 893 - set the action_space to NOTHING so test_reward.py passes and removed unnecessary test print statements --- src/primaite/environment/primaite_env.py | 20 +------------------- tests/conftest.py | 5 ++--- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index be16590f..33af7d89 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -236,7 +236,6 @@ class Primaite(Env): # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - print(self.action_space, "NODE action space") elif self.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): @@ -248,12 +247,10 @@ class Primaite(Env): # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - print(self.action_space, "ACL action space") elif self.action_type == ActionType.ANY: _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - print(self.action_space, "ANY action space") else: _LOGGER.info("Invalid action type selected") # Set up a csv to store the results of the training @@ -455,7 +452,6 @@ class Primaite(Env): Args: _action: The action space from the agent """ - # print("intepret action") # At the moment, actions are only affecting nodes if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) @@ -470,7 +466,6 @@ class Primaite(Env): ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: - print("invalid action type found") logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): @@ -1091,7 +1086,6 @@ class Primaite(Env): item: A config data item representing action info """ self.action_type = ActionType[action_info["type"]] - print("action type selected: ", self.action_type) def get_steps_info(self, steps_info): """ @@ -1196,7 +1190,6 @@ class Primaite(Env): """ # reserve 0 action to be a nothing action actions = {0: [1, 0, 0, 0]} - # print("node dict function call", self.num_nodes + 1) action_key = 1 for node in range(1, self.num_nodes + 1): # 4 node properties (NONE, OPERATING, OS, SERVICE) @@ -1204,14 +1197,11 @@ class Primaite(Env): # Node Actions either: # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state # Use MAX to ensure we get them all - # print(self.num_services, "num services") for node_action in range(4): for service_state in range(self.num_services): action = [node, node_property, node_action, service_state] - # check to see if its a nothing aciton (has no effect) - # print("action node",action) + # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): - print("true") actions[action_key] = action action_key += 1 @@ -1223,7 +1213,6 @@ class Primaite(Env): actions = {0: [0, 0, 0, 0, 0, 0]} action_key = 1 - # print("node count",self.num_nodes + 1) # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE for action_decision in range(3): # 2 possible action permissions 0 = DENY, 1 = CREATE @@ -1241,13 +1230,10 @@ class Primaite(Env): protocol, port, ] - # print("action acl", action) # Check to see if its an action we want to include as possible i.e. not a nothing action if is_valid_acl_action_extra(action): - print("true") actions[action_key] = action action_key += 1 - # print("false") return actions @@ -1261,8 +1247,6 @@ class Primaite(Env): node_action_dict = self.create_node_action_dict() acl_action_dict = self.create_acl_action_dict() - print(len(node_action_dict), len(acl_action_dict)) - # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other new_node_action_dict = { @@ -1273,6 +1257,4 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} - print("combined_action_dict entry", combined_action_dict.items()) - # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/conftest.py b/tests/conftest.py index dd732e78..3a99bcf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -184,8 +184,8 @@ def run_generic(env, config_values): # Send the observation space to the agent to get an action # TEMP - random action for now # action = env.blue_agent_action(obs) - action = env.action_space.sample() - + # action = env.action_space.sample() + action = 0 # Run the simulation step on the live environment obs, reward, done, info = env.step(action) @@ -222,7 +222,6 @@ def run_generic_set_actions(env, config_values): # Sets Node 1 Hardware State to OFF # Does not resolve any service action = 16 - print(action, "ran") # Run the simulation step on the live environment obs, reward, done, info = env.step(action) From 5add9d620ce2a8e70da43fcb226a487af89a774d Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 11:57:04 +0100 Subject: [PATCH 26/37] 893 - updated the docs to reflect changes made to action space --- docs/source/about.rst | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 8cc08b13..242d34bb 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -334,6 +334,8 @@ The full observation space would have 15 node-related elements and 3 link-relate gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) + * Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...} + The placeholders inside the list under the key '1' mean the following: Action Spaces ************** @@ -342,29 +344,38 @@ The action space available to the blue agent comes in two types: 1. Node-based 2. Access Control List + 3. Any The choice of action space used during a training session is determined in the config_[name].yaml file. **Node-Based** -The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym multidiscrete type, as follows: +The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows: - * [0, num nodes] - Node ID (0 = nothing, node ID) - * [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state) - * [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore) - * [0, num services] - Resolves to service ID (0 = nothing, resolves to service) + * Dictionary item {... ,1: [x1, x2, x3,x4] ...} + The placeholders inside the list under the key '1' mean the following: + + * [0, num nodes] - Node ID (0 = nothing, node ID) + * [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state) + * [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore) + * [0, num services] - Resolves to service ID (0 = nothing, resolves to service) **Access Control List** -The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI multidiscrete type, as follows: +The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows: - * [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) - * [0, 1] - Permission (0 = DENY, 1 = ALLOW) - * [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) - * [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) - * [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) - * [0, num ports] - Port (0 = any, then 1 -> x resolving to port) + * [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) + * [0, 1] - Permission (0 = DENY, 1 = ALLOW) + * [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) + * [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) + * [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) + * [0, num ports] - Port (0 = any, then 1 -> x resolving to port) + +**ANY** +The agent is able to carry out both **Node-Based** and **Access Control List** operations. + +This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above. Rewards ******* From e0ed97be36ff891cc3e327979d4cf80d54cd223c Mon Sep 17 00:00:00 2001 From: Sunil Samra Date: Tue, 6 Jun 2023 12:07:22 +0000 Subject: [PATCH 27/37] Apply suggestions from code review --- docs/source/about.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 242d34bb..a59701f6 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -344,7 +344,7 @@ The action space available to the blue agent comes in two types: 1. Node-based 2. Access Control List - 3. Any + 3. Any (Agent can take both node-based and ACL-based actions) The choice of action space used during a training session is determined in the config_[name].yaml file. From 58a87ee0c8ad8ede1b20bcfd97f557416221b95d Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:12:28 +0100 Subject: [PATCH 28/37] 893 - applied changes raised during PR --- docs/source/about.rst | 9 ++++----- src/primaite/environment/primaite_env.py | 14 +++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 242d34bb..a16fadd3 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -334,9 +334,6 @@ The full observation space would have 15 node-related elements and 3 link-relate gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) - * Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...} - The placeholders inside the list under the key '1' mean the following: - Action Spaces ************** @@ -344,7 +341,7 @@ The action space available to the blue agent comes in two types: 1. Node-based 2. Access Control List - 3. Any + 3. Any (Agent can take both node-based and ACL-based actions) The choice of action space used during a training session is determined in the config_[name].yaml file. @@ -364,6 +361,8 @@ The agent is able to influence the status of nodes by switching them off, resett The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows: + * Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...} + The placeholders inside the list under the key '1' mean the following: * [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) * [0, 1] - Permission (0 = DENY, 1 = ALLOW) @@ -375,7 +374,7 @@ The blue agent is able to influence the configuration of the Access Control List **ANY** The agent is able to carry out both **Node-Based** and **Access Control List** operations. -This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above. +This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above. Rewards ******* diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index d2af20f9..4facb7b2 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -536,14 +536,14 @@ class Primaite(Env): _action: The action space from the agent """ # Convert discrete value back to multidiscrete - multidiscrete_action = self.action_dict[_action] + readable_action = self.action_dict[_action] - action_decision = multidiscrete_action[0] - action_permission = multidiscrete_action[1] - action_source_ip = multidiscrete_action[2] - action_destination_ip = multidiscrete_action[3] - action_protocol = multidiscrete_action[4] - action_port = multidiscrete_action[5] + action_decision = readable_action[0] + action_permission = readable_action[1] + action_source_ip = readable_action[2] + action_destination_ip = readable_action[3] + action_protocol = readable_action[4] + action_port = readable_action[5] if action_decision == 0: # It's decided to do nothing From efd0f6ed08653cbcfb67e2b3b6db17c8680261e6 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:21:04 +0100 Subject: [PATCH 29/37] 893 - returned config_values in conftest to move run_generic_set_actions into test_single_action_space.py --- src/primaite/environment/primaite_env.py | 2 +- tests/conftest.py | 40 +------------------ tests/test_observation_space.py | 4 +- tests/test_reward.py | 2 +- tests/test_single_action_space.py | 49 +++++++++++++++++++++++- 5 files changed, 52 insertions(+), 45 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 4facb7b2..1794504a 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -368,7 +368,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - print(f" Step {self.step_count} Reward: {str(reward)}") + # print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count diff --git a/tests/conftest.py b/tests/conftest.py index 3a99bcf6..740f65b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,10 +169,8 @@ def _get_primaite_env_from_config( if env.config_values.agent_identifier == "GENERIC": run_generic(env, config_values) - elif env.config_values.agent_identifier == "GENERIC_TEST": - run_generic_set_actions(env, config_values) - return env + return env, config_values def run_generic(env, config_values): @@ -200,39 +198,3 @@ def run_generic(env, config_values): # env.reset() # env.close() - - -def run_generic_set_actions(env, config_values): - """Run against a generic agent with specified blue agent actions.""" - # Reset the environment at the start of the episode - # env.reset() - for episode in range(0, config_values.num_episodes): - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - action = 0 - if step == 5: - # [1, 1, 2, 1, 1, 1] - # Creates an ACL rule - # Deny traffic from server_1 to node_1 on port FTP - action = 7 - elif step == 7: - # [1, 1, 2, 0] Node Action - # Sets Node 1 Hardware State to OFF - # Does not resolve any service - action = 16 - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - - # Reset the environment at the end of the episode - # env.reset() - - # env.close() diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 6a187761..d6eaa3b7 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -6,7 +6,7 @@ from tests.conftest import _get_primaite_env_from_config def test_creating_env_with_box_obs(): """Try creating env with box observation space.""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", ) @@ -21,7 +21,7 @@ def test_creating_env_with_box_obs(): def test_creating_env_with_multidiscrete_obs(): """Try creating env with MultiDiscrete observation space.""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "multidiscrete_obs_space_laydown_config.yaml", diff --git a/tests/test_reward.py b/tests/test_reward.py index 4925a434..c54ee32f 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -8,7 +8,7 @@ def test_rewards_are_being_penalised_at_each_step_function(): When the initial state is OFF compared to reference state which is ON. """ - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index fda4c96c..5fc6cb7e 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,15 +1,57 @@ +import time + from primaite.common.enums import HardwareState from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config +def run_generic_set_actions(env, config_values): + """Run against a generic agent with specified blue agent actions.""" + # Reset the environment at the start of the episode + # env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = 0 + print("Episode:", episode, "\nStep:", step) + if step == 5: + # [1, 1, 2, 1, 1, 1] + # Creates an ACL rule + # Deny traffic from server_1 to node_1 on port FTP + action = 7 + elif step == 7: + # [1, 1, 2, 0] Node Action + # Sets Node 1 Hardware State to OFF + # Does not resolve any service + action = 16 + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + # env.reset() + + # env.close() + + def test_single_action_space_is_valid(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) + + run_generic_set_actions(env, config_values) + # Retrieve the action space dictionary values from environment env_action_space_dict = env.action_dict.values() # Flags to check the conditions of the action space @@ -33,12 +75,15 @@ def test_single_action_space_is_valid(): def test_agent_is_executing_actions_from_both_spaces(): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) + + run_generic_set_actions(env, config_values) + # Retrieve hardware state of computer_1 node in laydown config # Agent turned this off in Step 5 computer_node_hardware_state = env.nodes["1"].hardware_state From 2e1bdf2361829b4a5e4732a65a14cb1f615e299f Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:23:08 +0100 Subject: [PATCH 30/37] 893 - changed action in conftest.py back to sample of the environment action space --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 740f65b7..41c1bc94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -182,8 +182,8 @@ def run_generic(env, config_values): # Send the observation space to the agent to get an action # TEMP - random action for now # action = env.blue_agent_action(obs) - # action = env.action_space.sample() - action = 0 + # action = 0 + action = env.action_space.sample() # Run the simulation step on the live environment obs, reward, done, info = env.step(action) From 10585490fe434e07a00f2f779cd5ecde3ad7e384 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:47:07 +0100 Subject: [PATCH 31/37] 893 - --- src/primaite/environment/primaite_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index bc52c887..03798c3b 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -368,7 +368,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count From 0817a4cad3a8f39caf6b925ca27c187caafaf90b Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:49:22 +0100 Subject: [PATCH 32/37] 893 - added consistent action for test_reward.py --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 41c1bc94..1b3a06b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -182,8 +182,9 @@ def run_generic(env, config_values): # Send the observation space to the agent to get an action # TEMP - random action for now # action = env.blue_agent_action(obs) - # action = 0 - action = env.action_space.sample() + # action = env.action_space.sample() + action = 0 + # Run the simulation step on the live environment obs, reward, done, info = env.step(action) From e17e5ac4b9e5b451a2d228233c6f0db68d5f0f4c Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 15:54:35 +0100 Subject: [PATCH 33/37] 893 - added new line for assert statements --- src/primaite/environment/primaite_env.py | 1 + tests/test_single_action_space.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 03798c3b..3de1111b 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -423,6 +423,7 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes + print("ACTION:", self.action_dict[_action]) if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.action_type == ActionType.ACL: diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 5fc6cb7e..10701a6a 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -19,7 +19,7 @@ def run_generic_set_actions(env, config_values): if step == 5: # [1, 1, 2, 1, 1, 1] # Creates an ACL rule - # Deny traffic from server_1 to node_1 on port FTP + # Allows traffic from server_1 to node_1 on port FTP action = 7 elif step == 7: # [1, 1, 2, 0] Node Action @@ -97,4 +97,5 @@ def test_agent_is_executing_actions_from_both_spaces(): # In the scenario, we specified that the agent should create only 1 acl rule num_of_rules = len(acl_rules_list) # Therefore these statements below MUST be true - assert computer_node_hardware_state == HardwareState.OFF and num_of_rules == 1 + assert computer_node_hardware_state == HardwareState.OFF + assert num_of_rules == 1 From 281bb786127a7cb1970cb7dd6abca224e2a1019f Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 7 Jun 2023 09:19:30 +0100 Subject: [PATCH 34/37] 893 - removed print statements for demonstration --- src/primaite/environment/primaite_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 3de1111b..03798c3b 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -423,7 +423,6 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes - print("ACTION:", self.action_dict[_action]) if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.action_type == ActionType.ACL: From 6089fb6950b83678da20dd2d8b1c4fbabf1f36b2 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 7 Jun 2023 14:39:52 +0100 Subject: [PATCH 35/37] 893 - removed unnecessary functions from utils.py and changed single_action_space_fixed_blue_actions_main_config.yaml back to GENERIC agentIdentifier after PR comments --- src/primaite/agents/utils.py | 386 +----------------- ..._space_fixed_blue_actions_main_config.yaml | 2 +- tests/test_single_action_space.py | 3 +- 3 files changed, 3 insertions(+), 388 deletions(-) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 1ada88ba..bb967906 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,288 +1,4 @@ -import logging -import os.path -from datetime import datetime - -import numpy as np -import yaml - -from primaite.common.config_values_main import ConfigValuesMain -from primaite.common.enums import ( - ActionType, - HardwareState, - LinkStatus, - NodeHardwareAction, - NodePOLType, - NodeSoftwareAction, - SoftwareState, -) - - -def load_config_values(config_path): - """Loads the config values from the main config file into a config object.""" - config_file_main = open(config_path, "r") - config_data = yaml.safe_load(config_file_main) - # Create a config class - config_values = ConfigValuesMain() - - try: - # Generic - config_values.red_agent_identifier = config_data["redAgentIdentifier"] - config_values.action_type = ActionType[config_data["actionType"]] - config_values.config_filename_use_case = config_data["configFilename"] - # Reward values - # Generic - config_values.all_ok = float(config_data["allOk"]) - # Node Operating State - config_values.off_should_be_on = float(config_data["offShouldBeOn"]) - config_values.off_should_be_resetting = float( - config_data["offShouldBeResetting"] - ) - config_values.on_should_be_off = float(config_data["onShouldBeOff"]) - config_values.on_should_be_resetting = float(config_data["onShouldBeResetting"]) - config_values.resetting_should_be_on = float(config_data["resettingShouldBeOn"]) - config_values.resetting_should_be_off = float( - config_data["resettingShouldBeOff"] - ) - # Node O/S or Service State - config_values.good_should_be_patching = float( - config_data["goodShouldBePatching"] - ) - config_values.good_should_be_compromised = float( - config_data["goodShouldBeCompromised"] - ) - config_values.good_should_be_overwhelmed = float( - config_data["goodShouldBeOverwhelmed"] - ) - config_values.patching_should_be_good = float( - config_data["patchingShouldBeGood"] - ) - config_values.patching_should_be_compromised = float( - config_data["patchingShouldBeCompromised"] - ) - config_values.patching_should_be_overwhelmed = float( - config_data["patchingShouldBeOverwhelmed"] - ) - config_values.compromised_should_be_good = float( - config_data["compromisedShouldBeGood"] - ) - config_values.compromised_should_be_patching = float( - config_data["compromisedShouldBePatching"] - ) - config_values.compromised_should_be_overwhelmed = float( - config_data["compromisedShouldBeOverwhelmed"] - ) - config_values.compromised = float(config_data["compromised"]) - config_values.overwhelmed_should_be_good = float( - config_data["overwhelmedShouldBeGood"] - ) - config_values.overwhelmed_should_be_patching = float( - config_data["overwhelmedShouldBePatching"] - ) - config_values.overwhelmed_should_be_compromised = float( - config_data["overwhelmedShouldBeCompromised"] - ) - config_values.overwhelmed = float(config_data["overwhelmed"]) - # IER status - config_values.red_ier_running = float(config_data["redIerRunning"]) - config_values.green_ier_blocked = float(config_data["greenIerBlocked"]) - # Patching / Reset durations - config_values.os_patching_duration = int(config_data["osPatchingDuration"]) - config_values.node_reset_duration = int(config_data["nodeResetDuration"]) - config_values.service_patching_duration = int( - config_data["servicePatchingDuration"] - ) - - except Exception as e: - print(f"Could not save load config data: {e} ") - - return config_values - - -def configure_logging(log_name): - """Configures logging.""" - try: - now = datetime.now() # current date and time - time = now.strftime("%Y%m%d_%H%M%S") - filename = "/app/logs/" + log_name + "/" + time + ".log" - path = f"/app/logs/{log_name}" - is_dir = os.path.isdir(path) - if not is_dir: - os.makedirs(path) - logging.basicConfig( - filename=filename, - filemode="w", - format="%(asctime)s - %(levelname)s - %(message)s", - datefmt="%d-%b-%y %H:%M:%S", - level=logging.INFO, - ) - except Exception as e: - print("ERROR: Could not start logging", e) - - -def transform_change_obs_readable(obs): - """Transform list of transactions to readable list of each observation property. - - example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] - """ - ids = [i for i in obs[:, 0]] - operating_states = [HardwareState(i).name for i in obs[:, 1]] - os_states = [SoftwareState(i).name for i in obs[:, 2]] - new_obs = [ids, operating_states, os_states] - - for service in range(3, obs.shape[1]): - # Links bit/s don't have a service state - service_states = [ - SoftwareState(i).name if i <= 4 else i for i in obs[:, service] - ] - new_obs.append(service_states) - - return new_obs - - -def transform_obs_readable(obs): - """ - Transform obs readable function. - - example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]. - """ - changed_obs = transform_change_obs_readable(obs) - new_obs = list(zip(*changed_obs)) - # Convert list of tuples to list of lists - new_obs = [list(i) for i in new_obs] - - return new_obs - - -def convert_to_new_obs(obs, num_nodes=10): - """Convert original gym Box observation space to new multiDiscrete observation space.""" - # Remove ID columns, remove links and flatten to MultiDiscrete observation space - new_obs = obs[:num_nodes, 1:].flatten() - return new_obs - - -def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """ - Convert to old observation, links filled with 0's as no information is included in new observation space. - - example: - obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) - - new_obs = array([[ 1, 1, 1, 1], - [ 2, 1, 1, 1], - [ 3, 1, 1, 1], - ... - [20, 0, 0, 0]]) - """ - # Convert back to more readable, original format - reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) - - # Add empty links back and add node ID back - s = np.zeros( - [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], - dtype=np.int64, - ) - s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back - s[:num_nodes, 1:] = reshaped_nodes # put values back in - new_obs = s - - # Add links back in - links = obs[-num_links:] - # Links will be added to the last protocol/service slot but they are not specific to that service - new_obs[num_nodes:, -1] = links - - return new_obs - - -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """Return string describing change between two observations. - - example: - obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) - obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) - output = 'ID 1: SERVICE 2 set to GOOD' - """ - obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) - obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) - list_of_changes = [] - for n, row in enumerate(obs1 - obs2): - if row.any() != 0: - relevant_changes = np.where(row != 0, obs2[n], -1) - relevant_changes[0] = obs2[n, 0] # ID is always relevant - is_link = relevant_changes[0] > num_nodes - desc = _describe_obs_change_helper(relevant_changes, is_link) - list_of_changes.append(desc) - - change_string = "\n ".join(list_of_changes) - if len(list_of_changes) > 0: - change_string = "\n " + change_string - return change_string - - -def _describe_obs_change_helper(obs_change, is_link): - """ - Helper funcion to describe what has changed. - - example: - [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" - - Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' - """ - # Indexes where a change has occured, not including 0th index - index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] - # Node pol types, Indexes >= 3 are service nodes - node_pol_types = [ - NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) - for i in index_changed - ] - # Account for hardware states, software sattes and links - states = [ - LinkStatus(obs_change[i]).name - if is_link - else HardwareState(obs_change[i]).name - if i == 1 - else SoftwareState(obs_change[i]).name - for i in index_changed - ] - - if not is_link: - desc = f"ID {obs_change[0]}:" - for node_pol_type, state in list(zip(node_pol_types, states)): - desc = desc + " " + node_pol_type + " changed to " + state + "." - else: - desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." - - return desc - - -def transform_action_node_enum(action): - """ - Convert a node action from readable string format, to enumerated format. - - example: - [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - """ - action_node_id = action[0] - action_node_property = NodePOLType[action[1]].value - - if action[1] == "OPERATING": - property_action = NodeHardwareAction[action[2]].value - elif action[1] == "OS" or action[1] == "SERVICE": - property_action = NodeSoftwareAction[action[2]].value - else: - property_action = 0 - - action_service_index = action[3] - - new_action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - - return new_action +from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction def transform_action_node_readable(action): @@ -307,31 +23,6 @@ def transform_action_node_readable(action): return new_action -def node_action_description(action): - """Generate string describing a node-based action.""" - if isinstance(action[1], (int, np.int64)): - # transform action to readable format - action = transform_action_node_readable(action) - - node_id = action[0] - node_property = action[1] - property_action = action[2] - service_id = action[3] - - if property_action == "NONE": - return "" - if node_property == "OPERATING" or node_property == "OS": - description = f"NODE {node_id}, {node_property}, SET TO {property_action}" - elif node_property == "SERVICE": - description = ( - f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" - ) - else: - return "" - - return description - - def transform_action_acl_readable(action): """ Transform an ACL action to a more readable format. @@ -354,52 +45,6 @@ def transform_action_acl_readable(action): return new_action -def transform_action_acl_enum(action): - """Convert a acl action from readable string format, to enumerated format.""" - action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} - action_permissions = {"DENY": 0, "ALLOW": 1} - - action_decision = action_decisions[action[0]] - action_permission = action_permissions[action[1]] - - # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index - new_action = [action_decision, action_permission] + list(action[2:6]) - for n, val in enumerate(list(action[2:6])): - if val == "ANY": - new_action[n + 2] = 0 - - new_action = np.array(new_action) - return new_action - - -def acl_action_description(action): - """Generate string describing a acl-based action.""" - if isinstance(action[0], (int, np.int64)): - # transform action to readable format - action = transform_action_acl_readable(action) - if action[0] == "NONE": - description = "NO ACL RULE APPLIED" - else: - description = ( - f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," - f" for protocol/service index {action[4]} on port index {action[5]}" - ) - - return description - - -def get_node_of_ip(ip, node_dict): - """ - Get the node ID of an IP address. - - node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - """ - for node_key, node_value in node_dict.items(): - node_ip = node_value.get_ip_address() - if node_ip == ip: - return node_key - - def is_valid_node_action(action): """Is the node action an actual valid action. @@ -480,32 +125,3 @@ def is_valid_acl_action_extra(action): return False return True - - -def get_new_action(old_action, action_dict): - """Get new action (e.g. 32) from old action e.g. [1,1,1,0]. - - old_action can be either node or acl action type. - """ - for key, val in action_dict.items(): - if list(val) == list(old_action): - return key - # Not all possible actions are included in dict, only valid action are - # if action is not in the dict, its an invalid action so return 0 - return 0 - - -def get_action_description(action, action_dict): - """Get a string describing/explaining what an action is doing in words.""" - action_array = action_dict[action] - if len(action_array) == 4: - # node actions have length 4 - action_description = node_action_description(action_array) - elif len(action_array) == 6: - # acl actions have length 6 - action_description = acl_action_description(action_array) - else: - # Should never happen - action_description = "Unrecognised action" - - return action_description diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index becbc0f3..7fcc002f 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -5,7 +5,7 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agentIdentifier: GENERIC_TEST +agentIdentifier: GENERIC # Number of episodes to run per session numEpisodes: 1 # Time delay between steps (for generic agents) diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 10701a6a..75d57f5d 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -81,9 +81,8 @@ def test_agent_is_executing_actions_from_both_spaces(): lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) - + # Run environment with specified fixed blue agent actions only run_generic_set_actions(env, config_values) - # Retrieve hardware state of computer_1 node in laydown config # Agent turned this off in Step 5 computer_node_hardware_state = env.nodes["1"].hardware_state From 4329c65211009b84b6950960cc8426dd8208cb4c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Jun 2023 15:25:11 +0100 Subject: [PATCH 36/37] Apply suggestions from code review. --- src/primaite/common/config_values_main.py | 1 + src/primaite/environment/observations.py | 25 ++-- src/primaite/environment/primaite_env.py | 13 +-- src/primaite/main.py | 4 + ...n_with_NODE_STATUSES.yaml => laydown.yaml} | 5 +- .../laydown_with_LINK_TRAFFIC_LEVELS.yaml | 110 ------------------ .../laydown_with_NODE_LINK_TABLE.yaml | 77 ------------ .../obs_tests/laydown_without_obs_space.yaml | 74 ------------ .../main_config_LINK_TRAFFIC_LEVELS.yaml | 96 +++++++++++++++ .../main_config_NODE_LINK_TABLE.yaml | 93 +++++++++++++++ .../obs_tests/main_config_NODE_STATUSES.yaml | 93 +++++++++++++++ ...gent.yaml => main_config_without_obs.yaml} | 2 +- tests/conftest.py | 4 + tests/test_observation_space.py | 49 ++++---- 14 files changed, 338 insertions(+), 308 deletions(-) rename tests/config/obs_tests/{laydown_with_NODE_STATUSES.yaml => laydown.yaml} (95%) delete mode 100644 tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml delete mode 100644 tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml delete mode 100644 tests/config/obs_tests/laydown_without_obs_space.yaml create mode 100644 tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml create mode 100644 tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml create mode 100644 tests/config/obs_tests/main_config_NODE_STATUSES.yaml rename tests/config/obs_tests/{main_config_no_agent.yaml => main_config_without_obs.yaml} (98%) diff --git a/src/primaite/common/config_values_main.py b/src/primaite/common/config_values_main.py index 3493f9d2..f822b77f 100644 --- a/src/primaite/common/config_values_main.py +++ b/src/primaite/common/config_values_main.py @@ -9,6 +9,7 @@ class ConfigValuesMain(object): """Init.""" # Generic self.agent_identifier = "" # the agent in use + self.observation_config = None # observation space config self.num_episodes = 0 # number of episodes to train over self.num_steps = 0 # number of steps in an episode self.time_delay = 0 # delay between steps (ms) - applies to generic agents only diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index a598d6db..9e71ef1b 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union import numpy as np from gym import spaces @@ -56,9 +56,9 @@ class NodeLinkTable(AbstractObservationComponent): ``(12, 7)`` """ - _FIXED_PARAMETERS = 4 - _MAX_VAL = 1_000_000 - _DATA_TYPE = np.int64 + _FIXED_PARAMETERS: int = 4 + _MAX_VAL: int = 1_000_000 + _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): super().__init__(env) @@ -159,7 +159,7 @@ class NodeStatuses(AbstractObservationComponent): :type env: Primaite """ - _DATA_TYPE = np.int64 + _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): super().__init__(env) @@ -231,7 +231,7 @@ class LinkTrafficLevels(AbstractObservationComponent): :type quantisation_levels: int, optional """ - _DATA_TYPE = np.int64 + _DATA_TYPE: type = np.int64 def __init__( self, @@ -239,7 +239,14 @@ class LinkTrafficLevels(AbstractObservationComponent): combine_service_traffic: bool = False, quantisation_levels: int = 5, ): - assert quantisation_levels >= 3 + if quantisation_levels < 3: + _msg = ( + f"quantisation_levels must be 3 or more because the lowest and highest levels are " + f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. " + f"Resetting to default value (5)" + ) + _LOGGER.warning(_msg) + quantisation_levels = 5 super().__init__(env) @@ -296,7 +303,7 @@ class ObservationsHandler: Each component can also define further parameters to make them more flexible. """ - registry = { + _REGISTRY: Final[Dict[str, type]] = { "NODE_LINK_TABLE": NodeLinkTable, "NODE_STATUSES": NodeStatuses, "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, @@ -384,7 +391,7 @@ class ObservationsHandler: for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] - comp_class = cls.registry[comp_type] + comp_class = cls._REGISTRY[comp_type] # Create the component with options from the YAML options = component_cfg.get("options") or {} diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0ff58100..7995c4f7 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -48,10 +48,10 @@ class Primaite(Env): """PRIMmary AI Training Evironment (Primaite) class.""" # Action Space contants - ACTION_SPACE_NODE_PROPERTY_VALUES = 5 - ACTION_SPACE_NODE_ACTION_VALUES = 4 - ACTION_SPACE_ACL_ACTION_VALUES = 3 - ACTION_SPACE_ACL_PERMISSION_VALUES = 2 + ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5 + ACTION_SPACE_NODE_ACTION_VALUES: int = 4 + ACTION_SPACE_ACL_ACTION_VALUES: int = 3 + ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__(self, _config_values, _transaction_list): """ @@ -148,6 +148,8 @@ class Primaite(Env): # stores the observation config from the yaml, default is NODE_LINK_TABLE self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]} + if self.config_values.observation_config is not None: + self.obs_config = self.config_values.observation_config # Observation Handler manages the user-configurable observation space. # It will be initialised later. @@ -690,9 +692,6 @@ class Primaite(Env): elif item["itemType"] == "ACTIONS": # Get the action information self.get_action_info(item) - elif item["itemType"] == "OBSERVATION_SPACE": - # Get the observation information - self.save_obs_config(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) diff --git a/src/primaite/main.py b/src/primaite/main.py index c963dd00..5f8aa5e2 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -163,6 +163,10 @@ def load_config_values(): try: # Generic config_values.agent_identifier = config_data["agentIdentifier"] + if "observationSpace" in config_data: + config_values.observation_config = config_data["observationSpace"] + else: + config_values.observation_config = None config_values.num_episodes = int(config_data["numEpisodes"]) config_values.time_delay = int(config_data["timeDelay"]) config_values.config_filename_use_case = ( diff --git a/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml b/tests/config/obs_tests/laydown.yaml similarity index 95% rename from tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml rename to tests/config/obs_tests/laydown.yaml index 56ff3725..d3b131db 100644 --- a/tests/config/obs_tests/laydown_with_NODE_STATUSES.yaml +++ b/tests/config/obs_tests/laydown.yaml @@ -1,8 +1,5 @@ - itemType: ACTIONS type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: NODE_STATUSES - itemType: STEPS steps: 5 - itemType: PORTS @@ -82,7 +79,7 @@ id: '5' startStep: 0 endStep: 5 - load: 20 + load: 999 protocol: TCP port: '80' source: '1' diff --git a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml deleted file mode 100644 index e65ea306..00000000 --- a/tests/config/obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml +++ /dev/null @@ -1,110 +0,0 @@ -- itemType: ACTIONS - type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: LINK_TRAFFIC_LEVELS - options: - combine_service_traffic: false - quantisation_levels: 8 -- itemType: STEPS - steps: 5 -- itemType: PORTS - portsList: - - port: '80' - - port: '53' -- itemType: SERVICES - serviceList: - - name: TCP - - name: UDP - -######################################## -# Nodes -- itemType: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - -######################################## -# Links -- itemType: LINK - id: '4' - name: link1 - bandwidth: 1000 - source: '1' - destination: '3' -- itemType: LINK - id: '5' - name: link2 - bandwidth: 1000 - source: '3' - destination: '2' - -######################################### -# IERS -- itemType: GREEN_IER - id: '5' - startStep: 0 - endStep: 5 - load: 999 - protocol: TCP - port: '80' - source: '1' - destination: '2' - missionCriticality: 5 - -######################################### -# ACL Rules -- itemType: ACL_RULE - id: '6' - permission: ALLOW - source: 192.168.1.1 - destination: 192.168.1.2 - protocol: TCP - port: 80 -- itemType: ACL_RULE - id: '7' - permission: ALLOW - source: 192.168.1.2 - destination: 192.168.1.1 - protocol: TCP - port: 80 diff --git a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml deleted file mode 100644 index 0ceefbfa..00000000 --- a/tests/config/obs_tests/laydown_with_NODE_LINK_TABLE.yaml +++ /dev/null @@ -1,77 +0,0 @@ -- itemType: ACTIONS - type: NODE -- itemType: OBSERVATION_SPACE - components: - - name: NODE_LINK_TABLE -- itemType: STEPS - steps: 5 -- itemType: PORTS - portsList: - - port: '80' - - port: '53' -- itemType: SERVICES - serviceList: - - name: TCP - - name: UDP - -######################################## -# Nodes -- itemType: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - -######################################## -# Links -- itemType: LINK - id: '4' - name: link1 - bandwidth: 1000 - source: '1' - destination: '3' -- itemType: LINK - id: '5' - name: link2 - bandwidth: 1000 - source: '3' - destination: '2' diff --git a/tests/config/obs_tests/laydown_without_obs_space.yaml b/tests/config/obs_tests/laydown_without_obs_space.yaml deleted file mode 100644 index 3ef214da..00000000 --- a/tests/config/obs_tests/laydown_without_obs_space.yaml +++ /dev/null @@ -1,74 +0,0 @@ -- itemType: ACTIONS - type: NODE -- itemType: STEPS - steps: 5 -- itemType: PORTS - portsList: - - port: '80' - - port: '53' -- itemType: SERVICES - serviceList: - - name: TCP - - name: UDP - -######################################## -# Nodes -- itemType: NODE - node_id: '1' - name: PC1 - node_class: SERVICE - node_type: COMPUTER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.1 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '2' - name: SERVER - node_class: SERVICE - node_type: SERVER - priority: P5 - hardware_state: 'ON' - ip_address: 192.168.1.2 - software_state: GOOD - file_system_state: GOOD - services: - - name: TCP - port: '80' - state: GOOD - - name: UDP - port: '53' - state: GOOD -- itemType: NODE - node_id: '3' - name: SWITCH1 - node_class: ACTIVE - node_type: SWITCH - priority: P2 - hardware_state: 'ON' - ip_address: 192.168.1.3 - software_state: GOOD - file_system_state: GOOD - -######################################## -# Links -- itemType: LINK - id: '4' - name: link1 - bandwidth: 1000 - source: '1' - destination: '3' -- itemType: LINK - id: '5' - name: link2 - bandwidth: 1000 - source: '3' - destination: '2' diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml new file mode 100644 index 00000000..cdb741f3 --- /dev/null +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -0,0 +1,96 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: LINK_TRAFFIC_LEVELS + options: + combine_service_traffic: false + quantisation_levels: 8 + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml new file mode 100644 index 00000000..19d220c8 --- /dev/null +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -0,0 +1,93 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: NODE_LINK_TABLE + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml new file mode 100644 index 00000000..25520ccc --- /dev/null +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -0,0 +1,93 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: NONE +# Number of episodes to run per session +observationSpace: + components: + - name: NODE_STATUSES + +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1_000_000_000 + +# Reward values +# Generic +allOk: 0 +# Node Hardware State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node Software or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/obs_tests/main_config_no_agent.yaml b/tests/config/obs_tests/main_config_without_obs.yaml similarity index 98% rename from tests/config/obs_tests/main_config_no_agent.yaml rename to tests/config/obs_tests/main_config_without_obs.yaml index f632dca9..43ee251f 100644 --- a/tests/config/obs_tests/main_config_no_agent.yaml +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -21,7 +21,7 @@ agentLoadFile: C:\[Path]\[agent_saved_filename.zip] # Environment config values # The high value for the observation space -observationSpaceHighValue: 1000000000 +observationSpaceHighValue: 1_000_000_000 # Reward values # Generic diff --git a/tests/conftest.py b/tests/conftest.py index 1e987223..f3728b63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,10 @@ def _get_primaite_env_from_config( def load_config_values(): config_values.agent_identifier = config_data["agentIdentifier"] + if "observationSpace" in config_data: + config_values.observation_config = config_data["observationSpace"] + else: + config_values.observation_config = None config_values.num_episodes = int(config_data["numEpisodes"]) config_values.time_delay = int(config_data["timeDelay"]) config_values.config_filename_use_case = lay_down_config_path diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index ae862c96..dcf98ae1 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -27,9 +27,8 @@ def env(request): @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_without_obs_space.yaml", + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) def test_default_obs_space(env: Primaite): @@ -44,9 +43,8 @@ def test_default_obs_space(env: Primaite): @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_without_obs_space.yaml", + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) def test_registering_components(env: Primaite): @@ -61,9 +59,9 @@ def test_registering_components(env: Primaite): @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_with_NODE_LINK_TABLE.yaml", + main_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_NODE_LINK_TABLE.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) class TestNodeLinkTable: @@ -92,17 +90,17 @@ class TestNodeLinkTable: * Node 1: * 1 (id) * 1 (good hardware state) - * 1 (good OS state) + * 3 (compromised OS state) * 1 (good file system state) - * 1 (good service1 state) - * 1 (good service2 state) + * 1 (good TCP state) + * 1 (good UDP state) * Node 2: * 2 (id) * 1 (good hardware state) * 1 (good OS state) * 1 (good file system state) - * 1 (good service1 state) - * 1 (good service2 state) + * 1 (good TCP state) + * 4 (overwhelmed UDP state) * Node 3 (active node): * 3 (id) * 1 (good hardware state) @@ -115,14 +113,14 @@ class TestNodeLinkTable: * 0 (n/a hardware state) * 0 (n/a OS state) * 0 (n/a file system state) - * 0 (no traffic for service1) + * 999 (999 traffic for service1) * 0 (no traffic for service2) * Link 2: * 5 (id) * 0 (good hardware state) * 0 (good OS state) * 0 (good file system state) - * 0 (no traffic service1) + * 999 (999 traffic service1) * 0 (no traffic for service2) """ act = np.asarray([0, 0, 0, 0]) @@ -131,20 +129,19 @@ class TestNodeLinkTable: assert np.array_equal( obs, [ - [1, 1, 1, 1, 1, 1], - [2, 1, 1, 1, 1, 1], + [1, 1, 3, 1, 1, 1], + [2, 1, 1, 1, 1, 4], [3, 1, 1, 1, 0, 0], - [4, 0, 0, 0, 0, 0], - [5, 0, 0, 0, 0, 0], + [4, 0, 0, 0, 999, 0], + [5, 0, 0, 0, 999, 0], ], ) @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_with_NODE_STATUSES.yaml", + main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) class TestNodeStatuses: @@ -188,9 +185,9 @@ class TestNodeStatuses: @pytest.mark.env_config_paths( dict( - main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml", + main_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ) ) class TestLinkTrafficLevels: From a597cf95d75c03309d0328381c321f1cd7ddd7b6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 9 Jun 2023 10:28:24 +0100 Subject: [PATCH 37/37] Fix obs tests with new changes --- tests/test_observation_space.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index dcf98ae1..0df59b72 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -18,7 +18,7 @@ def env(request): marker = request.node.get_closest_marker("env_config_paths") main_config_path = marker.args[0]["main_config_path"] lay_down_config_path = marker.args[0]["lay_down_config_path"] - env = _get_primaite_env_from_config( + env, _ = _get_primaite_env_from_config( main_config_path=main_config_path, lay_down_config_path=lay_down_config_path, ) @@ -123,8 +123,8 @@ class TestNodeLinkTable: * 999 (999 traffic service1) * 0 (no traffic for service2) """ - act = np.asarray([0, 0, 0, 0]) - obs, reward, done, info = env.step(act) + # act = np.asarray([0,]) + obs, reward, done, info = env.step(0) # apply the 'do nothing' action assert np.array_equal( obs, @@ -178,8 +178,7 @@ class TestNodeStatuses: * service 1 = n/a (0) * service 2 = n/a (0) """ - act = np.asarray([0, 0, 0, 0]) - obs, _, _, _ = env.step(act) + obs, _, _, _ = env.step(0) # apply the 'do nothing' action assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) @@ -210,9 +209,8 @@ class TestLinkTrafficLevels: * an IER trying to send 999 bits of data over both links the whole time (via the first service) * link bandwidth of 1000, therefore the utilisation is 99.9% """ - act = np.asarray([0, 0, 0, 0]) - obs, reward, done, info = env.step(act) - obs, reward, done, info = env.step(act) + obs, reward, done, info = env.step(0) + obs, reward, done, info = env.step(0) # the observation space has combine_service_traffic set to False, so the space has this format: # [link1_service1, link1_service2, link2_service1, link2_service2]