From 3cd5864f25f966c9118b993e488ad74e08e0e764 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 26 May 2023 10:17:45 +0100 Subject: [PATCH 01/19] 1429 - created new branch from dev, added enums to enums.py, created agents package and utils.py file, added option to primaite_env.py for ANY action type and changed the action spaces are defined using ADSP branch --- src/primaite/agents/__init__.py | 0 src/primaite/agents/utils.py | 509 +++++++++++++++++++ src/primaite/common/enums.py | 26 + src/primaite/environment/primaite_env.py | 118 ++++- tests/config/single_action_space_config.yaml | 89 ++++ tests/test_single_action_space.py | 12 + 6 files changed, 735 insertions(+), 19 deletions(-) create mode 100644 src/primaite/agents/__init__.py create mode 100644 src/primaite/agents/utils.py create mode 100644 tests/config/single_action_space_config.yaml create mode 100644 tests/test_single_action_space.py diff --git a/src/primaite/agents/__init__.py b/src/primaite/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py new file mode 100644 index 00000000..d3924b24 --- /dev/null +++ b/src/primaite/agents/utils.py @@ -0,0 +1,509 @@ +import logging +import os.path +from datetime import datetime + +import numpy as np +import yaml + +from primaite.common.config_values_main import ConfigValuesMain +from primaite.common.enums import ( + ActionType, + HardwareState, + LinkStatus, + NodeHardwareAction, + NodePOLType, + NodeSoftwareAction, + SoftwareState, +) + + +def load_config_values(config_path): + """Loads the config values from the main config file into a config object.""" + config_file_main = open(config_path, "r") + config_data = yaml.safe_load(config_file_main) + # Create a config class + config_values = ConfigValuesMain() + + try: + # Generic + config_values.red_agent_identifier = config_data["redAgentIdentifier"] + config_values.action_type = ActionType[config_data["actionType"]] + config_values.config_filename_use_case = config_data["configFilename"] + # Reward values + # Generic + config_values.all_ok = float(config_data["allOk"]) + # Node Operating State + config_values.off_should_be_on = float(config_data["offShouldBeOn"]) + config_values.off_should_be_resetting = float( + config_data["offShouldBeResetting"] + ) + config_values.on_should_be_off = float(config_data["onShouldBeOff"]) + config_values.on_should_be_resetting = float(config_data["onShouldBeResetting"]) + config_values.resetting_should_be_on = float(config_data["resettingShouldBeOn"]) + config_values.resetting_should_be_off = float( + config_data["resettingShouldBeOff"] + ) + # Node O/S or Service State + config_values.good_should_be_patching = float( + config_data["goodShouldBePatching"] + ) + config_values.good_should_be_compromised = float( + config_data["goodShouldBeCompromised"] + ) + config_values.good_should_be_overwhelmed = float( + config_data["goodShouldBeOverwhelmed"] + ) + config_values.patching_should_be_good = float( + config_data["patchingShouldBeGood"] + ) + config_values.patching_should_be_compromised = float( + config_data["patchingShouldBeCompromised"] + ) + config_values.patching_should_be_overwhelmed = float( + config_data["patchingShouldBeOverwhelmed"] + ) + config_values.compromised_should_be_good = float( + config_data["compromisedShouldBeGood"] + ) + config_values.compromised_should_be_patching = float( + config_data["compromisedShouldBePatching"] + ) + config_values.compromised_should_be_overwhelmed = float( + config_data["compromisedShouldBeOverwhelmed"] + ) + config_values.compromised = float(config_data["compromised"]) + config_values.overwhelmed_should_be_good = float( + config_data["overwhelmedShouldBeGood"] + ) + config_values.overwhelmed_should_be_patching = float( + config_data["overwhelmedShouldBePatching"] + ) + config_values.overwhelmed_should_be_compromised = float( + config_data["overwhelmedShouldBeCompromised"] + ) + config_values.overwhelmed = float(config_data["overwhelmed"]) + # IER status + config_values.red_ier_running = float(config_data["redIerRunning"]) + config_values.green_ier_blocked = float(config_data["greenIerBlocked"]) + # Patching / Reset durations + config_values.os_patching_duration = int(config_data["osPatchingDuration"]) + config_values.node_reset_duration = int(config_data["nodeResetDuration"]) + config_values.service_patching_duration = int( + config_data["servicePatchingDuration"] + ) + + except Exception as e: + print(f"Could not save load config data: {e} ") + + return config_values + + +def configure_logging(log_name): + """Configures logging.""" + try: + now = datetime.now() # current date and time + time = now.strftime("%Y%m%d_%H%M%S") + filename = "/app/logs/" + log_name + "/" + time + ".log" + path = f"/app/logs/{log_name}" + is_dir = os.path.isdir(path) + if not is_dir: + os.makedirs(path) + logging.basicConfig( + filename=filename, + filemode="w", + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%d-%b-%y %H:%M:%S", + level=logging.INFO, + ) + except Exception as e: + print("ERROR: Could not start logging", e) + + +def transform_change_obs_readable(obs): + """Transform list of transactions to readable list of each observation property. + + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + """ + ids = [i for i in obs[:, 0]] + operating_states = [HardwareState(i).name for i in obs[:, 1]] + os_states = [SoftwareState(i).name for i in obs[:, 2]] + new_obs = [ids, operating_states, os_states] + + for service in range(3, obs.shape[1]): + # Links bit/s don't have a service state + service_states = [ + SoftwareState(i).name if i <= 4 else i for i in obs[:, service] + ] + new_obs.append(service_states) + + return new_obs + + +def transform_obs_readable(obs): + """ + Transform obs readable function. + + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]. + """ + changed_obs = transform_change_obs_readable(obs) + new_obs = list(zip(*changed_obs)) + # Convert list of tuples to list of lists + new_obs = [list(i) for i in new_obs] + + return new_obs + + +def convert_to_new_obs(obs, num_nodes=10): + """Convert original gym Box observation space to new multiDiscrete observation space.""" + # Remove ID columns, remove links and flatten to MultiDiscrete observation space + new_obs = obs[:num_nodes, 1:].flatten() + return new_obs + + +def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): + """ + Convert to old observation, links filled with 0's as no information is included in new observation space. + + example: + obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) + + new_obs = array([[ 1, 1, 1, 1], + [ 2, 1, 1, 1], + [ 3, 1, 1, 1], + ... + [20, 0, 0, 0]]) + """ + # Convert back to more readable, original format + reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) + + # Add empty links back and add node ID back + s = np.zeros( + [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) + s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back + s[:num_nodes, 1:] = reshaped_nodes # put values back in + new_obs = s + + # Add links back in + links = obs[-num_links:] + # Links will be added to the last protocol/service slot but they are not specific to that service + new_obs[num_nodes:, -1] = links + + return new_obs + + +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): + """Return string describing change between two observations. + + example: + obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) + obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) + output = 'ID 1: SERVICE 2 set to GOOD' + """ + obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) + obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) + list_of_changes = [] + for n, row in enumerate(obs1 - obs2): + if row.any() != 0: + relevant_changes = np.where(row != 0, obs2[n], -1) + relevant_changes[0] = obs2[n, 0] # ID is always relevant + is_link = relevant_changes[0] > num_nodes + desc = _describe_obs_change_helper(relevant_changes, is_link) + list_of_changes.append(desc) + + change_string = "\n ".join(list_of_changes) + if len(list_of_changes) > 0: + change_string = "\n " + change_string + return change_string + + +def _describe_obs_change_helper(obs_change, is_link): + """ + Helper funcion to describe what has changed. + + example: + [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" + + Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + """ + # Indexes where a change has occured, not including 0th index + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + # Node pol types, Indexes >= 3 are service nodes + node_pol_types = [ + NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) + for i in index_changed + ] + # Account for hardware states, software sattes and links + states = [ + LinkStatus(obs_change[i]).name + if is_link + else HardwareState(obs_change[i]).name + if i == 1 + else SoftwareState(obs_change[i]).name + for i in index_changed + ] + + if not is_link: + desc = f"ID {obs_change[0]}:" + for node_pol_type, state in list(zip(node_pol_types, states)): + desc = desc + " " + node_pol_type + " changed to " + state + "." + else: + desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." + + return desc + + +def transform_action_node_enum(action): + """ + Convert a node action from readable string format, to enumerated format. + + example: + [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] + """ + action_node_id = action[0] + action_node_property = NodePOLType[action[1]].value + + if action[1] == "OPERATING": + property_action = NodeHardwareAction[action[2]].value + elif action[1] == "OS" or action[1] == "SERVICE": + property_action = NodeSoftwareAction[action[2]].value + else: + property_action = 0 + + action_service_index = action[3] + + new_action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + + return new_action + + +def transform_action_node_readable(action): + """ + Convert a node action from enumerated format to readable format. + + example: + [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + """ + action_node_property = NodePOLType(action[1]).name + + if action_node_property == "OPERATING": + property_action = NodeHardwareAction(action[2]).name + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[ + 2 + ] <= 1: + property_action = NodeSoftwareAction(action[2]).name + else: + property_action = "NONE" + + new_action = [action[0], action_node_property, property_action, action[3]] + return new_action + + +def node_action_description(action): + """Generate string describing a node-based action.""" + if isinstance(action[1], (int, np.int64)): + # transform action to readable format + action = transform_action_node_readable(action) + + node_id = action[0] + node_property = action[1] + property_action = action[2] + service_id = action[3] + + if property_action == "NONE": + return "" + if node_property == "OPERATING" or node_property == "OS": + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + elif node_property == "SERVICE": + description = ( + f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" + ) + else: + return "" + + return description + + +def transform_action_acl_readable(action): + """ + Transform an ACL action to a more readable format. + + example: + [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + """ + action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} + action_permissions = {0: "DENY", 1: "ALLOW"} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, 0 means any, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == 0: + new_action[n + 2] = "ANY" + + return new_action + + +def transform_action_acl_enum(action): + """Convert a acl action from readable string format, to enumerated format.""" + action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} + action_permissions = {"DENY": 0, "ALLOW": 1} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == "ANY": + new_action[n + 2] = 0 + + new_action = np.array(new_action) + return new_action + + +def acl_action_description(action): + """Generate string describing a acl-based action.""" + if isinstance(action[0], (int, np.int64)): + # transform action to readable format + action = transform_action_acl_readable(action) + if action[0] == "NONE": + description = "NO ACL RULE APPLIED" + else: + description = ( + f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," + f" for protocol/service index {action[4]} on port index {action[5]}" + ) + + return description + + +def get_node_of_ip(ip, node_dict): + """ + Get the node ID of an IP address. + + node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) + """ + for node_key, node_value in node_dict.items(): + node_ip = node_value.get_ip_address() + if node_ip == ip: + return node_key + + +def is_valid_node_action(action): + """Is the node action an actual valid action. + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch + - Node already being in that state (turning an ON node ON) + """ + action_r = transform_action_node_readable(action) + + node_property = action_r[1] + node_action = action_r[2] + + if node_property == "NONE": + return False + if node_action == "NONE": + return False + if node_property == "OPERATING" and node_action == "PATCHING": + # Operating State cannot PATCH + return False + if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + # Software States can only do Nothing or Patch + return False + return True + + +def is_valid_acl_action(action): + """ + Is the ACL action an actual valid action. + + Only uses information about the action to determine if the action has an effect. + + Does NOT consider: + - Trying to create identical rules + - Trying to create a rule which is a subset of another rule (caused by "ANY") + """ + action_r = transform_action_acl_readable(action) + + action_decision = action_r[0] + action_permission = action_r[1] + action_source_id = action_r[2] + action_destination_id = action_r[3] + + if action_decision == "NONE": + return False + if ( + action_source_id == action_destination_id + and action_source_id != "ANY" + and action_destination_id != "ANY" + ): + # ACL rule towards itself + return False + if action_permission == "DENY": + # DENY is unnecessary, we can create and delete allow rules instead + # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. + return False + + return True + + +def is_valid_acl_action_extra(action): + """Harsher version of valid acl actions, does not allow action.""" + if is_valid_acl_action(action) is False: + return False + + action_r = transform_action_acl_readable(action) + action_protocol = action_r[4] + action_port = action_r[5] + + # Don't allow protocols or ports to be ANY + # in the future we might want to do the opposite, and only have ANY option for ports and service + if action_protocol == "ANY": + return False + if action_port == "ANY": + return False + + return True + + +def get_new_action(old_action, action_dict): + """Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + + old_action can be either node or acl action type. + """ + for key, val in action_dict.items(): + if list(val) == list(old_action): + return key + # Not all possible actions are included in dict, only valid action are + # if action is not in the dict, its an invalid action so return 0 + return 0 + + +def get_action_description(action, action_dict): + """Get a string describing/explaining what an action is doing in words.""" + action_array = action_dict[action] + if len(action_array) == 4: + # node actions have length 4 + action_description = node_action_description(action_array) + elif len(action_array) == 6: + # acl actions have length 6 + action_description = acl_action_description(action_array) + else: + # Should never happen + action_description = "Unrecognised action" + + return action_description diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0e00c9e4..0aebf2a4 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -91,3 +91,29 @@ class FileSystemState(Enum): DESTROYED = 3 REPAIRING = 4 RESTORING = 5 + + +class NodeHardwareAction(Enum): + """Node hardware action.""" + + NONE = 0 + ON = 1 + OFF = 2 + RESET = 3 + + +class NodeSoftwareAction(Enum): + """Node software action.""" + + NONE = 0 + PATCHING = 1 + + +class LinkStatus(Enum): + """Link traffic status.""" + + NONE = 0 + LOW = 1 + MEDIUM = 2 + HIGH = 3 + OVERLOAD = 4 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 99c7c09f..1c72bba0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -15,6 +15,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -232,15 +233,9 @@ class Primaite(Env): # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa - self.action_space = spaces.MultiDiscrete( - [ - self.num_nodes, - self.ACTION_SPACE_NODE_PROPERTY_VALUES, - self.ACTION_SPACE_NODE_ACTION_VALUES, - self.num_services, - ] - ) - else: + self.action_dict = self.create_node_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) + elif self.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) @@ -249,16 +244,12 @@ class Primaite(Env): # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) - self.action_space = spaces.MultiDiscrete( - [ - self.ACTION_SPACE_ACL_ACTION_VALUES, - self.ACTION_SPACE_ACL_PERMISSION_VALUES, - self.num_nodes + 1, - self.num_nodes + 1, - self.num_services + 1, - self.num_ports + 1, - ] - ) + self.action_dict = self.create_acl_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) + else: + _LOGGER.info("Action space type ANY selected") + self.action_dict = self.create_node_and_acl_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) # Set up a csv to store the results of the training try: @@ -1163,3 +1154,92 @@ class Primaite(Env): else: # Bad formatting pass + + def create_node_action_dict(self): + """ + Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. + + Note: Only actions that have the potential to change the state exist in the mapping (except for key 0) + + example return: + {0: [1, 0, 0, 0], + 1: [1, 1, 1, 0], + 2: [1, 1, 2, 0], + 3: [1, 1, 3, 0], + 4: [1, 2, 1, 0], + 5: [1, 3, 1, 0], + ... + } + + """ + # reserve 0 action to be a nothing action + actions = {0: [1, 0, 0, 0]} + + action_key = 1 + for node in range(1, self.num_nodes + 1): + # 4 node properties (NONE, OPERATING, OS, SERVICE) + for node_property in range(4): + # Node Actions either: + # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state + # Use MAX to ensure we get them all + for node_action in range(4): + for service_state in range(self.num_services): + action = [node, node_property, node_action, service_state] + # check to see if its a nothing aciton (has no effect) + if is_valid_node_action(action): + actions[action_key] = action + action_key += 1 + + return actions + + def create_acl_action_dict(self): + """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.""" + # reserve 0 action to be a nothing action + actions = {0: [0, 0, 0, 0, 0, 0]} + + action_key = 1 + # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE + for action_decision in range(3): + # 2 possible action permissions 0 = DENY, 1 = CREATE + for action_permission in range(2): + # Number of nodes + 1 (for any) + for source_ip in range(self.num_nodes + 1): + for dest_ip in range(self.num_nodes + 1): + for protocol in range(self.num_services + 1): + for port in range(self.num_ports + 1): + action = [ + action_decision, + action_permission, + source_ip, + dest_ip, + protocol, + port, + ] + # Check to see if its an action we want to include as possible i.e. not a nothing action + if is_valid_acl_action_extra(action): + actions[action_key] = action + action_key += 1 + + return actions + + def create_node_and_acl_action_dict(self): + """ + Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. + + The dictionary contains actions of both Node and ACL action types. + + """ + node_action_dict = self.create_node_action_dict() + acl_action_dict = self.create_acl_action_dict() + + # Change node keys to not overlap with acl keys + # Only 1 nothing action (key 0) is required, remove the other + new_node_action_dict = { + k + len(acl_action_dict) - 1: v + for k, v in node_action_dict.items() + if k != 0 + } + + # Combine the Node dict and ACL dict + combined_action_dict = {**acl_action_dict, **new_node_action_dict} + return combined_action_dict diff --git a/tests/config/single_action_space_config.yaml b/tests/config/single_action_space_config.yaml new file mode 100644 index 00000000..6f6bb4e6 --- /dev/null +++ b/tests/config/single_action_space_config.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: GENERIC +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Operating State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node O/S or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py new file mode 100644 index 00000000..3ec1dc2e --- /dev/null +++ b/tests/test_single_action_space.py @@ -0,0 +1,12 @@ +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_primaite_env_from_config + + +def test_single_action_space(): + """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_lay_down_config.yaml", + ) + print("Average Reward:", env.average_reward) From e2fb03b9bd2b91c826b24650ea0884da30f9f022 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 26 May 2023 14:29:02 +0100 Subject: [PATCH 02/19] 1429 - added code from ADSP branch to primaite_env.py and added NONE = 0 to NodePOLType in enums.py --- src/primaite/common/enums.py | 1 + src/primaite/environment/primaite_env.py | 38 ++++++++++++++++-------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0aebf2a4..20660e86 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -49,6 +49,7 @@ class SoftwareState(Enum): class NodePOLType(Enum): """Node Pattern of Life type enumeration.""" + NONE = 0 OPERATING = 1 OS = 2 SERVICE = 3 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 1c72bba0..0ebcd973 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -247,7 +247,7 @@ class Primaite(Env): self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info("Action space type ANY selected") + _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) @@ -453,8 +453,18 @@ class Primaite(Env): # At the moment, actions are only affecting nodes if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) - else: + elif self.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) + elif ( + len(self.action_dict[_action]) == 6 + ): # ACL actions in multidiscrete form have len 6 + self.apply_actions_to_acl(_action) + elif ( + len(self.action_dict[_action]) == 4 + ): # Node actions in multdiscrete (array) from have len 4 + self.apply_actions_to_nodes(_action) + else: + logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): """ @@ -463,10 +473,11 @@ class Primaite(Env): Args: _action: The action space from the agent """ - node_id = _action[0] - node_property = _action[1] - property_action = _action[2] - service_index = _action[3] + readable_action = self.action_dict[_action] + node_id = readable_action[0] + node_property = readable_action[1] + property_action = readable_action[2] + service_index = readable_action[3] # Check that the action is requesting a valid node try: @@ -552,12 +563,15 @@ class Primaite(Env): Args: _action: The action space from the agent """ - action_decision = _action[0] - action_permission = _action[1] - action_source_ip = _action[2] - action_destination_ip = _action[3] - action_protocol = _action[4] - action_port = _action[5] + # Convert discrete value back to multidiscrete + multidiscrete_action = self.action_dict[_action] + + action_decision = multidiscrete_action[0] + action_permission = multidiscrete_action[1] + action_source_ip = multidiscrete_action[2] + action_destination_ip = multidiscrete_action[3] + action_protocol = multidiscrete_action[4] + action_port = multidiscrete_action[5] if action_decision == 0: # It's decided to do nothing From 20d13f42a2b8a797a7ffcb970d0d1a653770ee29 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 31 May 2023 13:15:25 +0100 Subject: [PATCH 03/19] 1443 - added changes from ADSP to observation space in primaite_env.py --- src/primaite/environment/primaite_env.py | 41 +++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0ebcd973..84b485bd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -207,6 +207,7 @@ class Primaite(Env): # Calculate the number of items that need to be included in the # observation space + """ num_items = self.num_links + self.num_nodes # Set the number of observation parameters, being # of services plus id, # hardware state, file system state and SoftwareState (i.e. 4) @@ -221,6 +222,23 @@ class Primaite(Env): shape=self.observation_shape, dtype=np.int64, ) + """ + self.num_observation_parameters = ( + self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS + ) + # Define the observation space: + # There are: + # 4 Operating States (ON/OFF/RESETTING) + NONE (0) + # 4 OS States (GOOD/PATCHING/COMPROMISED) + NONE + # 5 Service States (NONE/GOOD/PATCHING/COMPROMISED/OVERWHELMED) + NONE + # There can be any number of services + # There are 5 node types No Traffic, Low Traffic, Medium Traffic, High Traffic, Overloaded/max traffic + self.observation_space = spaces.MultiDiscrete( + ([4, 4] + [5] * self.num_services) * self.num_nodes + [5] * self.num_links + ) + + # Define the observation shape + self.observation_shape = self.observation_space.sample().shape # This is the observation that is sent back via the rest and step functions self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) @@ -396,7 +414,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -678,6 +696,19 @@ class Primaite(Env): def update_environent_obs(self): """Updates the observation space based on the node and link status.""" + # Convert back to more readable, original format + reshaped_nodes = self.env_obs[: -self.num_links].reshape( + self.num_nodes, self.num_services + 2 + ) + + # Add empty links back and add node ID back + s = np.zeros( + [reshaped_nodes.shape[0] + self.num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) + s[:, 0] = range(1, self.num_nodes + self.num_links + 1) # Adding ID back + s[: self.num_nodes, 1:] = reshaped_nodes # put values back in + self.env_obs = s item_index = 0 # Do nodes first @@ -720,6 +751,13 @@ class Primaite(Env): protocol_index += 1 item_index += 1 + # Remove ID columns, remove links and flatten to 1D array + node_obs = self.env_obs[: self.num_nodes, 1:].flatten() + # Remove ID, remove all data except link traffic status + link_obs = self.env_obs[self.num_nodes :, 3:].flatten() + # Combine nodes and links + self.env_obs = np.append(node_obs, link_obs) + def load_config(self): """Loads config data in order to build the environment configuration.""" for item in self.config_data: @@ -1187,6 +1225,7 @@ class Primaite(Env): """ # reserve 0 action to be a nothing action + # Used to be {0: [1, 0, 0, 0]} actions = {0: [1, 0, 0, 0]} action_key = 1 From ae2f4d472ec78d67c4268352cd773fe097572bf2 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 31 May 2023 14:11:15 +0100 Subject: [PATCH 04/19] 1443 - reverted changes made to observation space and added config files for testing --- src/primaite/environment/primaite_env.py | 38 ------------------- .../single_action_space_lay_down_config.yaml | 29 ++++++++++++++ ...l => single_action_space_main_config.yaml} | 0 tests/test_single_action_space.py | 4 +- 4 files changed, 31 insertions(+), 40 deletions(-) create mode 100644 tests/config/single_action_space_lay_down_config.yaml rename tests/config/{single_action_space_config.yaml => single_action_space_main_config.yaml} (100%) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 84b485bd..49c45f3e 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -207,7 +207,6 @@ class Primaite(Env): # Calculate the number of items that need to be included in the # observation space - """ num_items = self.num_links + self.num_nodes # Set the number of observation parameters, being # of services plus id, # hardware state, file system state and SoftwareState (i.e. 4) @@ -222,23 +221,6 @@ class Primaite(Env): shape=self.observation_shape, dtype=np.int64, ) - """ - self.num_observation_parameters = ( - self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS - ) - # Define the observation space: - # There are: - # 4 Operating States (ON/OFF/RESETTING) + NONE (0) - # 4 OS States (GOOD/PATCHING/COMPROMISED) + NONE - # 5 Service States (NONE/GOOD/PATCHING/COMPROMISED/OVERWHELMED) + NONE - # There can be any number of services - # There are 5 node types No Traffic, Low Traffic, Medium Traffic, High Traffic, Overloaded/max traffic - self.observation_space = spaces.MultiDiscrete( - ([4, 4] + [5] * self.num_services) * self.num_nodes + [5] * self.num_links - ) - - # Define the observation shape - self.observation_shape = self.observation_space.sample().shape # This is the observation that is sent back via the rest and step functions self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) @@ -696,19 +678,6 @@ class Primaite(Env): def update_environent_obs(self): """Updates the observation space based on the node and link status.""" - # Convert back to more readable, original format - reshaped_nodes = self.env_obs[: -self.num_links].reshape( - self.num_nodes, self.num_services + 2 - ) - - # Add empty links back and add node ID back - s = np.zeros( - [reshaped_nodes.shape[0] + self.num_links, reshaped_nodes.shape[1] + 1], - dtype=np.int64, - ) - s[:, 0] = range(1, self.num_nodes + self.num_links + 1) # Adding ID back - s[: self.num_nodes, 1:] = reshaped_nodes # put values back in - self.env_obs = s item_index = 0 # Do nodes first @@ -751,13 +720,6 @@ class Primaite(Env): protocol_index += 1 item_index += 1 - # Remove ID columns, remove links and flatten to 1D array - node_obs = self.env_obs[: self.num_nodes, 1:].flatten() - # Remove ID, remove all data except link traffic status - link_obs = self.env_obs[self.num_nodes :, 3:].flatten() - # Combine nodes and links - self.env_obs = np.append(node_obs, link_obs) - def load_config(self): """Loads config data in order to build the environment configuration.""" for item in self.config_data: diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml new file mode 100644 index 00000000..058b790b --- /dev/null +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -0,0 +1,29 @@ +- itemType: ACTIONS + type: NODE +- itemType: STEPS + steps: 15 +- itemType: PORTS + portsList: + - port: '21' +- itemType: SERVICES + serviceList: + - name: ftp +- itemType: NODE + node_id: '1' + name: node + node_class: SERVICE + node_type: COMPUTER + priority: P1 + hardware_state: 'ON' + ip_address: 192.168.0.1 + software_state: GOOD + file_system_state: GOOD + services: + - name: ftp + port: '21' + state: GOOD +- itemType: POSITION + positions: + - node: '1' + x_pos: 309 + y_pos: 78 diff --git a/tests/config/single_action_space_config.yaml b/tests/config/single_action_space_main_config.yaml similarity index 100% rename from tests/config/single_action_space_config.yaml rename to tests/config/single_action_space_main_config.yaml diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 3ec1dc2e..8c87d57b 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -5,8 +5,8 @@ from tests.conftest import _get_primaite_env_from_config def test_single_action_space(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" env = _get_primaite_env_from_config( - main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + main_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_lay_down_config.yaml", + / "single_action_space_lay_down_config.yaml", ) print("Average Reward:", env.average_reward) From f72a80c9d2a8ddeac8bb81e58f4dd87a86271610 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Thu, 1 Jun 2023 16:27:25 +0100 Subject: [PATCH 05/19] 1443 - added in print test statements --- src/primaite/agents/utils.py | 2 ++ src/primaite/environment/primaite_env.py | 24 ++++++++++++++--- .../single_action_space_lay_down_config.yaml | 26 ++++++++++++++++++- .../single_action_space_main_config.yaml | 2 +- 4 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index d3924b24..1ada88ba 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -414,6 +414,8 @@ def is_valid_node_action(action): node_property = action_r[1] node_action = action_r[2] + # print("node property", node_property, "\nnode action", node_action) + if node_property == "NONE": return False if node_action == "NONE": diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 49c45f3e..5d783af1 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -42,6 +42,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) +_LOGGER.setLevel(logging.INFO) class Primaite(Env): @@ -235,6 +236,7 @@ class Primaite(Env): # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) + print(self.action_space, "NODE action space") elif self.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): @@ -246,11 +248,12 @@ class Primaite(Env): # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) + print(self.action_space, "ACL action space") else: - _LOGGER.info("Action space type ANY selected - Node + ACL") + _LOGGER.warning("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - + print(self.action_space, "ANY action space") # Set up a csv to store the results of the training try: now = datetime.now() # current date and time @@ -450,6 +453,7 @@ class Primaite(Env): Args: _action: The action space from the agent """ + # print("intepret action") # At the moment, actions are only affecting nodes if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) @@ -464,6 +468,7 @@ class Primaite(Env): ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: + print("invalid action type found") logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): @@ -1084,6 +1089,7 @@ class Primaite(Env): item: A config data item representing action info """ self.action_type = ActionType[action_info["type"]] + print("action type selected: ", self.action_type) def get_steps_info(self, steps_info): """ @@ -1187,9 +1193,8 @@ class Primaite(Env): """ # reserve 0 action to be a nothing action - # Used to be {0: [1, 0, 0, 0]} actions = {0: [1, 0, 0, 0]} - + # print("node dict function call", self.num_nodes + 1) action_key = 1 for node in range(1, self.num_nodes + 1): # 4 node properties (NONE, OPERATING, OS, SERVICE) @@ -1197,11 +1202,14 @@ class Primaite(Env): # Node Actions either: # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state # Use MAX to ensure we get them all + # print(self.num_services, "num services") for node_action in range(4): for service_state in range(self.num_services): action = [node, node_property, node_action, service_state] # check to see if its a nothing aciton (has no effect) + # print("action node",action) if is_valid_node_action(action): + print("true") actions[action_key] = action action_key += 1 @@ -1213,6 +1221,7 @@ class Primaite(Env): actions = {0: [0, 0, 0, 0, 0, 0]} action_key = 1 + # print("node count",self.num_nodes + 1) # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE for action_decision in range(3): # 2 possible action permissions 0 = DENY, 1 = CREATE @@ -1230,10 +1239,13 @@ class Primaite(Env): protocol, port, ] + # print("action acl", action) # Check to see if its an action we want to include as possible i.e. not a nothing action if is_valid_acl_action_extra(action): + print("true") actions[action_key] = action action_key += 1 + # print("false") return actions @@ -1247,6 +1259,8 @@ class Primaite(Env): node_action_dict = self.create_node_action_dict() acl_action_dict = self.create_acl_action_dict() + print(len(node_action_dict), len(acl_action_dict)) + # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other new_node_action_dict = { @@ -1257,4 +1271,6 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} + logging.warning("logging is working") + # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml index 058b790b..6d44356f 100644 --- a/tests/config/single_action_space_lay_down_config.yaml +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -1,5 +1,5 @@ - itemType: ACTIONS - type: NODE + type: ANY - itemType: STEPS steps: 15 - itemType: PORTS @@ -27,3 +27,27 @@ - node: '1' x_pos: 309 y_pos: 78 +- itemType: RED_POL + id: '1' + startStep: 1 + endStep: 3 + targetNodeId: '1' + initiator: DIRECT + type: FILE + protocol: NA + state: CORRUPT + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '2' + startStep: 3 + endStep: 15 + targetNodeId: '1' + initiator: DIRECT + type: FILE + protocol: NA + state: GOOD + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index 6f6bb4e6..7fcc002f 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -11,7 +11,7 @@ numEpisodes: 1 # Time delay between steps (for generic agents) timeDelay: 1 # Filename of the scenario / laydown -configFilename: one_node_states_on_off_lay_down_config.yaml +configFilename: single_action_space_lay_down_config.yaml # Type of session to be run (TRAINING or EVALUATION) sessionType: TRAINING # Determine whether to load an agent from file From d0c11a14ed2214ab0257f268bbed5ecca1783f15 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 2 Jun 2023 09:51:15 +0100 Subject: [PATCH 06/19] 893 - added ANY to enums.py --- src/primaite/common/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 20660e86..0e43ea38 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -82,6 +82,7 @@ class ActionType(Enum): NODE = 0 ACL = 1 + ANY = 2 class FileSystemState(Enum): From 66fdae5df17f18ff7d5674b6e8088aa8b973c835 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Fri, 2 Jun 2023 11:55:31 +0100 Subject: [PATCH 07/19] 893 - added test which shows the new action space has been created when ANY is selected in single_action_space_lay_down_config.yaml --- src/primaite/environment/primaite_env.py | 7 ++++--- tests/test_single_action_space.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5d783af1..9b0bbeec 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -249,11 +249,13 @@ class Primaite(Env): self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) print(self.action_space, "ACL action space") - else: - _LOGGER.warning("Action space type ANY selected - Node + ACL") + elif self.action_type == ActionType.ANY: + _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) print(self.action_space, "ANY action space") + else: + _LOGGER.info("Invalid action type selected") # Set up a csv to store the results of the training try: now = datetime.now() # current date and time @@ -1271,6 +1273,5 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} - logging.warning("logging is working") # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 8c87d57b..203a6232 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -9,4 +9,18 @@ def test_single_action_space(): lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) - print("Average Reward:", env.average_reward) + """ + nv.action_space.n is the total number of actions in the Discrete action space + This is the number of actions the agent has to choose from. + + The total number of actions that an agent can type when a NODE action type is selected is: 6 + The total number of actions that an agent can take when an ACL action type is selected is: 7 + + These action spaces are combined and the total number of actions is: 12 + This is due to both actions containing the action to "Do nothing", so it needs to be removed from one of the spaces, + to avoid duplicate actions. + + As a result, 12 is the total number of action spaces. + """ + # e + assert env.action_space.n == 12 From 1a7d629d5ac7f51150e987edbdaaf4481eaba953 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 11:00:41 +0100 Subject: [PATCH 08/19] 893 - added new tests to test action space size and node is completing both sets of actions in a single episode and created new main config --- src/primaite/environment/primaite_env.py | 1 + ..._space_fixed_blue_actions_main_config.yaml | 89 +++++++++++++++++++ .../single_action_space_lay_down_config.yaml | 50 ++++++----- tests/conftest.py | 40 +++++++++ tests/test_single_action_space.py | 57 +++++++++--- 5 files changed, 199 insertions(+), 38 deletions(-) create mode 100644 tests/config/single_action_space_fixed_blue_actions_main_config.yaml diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 9b0bbeec..be16590f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1273,5 +1273,6 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} + print("combined_action_dict entry", combined_action_dict.items()) # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml new file mode 100644 index 00000000..becbc0f3 --- /dev/null +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: GENERIC_TEST +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: single_action_space_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Operating State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node O/S or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml index 6d44356f..ab3b170e 100644 --- a/tests/config/single_action_space_lay_down_config.yaml +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -15,39 +15,41 @@ node_type: COMPUTER priority: P1 hardware_state: 'ON' + ip_address: 192.168.0.14 + software_state: GOOD + file_system_state: GOOD + services: + - name: ftp + port: '21' + state: COMPROMISED +- itemType: NODE + node_id: '2' + name: server_1 + node_class: SERVICE + node_type: SERVER + priority: P1 + hardware_state: 'ON' ip_address: 192.168.0.1 software_state: GOOD file_system_state: GOOD services: - name: ftp port: '21' - state: GOOD + state: COMPROMISED - itemType: POSITION positions: - node: '1' x_pos: 309 y_pos: 78 -- itemType: RED_POL - id: '1' - startStep: 1 - endStep: 3 - targetNodeId: '1' - initiator: DIRECT - type: FILE - protocol: NA - state: CORRUPT - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA -- itemType: RED_POL - id: '2' - startStep: 3 + - node: '2' + x_pos: 200 + y_pos: 78 +- itemType: RED_IER + id: '3' + startStep: 2 endStep: 15 - targetNodeId: '1' - initiator: DIRECT - type: FILE - protocol: NA - state: GOOD - sourceNodeId: NA - sourceNodeService: NA - sourceNodeServiceState: NA + load: 1000 + protocol: ftp + port: CORRUPT + source: '1' + destination: '2' diff --git a/tests/conftest.py b/tests/conftest.py index 1e987223..dd732e78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -164,10 +164,13 @@ def _get_primaite_env_from_config( # Load in config data load_config_values() env = Primaite(config_values, []) + # Get the number of steps (which is stored in the child config file) config_values.num_steps = env.episode_steps if env.config_values.agent_identifier == "GENERIC": run_generic(env, config_values) + elif env.config_values.agent_identifier == "GENERIC_TEST": + run_generic_set_actions(env, config_values) return env @@ -197,3 +200,40 @@ def run_generic(env, config_values): # env.reset() # env.close() + + +def run_generic_set_actions(env, config_values): + """Run against a generic agent with specified blue agent actions.""" + # Reset the environment at the start of the episode + # env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = 0 + if step == 5: + # [1, 1, 2, 1, 1, 1] + # Creates an ACL rule + # Deny traffic from server_1 to node_1 on port FTP + action = 7 + elif step == 7: + # [1, 1, 2, 0] Node Action + # Sets Node 1 Hardware State to OFF + # Does not resolve any service + action = 16 + print(action, "ran") + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + # env.reset() + + # env.close() diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 203a6232..fda4c96c 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,26 +1,55 @@ +from primaite.common.enums import HardwareState from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config -def test_single_action_space(): +def test_single_action_space_is_valid(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" env = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) - """ - nv.action_space.n is the total number of actions in the Discrete action space - This is the number of actions the agent has to choose from. + # Retrieve the action space dictionary values from environment + env_action_space_dict = env.action_dict.values() + # Flags to check the conditions of the action space + contains_acl_actions = False + contains_node_actions = False + both_action_spaces = False + # Loop through each element of the list (which is every value from the dictionary) + for dict_item in env_action_space_dict: + # Node action detected + if len(dict_item) == 4: + contains_node_actions = True + # Link action detected + elif len(dict_item) == 6: + contains_acl_actions = True + # If both are there then the ANY action type is working + if contains_node_actions and contains_acl_actions: + both_action_spaces = True + # Check condition should be True + assert both_action_spaces - The total number of actions that an agent can type when a NODE action type is selected is: 6 - The total number of actions that an agent can take when an ACL action type is selected is: 7 - These action spaces are combined and the total number of actions is: 12 - This is due to both actions containing the action to "Do nothing", so it needs to be removed from one of the spaces, - to avoid duplicate actions. - - As a result, 12 is the total number of action spaces. - """ - # e - assert env.action_space.n == 12 +def test_agent_is_executing_actions_from_both_spaces(): + """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "single_action_space_lay_down_config.yaml", + ) + # Retrieve hardware state of computer_1 node in laydown config + # Agent turned this off in Step 5 + computer_node_hardware_state = env.nodes["1"].hardware_state + # Retrieve the Access Control List object stored by the environment at the end of the episode + access_control_list = env.acl + # Use the Access Control List object acl object attribute to get dictionary + # Use dictionary.values() to get total list of all items in the dictionary + acl_rules_list = access_control_list.acl.values() + # Length of this list tells you how many items are in the dictionary + # This number is the frequency of Access Control Rules in the environment + # In the scenario, we specified that the agent should create only 1 acl rule + num_of_rules = len(acl_rules_list) + # Therefore these statements below MUST be true + assert computer_node_hardware_state == HardwareState.OFF and num_of_rules == 1 From dc7be7d8e6e210a3ed6eec7d8236d3886789202e Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 11:10:38 +0100 Subject: [PATCH 09/19] 893 - set the action_space to NOTHING so test_reward.py passes and removed unnecessary test print statements --- src/primaite/environment/primaite_env.py | 20 +------------------- tests/conftest.py | 5 ++--- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index be16590f..33af7d89 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -236,7 +236,6 @@ class Primaite(Env): # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa self.action_dict = self.create_node_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - print(self.action_space, "NODE action space") elif self.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): @@ -248,12 +247,10 @@ class Primaite(Env): # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - print(self.action_space, "ACL action space") elif self.action_type == ActionType.ANY: _LOGGER.info("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) - print(self.action_space, "ANY action space") else: _LOGGER.info("Invalid action type selected") # Set up a csv to store the results of the training @@ -455,7 +452,6 @@ class Primaite(Env): Args: _action: The action space from the agent """ - # print("intepret action") # At the moment, actions are only affecting nodes if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) @@ -470,7 +466,6 @@ class Primaite(Env): ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: - print("invalid action type found") logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): @@ -1091,7 +1086,6 @@ class Primaite(Env): item: A config data item representing action info """ self.action_type = ActionType[action_info["type"]] - print("action type selected: ", self.action_type) def get_steps_info(self, steps_info): """ @@ -1196,7 +1190,6 @@ class Primaite(Env): """ # reserve 0 action to be a nothing action actions = {0: [1, 0, 0, 0]} - # print("node dict function call", self.num_nodes + 1) action_key = 1 for node in range(1, self.num_nodes + 1): # 4 node properties (NONE, OPERATING, OS, SERVICE) @@ -1204,14 +1197,11 @@ class Primaite(Env): # Node Actions either: # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state # Use MAX to ensure we get them all - # print(self.num_services, "num services") for node_action in range(4): for service_state in range(self.num_services): action = [node, node_property, node_action, service_state] - # check to see if its a nothing aciton (has no effect) - # print("action node",action) + # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): - print("true") actions[action_key] = action action_key += 1 @@ -1223,7 +1213,6 @@ class Primaite(Env): actions = {0: [0, 0, 0, 0, 0, 0]} action_key = 1 - # print("node count",self.num_nodes + 1) # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE for action_decision in range(3): # 2 possible action permissions 0 = DENY, 1 = CREATE @@ -1241,13 +1230,10 @@ class Primaite(Env): protocol, port, ] - # print("action acl", action) # Check to see if its an action we want to include as possible i.e. not a nothing action if is_valid_acl_action_extra(action): - print("true") actions[action_key] = action action_key += 1 - # print("false") return actions @@ -1261,8 +1247,6 @@ class Primaite(Env): node_action_dict = self.create_node_action_dict() acl_action_dict = self.create_acl_action_dict() - print(len(node_action_dict), len(acl_action_dict)) - # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other new_node_action_dict = { @@ -1273,6 +1257,4 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} - print("combined_action_dict entry", combined_action_dict.items()) - # print(len(list(combined_action_dict.values()))) return combined_action_dict diff --git a/tests/conftest.py b/tests/conftest.py index dd732e78..3a99bcf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -184,8 +184,8 @@ def run_generic(env, config_values): # Send the observation space to the agent to get an action # TEMP - random action for now # action = env.blue_agent_action(obs) - action = env.action_space.sample() - + # action = env.action_space.sample() + action = 0 # Run the simulation step on the live environment obs, reward, done, info = env.step(action) @@ -222,7 +222,6 @@ def run_generic_set_actions(env, config_values): # Sets Node 1 Hardware State to OFF # Does not resolve any service action = 16 - print(action, "ran") # Run the simulation step on the live environment obs, reward, done, info = env.step(action) From 5add9d620ce2a8e70da43fcb226a487af89a774d Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 11:57:04 +0100 Subject: [PATCH 10/19] 893 - updated the docs to reflect changes made to action space --- docs/source/about.rst | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 8cc08b13..242d34bb 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -334,6 +334,8 @@ The full observation space would have 15 node-related elements and 3 link-relate gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) + * Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...} + The placeholders inside the list under the key '1' mean the following: Action Spaces ************** @@ -342,29 +344,38 @@ The action space available to the blue agent comes in two types: 1. Node-based 2. Access Control List + 3. Any The choice of action space used during a training session is determined in the config_[name].yaml file. **Node-Based** -The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym multidiscrete type, as follows: +The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows: - * [0, num nodes] - Node ID (0 = nothing, node ID) - * [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state) - * [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore) - * [0, num services] - Resolves to service ID (0 = nothing, resolves to service) + * Dictionary item {... ,1: [x1, x2, x3,x4] ...} + The placeholders inside the list under the key '1' mean the following: + + * [0, num nodes] - Node ID (0 = nothing, node ID) + * [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state) + * [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore) + * [0, num services] - Resolves to service ID (0 = nothing, resolves to service) **Access Control List** -The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI multidiscrete type, as follows: +The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows: - * [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) - * [0, 1] - Permission (0 = DENY, 1 = ALLOW) - * [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) - * [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) - * [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) - * [0, num ports] - Port (0 = any, then 1 -> x resolving to port) + * [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) + * [0, 1] - Permission (0 = DENY, 1 = ALLOW) + * [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) + * [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) + * [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) + * [0, num ports] - Port (0 = any, then 1 -> x resolving to port) + +**ANY** +The agent is able to carry out both **Node-Based** and **Access Control List** operations. + +This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above. Rewards ******* From e0ed97be36ff891cc3e327979d4cf80d54cd223c Mon Sep 17 00:00:00 2001 From: Sunil Samra Date: Tue, 6 Jun 2023 12:07:22 +0000 Subject: [PATCH 11/19] Apply suggestions from code review --- docs/source/about.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 242d34bb..a59701f6 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -344,7 +344,7 @@ The action space available to the blue agent comes in two types: 1. Node-based 2. Access Control List - 3. Any + 3. Any (Agent can take both node-based and ACL-based actions) The choice of action space used during a training session is determined in the config_[name].yaml file. From 58a87ee0c8ad8ede1b20bcfd97f557416221b95d Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:12:28 +0100 Subject: [PATCH 12/19] 893 - applied changes raised during PR --- docs/source/about.rst | 9 ++++----- src/primaite/environment/primaite_env.py | 14 +++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/source/about.rst b/docs/source/about.rst index 242d34bb..a16fadd3 100644 --- a/docs/source/about.rst +++ b/docs/source/about.rst @@ -334,9 +334,6 @@ The full observation space would have 15 node-related elements and 3 link-relate gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5]) - * Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...} - The placeholders inside the list under the key '1' mean the following: - Action Spaces ************** @@ -344,7 +341,7 @@ The action space available to the blue agent comes in two types: 1. Node-based 2. Access Control List - 3. Any + 3. Any (Agent can take both node-based and ACL-based actions) The choice of action space used during a training session is determined in the config_[name].yaml file. @@ -364,6 +361,8 @@ The agent is able to influence the status of nodes by switching them off, resett The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows: + * Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...} + The placeholders inside the list under the key '1' mean the following: * [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) * [0, 1] - Permission (0 = DENY, 1 = ALLOW) @@ -375,7 +374,7 @@ The blue agent is able to influence the configuration of the Access Control List **ANY** The agent is able to carry out both **Node-Based** and **Access Control List** operations. -This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above. +This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above. Rewards ******* diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index d2af20f9..4facb7b2 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -536,14 +536,14 @@ class Primaite(Env): _action: The action space from the agent """ # Convert discrete value back to multidiscrete - multidiscrete_action = self.action_dict[_action] + readable_action = self.action_dict[_action] - action_decision = multidiscrete_action[0] - action_permission = multidiscrete_action[1] - action_source_ip = multidiscrete_action[2] - action_destination_ip = multidiscrete_action[3] - action_protocol = multidiscrete_action[4] - action_port = multidiscrete_action[5] + action_decision = readable_action[0] + action_permission = readable_action[1] + action_source_ip = readable_action[2] + action_destination_ip = readable_action[3] + action_protocol = readable_action[4] + action_port = readable_action[5] if action_decision == 0: # It's decided to do nothing From efd0f6ed08653cbcfb67e2b3b6db17c8680261e6 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:21:04 +0100 Subject: [PATCH 13/19] 893 - returned config_values in conftest to move run_generic_set_actions into test_single_action_space.py --- src/primaite/environment/primaite_env.py | 2 +- tests/conftest.py | 40 +------------------ tests/test_observation_space.py | 4 +- tests/test_reward.py | 2 +- tests/test_single_action_space.py | 49 +++++++++++++++++++++++- 5 files changed, 52 insertions(+), 45 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 4facb7b2..1794504a 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -368,7 +368,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - print(f" Step {self.step_count} Reward: {str(reward)}") + # print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count diff --git a/tests/conftest.py b/tests/conftest.py index 3a99bcf6..740f65b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,10 +169,8 @@ def _get_primaite_env_from_config( if env.config_values.agent_identifier == "GENERIC": run_generic(env, config_values) - elif env.config_values.agent_identifier == "GENERIC_TEST": - run_generic_set_actions(env, config_values) - return env + return env, config_values def run_generic(env, config_values): @@ -200,39 +198,3 @@ def run_generic(env, config_values): # env.reset() # env.close() - - -def run_generic_set_actions(env, config_values): - """Run against a generic agent with specified blue agent actions.""" - # Reset the environment at the start of the episode - # env.reset() - for episode in range(0, config_values.num_episodes): - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - action = 0 - if step == 5: - # [1, 1, 2, 1, 1, 1] - # Creates an ACL rule - # Deny traffic from server_1 to node_1 on port FTP - action = 7 - elif step == 7: - # [1, 1, 2, 0] Node Action - # Sets Node 1 Hardware State to OFF - # Does not resolve any service - action = 16 - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - - # Reset the environment at the end of the episode - # env.reset() - - # env.close() diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 6a187761..d6eaa3b7 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -6,7 +6,7 @@ from tests.conftest import _get_primaite_env_from_config def test_creating_env_with_box_obs(): """Try creating env with box observation space.""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", ) @@ -21,7 +21,7 @@ def test_creating_env_with_box_obs(): def test_creating_env_with_multidiscrete_obs(): """Try creating env with MultiDiscrete observation space.""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "multidiscrete_obs_space_laydown_config.yaml", diff --git a/tests/test_reward.py b/tests/test_reward.py index 4925a434..c54ee32f 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -8,7 +8,7 @@ def test_rewards_are_being_penalised_at_each_step_function(): When the initial state is OFF compared to reference state which is ON. """ - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index fda4c96c..5fc6cb7e 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,15 +1,57 @@ +import time + from primaite.common.enums import HardwareState from tests import TEST_CONFIG_ROOT from tests.conftest import _get_primaite_env_from_config +def run_generic_set_actions(env, config_values): + """Run against a generic agent with specified blue agent actions.""" + # Reset the environment at the start of the episode + # env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = 0 + print("Episode:", episode, "\nStep:", step) + if step == 5: + # [1, 1, 2, 1, 1, 1] + # Creates an ACL rule + # Deny traffic from server_1 to node_1 on port FTP + action = 7 + elif step == 7: + # [1, 1, 2, 0] Node Action + # Sets Node 1 Hardware State to OFF + # Does not resolve any service + action = 16 + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + # env.reset() + + # env.close() + + def test_single_action_space_is_valid(): """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) + + run_generic_set_actions(env, config_values) + # Retrieve the action space dictionary values from environment env_action_space_dict = env.action_dict.values() # Flags to check the conditions of the action space @@ -33,12 +75,15 @@ def test_single_action_space_is_valid(): def test_agent_is_executing_actions_from_both_spaces(): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" - env = _get_primaite_env_from_config( + env, config_values = _get_primaite_env_from_config( main_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) + + run_generic_set_actions(env, config_values) + # Retrieve hardware state of computer_1 node in laydown config # Agent turned this off in Step 5 computer_node_hardware_state = env.nodes["1"].hardware_state From 2e1bdf2361829b4a5e4732a65a14cb1f615e299f Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:23:08 +0100 Subject: [PATCH 14/19] 893 - changed action in conftest.py back to sample of the environment action space --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 740f65b7..41c1bc94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -182,8 +182,8 @@ def run_generic(env, config_values): # Send the observation space to the agent to get an action # TEMP - random action for now # action = env.blue_agent_action(obs) - # action = env.action_space.sample() - action = 0 + # action = 0 + action = env.action_space.sample() # Run the simulation step on the live environment obs, reward, done, info = env.step(action) From 10585490fe434e07a00f2f779cd5ecde3ad7e384 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:47:07 +0100 Subject: [PATCH 15/19] 893 - --- src/primaite/environment/primaite_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index bc52c887..03798c3b 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -368,7 +368,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count From 0817a4cad3a8f39caf6b925ca27c187caafaf90b Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 13:49:22 +0100 Subject: [PATCH 16/19] 893 - added consistent action for test_reward.py --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 41c1bc94..1b3a06b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -182,8 +182,9 @@ def run_generic(env, config_values): # Send the observation space to the agent to get an action # TEMP - random action for now # action = env.blue_agent_action(obs) - # action = 0 - action = env.action_space.sample() + # action = env.action_space.sample() + action = 0 + # Run the simulation step on the live environment obs, reward, done, info = env.step(action) From e17e5ac4b9e5b451a2d228233c6f0db68d5f0f4c Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 6 Jun 2023 15:54:35 +0100 Subject: [PATCH 17/19] 893 - added new line for assert statements --- src/primaite/environment/primaite_env.py | 1 + tests/test_single_action_space.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 03798c3b..3de1111b 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -423,6 +423,7 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes + print("ACTION:", self.action_dict[_action]) if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.action_type == ActionType.ACL: diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 5fc6cb7e..10701a6a 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -19,7 +19,7 @@ def run_generic_set_actions(env, config_values): if step == 5: # [1, 1, 2, 1, 1, 1] # Creates an ACL rule - # Deny traffic from server_1 to node_1 on port FTP + # Allows traffic from server_1 to node_1 on port FTP action = 7 elif step == 7: # [1, 1, 2, 0] Node Action @@ -97,4 +97,5 @@ def test_agent_is_executing_actions_from_both_spaces(): # In the scenario, we specified that the agent should create only 1 acl rule num_of_rules = len(acl_rules_list) # Therefore these statements below MUST be true - assert computer_node_hardware_state == HardwareState.OFF and num_of_rules == 1 + assert computer_node_hardware_state == HardwareState.OFF + assert num_of_rules == 1 From 281bb786127a7cb1970cb7dd6abca224e2a1019f Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 7 Jun 2023 09:19:30 +0100 Subject: [PATCH 18/19] 893 - removed print statements for demonstration --- src/primaite/environment/primaite_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 3de1111b..03798c3b 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -423,7 +423,6 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes - print("ACTION:", self.action_dict[_action]) if self.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.action_type == ActionType.ACL: From 6089fb6950b83678da20dd2d8b1c4fbabf1f36b2 Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Wed, 7 Jun 2023 14:39:52 +0100 Subject: [PATCH 19/19] 893 - removed unnecessary functions from utils.py and changed single_action_space_fixed_blue_actions_main_config.yaml back to GENERIC agentIdentifier after PR comments --- src/primaite/agents/utils.py | 386 +----------------- ..._space_fixed_blue_actions_main_config.yaml | 2 +- tests/test_single_action_space.py | 3 +- 3 files changed, 3 insertions(+), 388 deletions(-) diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index 1ada88ba..bb967906 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,288 +1,4 @@ -import logging -import os.path -from datetime import datetime - -import numpy as np -import yaml - -from primaite.common.config_values_main import ConfigValuesMain -from primaite.common.enums import ( - ActionType, - HardwareState, - LinkStatus, - NodeHardwareAction, - NodePOLType, - NodeSoftwareAction, - SoftwareState, -) - - -def load_config_values(config_path): - """Loads the config values from the main config file into a config object.""" - config_file_main = open(config_path, "r") - config_data = yaml.safe_load(config_file_main) - # Create a config class - config_values = ConfigValuesMain() - - try: - # Generic - config_values.red_agent_identifier = config_data["redAgentIdentifier"] - config_values.action_type = ActionType[config_data["actionType"]] - config_values.config_filename_use_case = config_data["configFilename"] - # Reward values - # Generic - config_values.all_ok = float(config_data["allOk"]) - # Node Operating State - config_values.off_should_be_on = float(config_data["offShouldBeOn"]) - config_values.off_should_be_resetting = float( - config_data["offShouldBeResetting"] - ) - config_values.on_should_be_off = float(config_data["onShouldBeOff"]) - config_values.on_should_be_resetting = float(config_data["onShouldBeResetting"]) - config_values.resetting_should_be_on = float(config_data["resettingShouldBeOn"]) - config_values.resetting_should_be_off = float( - config_data["resettingShouldBeOff"] - ) - # Node O/S or Service State - config_values.good_should_be_patching = float( - config_data["goodShouldBePatching"] - ) - config_values.good_should_be_compromised = float( - config_data["goodShouldBeCompromised"] - ) - config_values.good_should_be_overwhelmed = float( - config_data["goodShouldBeOverwhelmed"] - ) - config_values.patching_should_be_good = float( - config_data["patchingShouldBeGood"] - ) - config_values.patching_should_be_compromised = float( - config_data["patchingShouldBeCompromised"] - ) - config_values.patching_should_be_overwhelmed = float( - config_data["patchingShouldBeOverwhelmed"] - ) - config_values.compromised_should_be_good = float( - config_data["compromisedShouldBeGood"] - ) - config_values.compromised_should_be_patching = float( - config_data["compromisedShouldBePatching"] - ) - config_values.compromised_should_be_overwhelmed = float( - config_data["compromisedShouldBeOverwhelmed"] - ) - config_values.compromised = float(config_data["compromised"]) - config_values.overwhelmed_should_be_good = float( - config_data["overwhelmedShouldBeGood"] - ) - config_values.overwhelmed_should_be_patching = float( - config_data["overwhelmedShouldBePatching"] - ) - config_values.overwhelmed_should_be_compromised = float( - config_data["overwhelmedShouldBeCompromised"] - ) - config_values.overwhelmed = float(config_data["overwhelmed"]) - # IER status - config_values.red_ier_running = float(config_data["redIerRunning"]) - config_values.green_ier_blocked = float(config_data["greenIerBlocked"]) - # Patching / Reset durations - config_values.os_patching_duration = int(config_data["osPatchingDuration"]) - config_values.node_reset_duration = int(config_data["nodeResetDuration"]) - config_values.service_patching_duration = int( - config_data["servicePatchingDuration"] - ) - - except Exception as e: - print(f"Could not save load config data: {e} ") - - return config_values - - -def configure_logging(log_name): - """Configures logging.""" - try: - now = datetime.now() # current date and time - time = now.strftime("%Y%m%d_%H%M%S") - filename = "/app/logs/" + log_name + "/" + time + ".log" - path = f"/app/logs/{log_name}" - is_dir = os.path.isdir(path) - if not is_dir: - os.makedirs(path) - logging.basicConfig( - filename=filename, - filemode="w", - format="%(asctime)s - %(levelname)s - %(message)s", - datefmt="%d-%b-%y %H:%M:%S", - level=logging.INFO, - ) - except Exception as e: - print("ERROR: Could not start logging", e) - - -def transform_change_obs_readable(obs): - """Transform list of transactions to readable list of each observation property. - - example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] - """ - ids = [i for i in obs[:, 0]] - operating_states = [HardwareState(i).name for i in obs[:, 1]] - os_states = [SoftwareState(i).name for i in obs[:, 2]] - new_obs = [ids, operating_states, os_states] - - for service in range(3, obs.shape[1]): - # Links bit/s don't have a service state - service_states = [ - SoftwareState(i).name if i <= 4 else i for i in obs[:, service] - ] - new_obs.append(service_states) - - return new_obs - - -def transform_obs_readable(obs): - """ - Transform obs readable function. - - example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]. - """ - changed_obs = transform_change_obs_readable(obs) - new_obs = list(zip(*changed_obs)) - # Convert list of tuples to list of lists - new_obs = [list(i) for i in new_obs] - - return new_obs - - -def convert_to_new_obs(obs, num_nodes=10): - """Convert original gym Box observation space to new multiDiscrete observation space.""" - # Remove ID columns, remove links and flatten to MultiDiscrete observation space - new_obs = obs[:num_nodes, 1:].flatten() - return new_obs - - -def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): - """ - Convert to old observation, links filled with 0's as no information is included in new observation space. - - example: - obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) - - new_obs = array([[ 1, 1, 1, 1], - [ 2, 1, 1, 1], - [ 3, 1, 1, 1], - ... - [20, 0, 0, 0]]) - """ - # Convert back to more readable, original format - reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) - - # Add empty links back and add node ID back - s = np.zeros( - [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], - dtype=np.int64, - ) - s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back - s[:num_nodes, 1:] = reshaped_nodes # put values back in - new_obs = s - - # Add links back in - links = obs[-num_links:] - # Links will be added to the last protocol/service slot but they are not specific to that service - new_obs[num_nodes:, -1] = links - - return new_obs - - -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """Return string describing change between two observations. - - example: - obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) - obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) - output = 'ID 1: SERVICE 2 set to GOOD' - """ - obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) - obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) - list_of_changes = [] - for n, row in enumerate(obs1 - obs2): - if row.any() != 0: - relevant_changes = np.where(row != 0, obs2[n], -1) - relevant_changes[0] = obs2[n, 0] # ID is always relevant - is_link = relevant_changes[0] > num_nodes - desc = _describe_obs_change_helper(relevant_changes, is_link) - list_of_changes.append(desc) - - change_string = "\n ".join(list_of_changes) - if len(list_of_changes) > 0: - change_string = "\n " + change_string - return change_string - - -def _describe_obs_change_helper(obs_change, is_link): - """ - Helper funcion to describe what has changed. - - example: - [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" - - Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' - """ - # Indexes where a change has occured, not including 0th index - index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] - # Node pol types, Indexes >= 3 are service nodes - node_pol_types = [ - NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) - for i in index_changed - ] - # Account for hardware states, software sattes and links - states = [ - LinkStatus(obs_change[i]).name - if is_link - else HardwareState(obs_change[i]).name - if i == 1 - else SoftwareState(obs_change[i]).name - for i in index_changed - ] - - if not is_link: - desc = f"ID {obs_change[0]}:" - for node_pol_type, state in list(zip(node_pol_types, states)): - desc = desc + " " + node_pol_type + " changed to " + state + "." - else: - desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." - - return desc - - -def transform_action_node_enum(action): - """ - Convert a node action from readable string format, to enumerated format. - - example: - [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - """ - action_node_id = action[0] - action_node_property = NodePOLType[action[1]].value - - if action[1] == "OPERATING": - property_action = NodeHardwareAction[action[2]].value - elif action[1] == "OS" or action[1] == "SERVICE": - property_action = NodeSoftwareAction[action[2]].value - else: - property_action = 0 - - action_service_index = action[3] - - new_action = [ - action_node_id, - action_node_property, - property_action, - action_service_index, - ] - - return new_action +from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction def transform_action_node_readable(action): @@ -307,31 +23,6 @@ def transform_action_node_readable(action): return new_action -def node_action_description(action): - """Generate string describing a node-based action.""" - if isinstance(action[1], (int, np.int64)): - # transform action to readable format - action = transform_action_node_readable(action) - - node_id = action[0] - node_property = action[1] - property_action = action[2] - service_id = action[3] - - if property_action == "NONE": - return "" - if node_property == "OPERATING" or node_property == "OS": - description = f"NODE {node_id}, {node_property}, SET TO {property_action}" - elif node_property == "SERVICE": - description = ( - f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" - ) - else: - return "" - - return description - - def transform_action_acl_readable(action): """ Transform an ACL action to a more readable format. @@ -354,52 +45,6 @@ def transform_action_acl_readable(action): return new_action -def transform_action_acl_enum(action): - """Convert a acl action from readable string format, to enumerated format.""" - action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} - action_permissions = {"DENY": 0, "ALLOW": 1} - - action_decision = action_decisions[action[0]] - action_permission = action_permissions[action[1]] - - # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index - new_action = [action_decision, action_permission] + list(action[2:6]) - for n, val in enumerate(list(action[2:6])): - if val == "ANY": - new_action[n + 2] = 0 - - new_action = np.array(new_action) - return new_action - - -def acl_action_description(action): - """Generate string describing a acl-based action.""" - if isinstance(action[0], (int, np.int64)): - # transform action to readable format - action = transform_action_acl_readable(action) - if action[0] == "NONE": - description = "NO ACL RULE APPLIED" - else: - description = ( - f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," - f" for protocol/service index {action[4]} on port index {action[5]}" - ) - - return description - - -def get_node_of_ip(ip, node_dict): - """ - Get the node ID of an IP address. - - node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - """ - for node_key, node_value in node_dict.items(): - node_ip = node_value.get_ip_address() - if node_ip == ip: - return node_key - - def is_valid_node_action(action): """Is the node action an actual valid action. @@ -480,32 +125,3 @@ def is_valid_acl_action_extra(action): return False return True - - -def get_new_action(old_action, action_dict): - """Get new action (e.g. 32) from old action e.g. [1,1,1,0]. - - old_action can be either node or acl action type. - """ - for key, val in action_dict.items(): - if list(val) == list(old_action): - return key - # Not all possible actions are included in dict, only valid action are - # if action is not in the dict, its an invalid action so return 0 - return 0 - - -def get_action_description(action, action_dict): - """Get a string describing/explaining what an action is doing in words.""" - action_array = action_dict[action] - if len(action_array) == 4: - # node actions have length 4 - action_description = node_action_description(action_array) - elif len(action_array) == 6: - # acl actions have length 6 - action_description = acl_action_description(action_array) - else: - # Should never happen - action_description = "Unrecognised action" - - return action_description diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index becbc0f3..7fcc002f 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -5,7 +5,7 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agentIdentifier: GENERIC_TEST +agentIdentifier: GENERIC # Number of episodes to run per session numEpisodes: 1 # Time delay between steps (for generic agents) diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 10701a6a..75d57f5d 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -81,9 +81,8 @@ def test_agent_is_executing_actions_from_both_spaces(): lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ) - + # Run environment with specified fixed blue agent actions only run_generic_set_actions(env, config_values) - # Retrieve hardware state of computer_1 node in laydown config # Agent turned this off in Step 5 computer_node_hardware_state = env.nodes["1"].hardware_state