diff --git a/src/primaite/agents/__init__.py b/src/primaite/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py new file mode 100644 index 00000000..d3924b24 --- /dev/null +++ b/src/primaite/agents/utils.py @@ -0,0 +1,509 @@ +import logging +import os.path +from datetime import datetime + +import numpy as np +import yaml + +from primaite.common.config_values_main import ConfigValuesMain +from primaite.common.enums import ( + ActionType, + HardwareState, + LinkStatus, + NodeHardwareAction, + NodePOLType, + NodeSoftwareAction, + SoftwareState, +) + + +def load_config_values(config_path): + """Loads the config values from the main config file into a config object.""" + config_file_main = open(config_path, "r") + config_data = yaml.safe_load(config_file_main) + # Create a config class + config_values = ConfigValuesMain() + + try: + # Generic + config_values.red_agent_identifier = config_data["redAgentIdentifier"] + config_values.action_type = ActionType[config_data["actionType"]] + config_values.config_filename_use_case = config_data["configFilename"] + # Reward values + # Generic + config_values.all_ok = float(config_data["allOk"]) + # Node Operating State + config_values.off_should_be_on = float(config_data["offShouldBeOn"]) + config_values.off_should_be_resetting = float( + config_data["offShouldBeResetting"] + ) + config_values.on_should_be_off = float(config_data["onShouldBeOff"]) + config_values.on_should_be_resetting = float(config_data["onShouldBeResetting"]) + config_values.resetting_should_be_on = float(config_data["resettingShouldBeOn"]) + config_values.resetting_should_be_off = float( + config_data["resettingShouldBeOff"] + ) + # Node O/S or Service State + config_values.good_should_be_patching = float( + config_data["goodShouldBePatching"] + ) + config_values.good_should_be_compromised = float( + config_data["goodShouldBeCompromised"] + ) + config_values.good_should_be_overwhelmed = float( + config_data["goodShouldBeOverwhelmed"] + ) + config_values.patching_should_be_good = float( + config_data["patchingShouldBeGood"] + ) + config_values.patching_should_be_compromised = float( + config_data["patchingShouldBeCompromised"] + ) + config_values.patching_should_be_overwhelmed = float( + config_data["patchingShouldBeOverwhelmed"] + ) + config_values.compromised_should_be_good = float( + config_data["compromisedShouldBeGood"] + ) + config_values.compromised_should_be_patching = float( + config_data["compromisedShouldBePatching"] + ) + config_values.compromised_should_be_overwhelmed = float( + config_data["compromisedShouldBeOverwhelmed"] + ) + config_values.compromised = float(config_data["compromised"]) + config_values.overwhelmed_should_be_good = float( + config_data["overwhelmedShouldBeGood"] + ) + config_values.overwhelmed_should_be_patching = float( + config_data["overwhelmedShouldBePatching"] + ) + config_values.overwhelmed_should_be_compromised = float( + config_data["overwhelmedShouldBeCompromised"] + ) + config_values.overwhelmed = float(config_data["overwhelmed"]) + # IER status + config_values.red_ier_running = float(config_data["redIerRunning"]) + config_values.green_ier_blocked = float(config_data["greenIerBlocked"]) + # Patching / Reset durations + config_values.os_patching_duration = int(config_data["osPatchingDuration"]) + config_values.node_reset_duration = int(config_data["nodeResetDuration"]) + config_values.service_patching_duration = int( + config_data["servicePatchingDuration"] + ) + + except Exception as e: + print(f"Could not save load config data: {e} ") + + return config_values + + +def configure_logging(log_name): + """Configures logging.""" + try: + now = datetime.now() # current date and time + time = now.strftime("%Y%m%d_%H%M%S") + filename = "/app/logs/" + log_name + "/" + time + ".log" + path = f"/app/logs/{log_name}" + is_dir = os.path.isdir(path) + if not is_dir: + os.makedirs(path) + logging.basicConfig( + filename=filename, + filemode="w", + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%d-%b-%y %H:%M:%S", + level=logging.INFO, + ) + except Exception as e: + print("ERROR: Could not start logging", e) + + +def transform_change_obs_readable(obs): + """Transform list of transactions to readable list of each observation property. + + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + """ + ids = [i for i in obs[:, 0]] + operating_states = [HardwareState(i).name for i in obs[:, 1]] + os_states = [SoftwareState(i).name for i in obs[:, 2]] + new_obs = [ids, operating_states, os_states] + + for service in range(3, obs.shape[1]): + # Links bit/s don't have a service state + service_states = [ + SoftwareState(i).name if i <= 4 else i for i in obs[:, service] + ] + new_obs.append(service_states) + + return new_obs + + +def transform_obs_readable(obs): + """ + Transform obs readable function. + + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]. + """ + changed_obs = transform_change_obs_readable(obs) + new_obs = list(zip(*changed_obs)) + # Convert list of tuples to list of lists + new_obs = [list(i) for i in new_obs] + + return new_obs + + +def convert_to_new_obs(obs, num_nodes=10): + """Convert original gym Box observation space to new multiDiscrete observation space.""" + # Remove ID columns, remove links and flatten to MultiDiscrete observation space + new_obs = obs[:num_nodes, 1:].flatten() + return new_obs + + +def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): + """ + Convert to old observation, links filled with 0's as no information is included in new observation space. + + example: + obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) + + new_obs = array([[ 1, 1, 1, 1], + [ 2, 1, 1, 1], + [ 3, 1, 1, 1], + ... + [20, 0, 0, 0]]) + """ + # Convert back to more readable, original format + reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) + + # Add empty links back and add node ID back + s = np.zeros( + [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) + s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back + s[:num_nodes, 1:] = reshaped_nodes # put values back in + new_obs = s + + # Add links back in + links = obs[-num_links:] + # Links will be added to the last protocol/service slot but they are not specific to that service + new_obs[num_nodes:, -1] = links + + return new_obs + + +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): + """Return string describing change between two observations. + + example: + obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) + obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) + output = 'ID 1: SERVICE 2 set to GOOD' + """ + obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) + obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) + list_of_changes = [] + for n, row in enumerate(obs1 - obs2): + if row.any() != 0: + relevant_changes = np.where(row != 0, obs2[n], -1) + relevant_changes[0] = obs2[n, 0] # ID is always relevant + is_link = relevant_changes[0] > num_nodes + desc = _describe_obs_change_helper(relevant_changes, is_link) + list_of_changes.append(desc) + + change_string = "\n ".join(list_of_changes) + if len(list_of_changes) > 0: + change_string = "\n " + change_string + return change_string + + +def _describe_obs_change_helper(obs_change, is_link): + """ + Helper funcion to describe what has changed. + + example: + [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" + + Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + """ + # Indexes where a change has occured, not including 0th index + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + # Node pol types, Indexes >= 3 are service nodes + node_pol_types = [ + NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) + for i in index_changed + ] + # Account for hardware states, software sattes and links + states = [ + LinkStatus(obs_change[i]).name + if is_link + else HardwareState(obs_change[i]).name + if i == 1 + else SoftwareState(obs_change[i]).name + for i in index_changed + ] + + if not is_link: + desc = f"ID {obs_change[0]}:" + for node_pol_type, state in list(zip(node_pol_types, states)): + desc = desc + " " + node_pol_type + " changed to " + state + "." + else: + desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." + + return desc + + +def transform_action_node_enum(action): + """ + Convert a node action from readable string format, to enumerated format. + + example: + [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] + """ + action_node_id = action[0] + action_node_property = NodePOLType[action[1]].value + + if action[1] == "OPERATING": + property_action = NodeHardwareAction[action[2]].value + elif action[1] == "OS" or action[1] == "SERVICE": + property_action = NodeSoftwareAction[action[2]].value + else: + property_action = 0 + + action_service_index = action[3] + + new_action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] + + return new_action + + +def transform_action_node_readable(action): + """ + Convert a node action from enumerated format to readable format. + + example: + [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + """ + action_node_property = NodePOLType(action[1]).name + + if action_node_property == "OPERATING": + property_action = NodeHardwareAction(action[2]).name + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[ + 2 + ] <= 1: + property_action = NodeSoftwareAction(action[2]).name + else: + property_action = "NONE" + + new_action = [action[0], action_node_property, property_action, action[3]] + return new_action + + +def node_action_description(action): + """Generate string describing a node-based action.""" + if isinstance(action[1], (int, np.int64)): + # transform action to readable format + action = transform_action_node_readable(action) + + node_id = action[0] + node_property = action[1] + property_action = action[2] + service_id = action[3] + + if property_action == "NONE": + return "" + if node_property == "OPERATING" or node_property == "OS": + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + elif node_property == "SERVICE": + description = ( + f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" + ) + else: + return "" + + return description + + +def transform_action_acl_readable(action): + """ + Transform an ACL action to a more readable format. + + example: + [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + """ + action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} + action_permissions = {0: "DENY", 1: "ALLOW"} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, 0 means any, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == 0: + new_action[n + 2] = "ANY" + + return new_action + + +def transform_action_acl_enum(action): + """Convert a acl action from readable string format, to enumerated format.""" + action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} + action_permissions = {"DENY": 0, "ALLOW": 1} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == "ANY": + new_action[n + 2] = 0 + + new_action = np.array(new_action) + return new_action + + +def acl_action_description(action): + """Generate string describing a acl-based action.""" + if isinstance(action[0], (int, np.int64)): + # transform action to readable format + action = transform_action_acl_readable(action) + if action[0] == "NONE": + description = "NO ACL RULE APPLIED" + else: + description = ( + f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," + f" for protocol/service index {action[4]} on port index {action[5]}" + ) + + return description + + +def get_node_of_ip(ip, node_dict): + """ + Get the node ID of an IP address. + + node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) + """ + for node_key, node_value in node_dict.items(): + node_ip = node_value.get_ip_address() + if node_ip == ip: + return node_key + + +def is_valid_node_action(action): + """Is the node action an actual valid action. + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch + - Node already being in that state (turning an ON node ON) + """ + action_r = transform_action_node_readable(action) + + node_property = action_r[1] + node_action = action_r[2] + + if node_property == "NONE": + return False + if node_action == "NONE": + return False + if node_property == "OPERATING" and node_action == "PATCHING": + # Operating State cannot PATCH + return False + if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + # Software States can only do Nothing or Patch + return False + return True + + +def is_valid_acl_action(action): + """ + Is the ACL action an actual valid action. + + Only uses information about the action to determine if the action has an effect. + + Does NOT consider: + - Trying to create identical rules + - Trying to create a rule which is a subset of another rule (caused by "ANY") + """ + action_r = transform_action_acl_readable(action) + + action_decision = action_r[0] + action_permission = action_r[1] + action_source_id = action_r[2] + action_destination_id = action_r[3] + + if action_decision == "NONE": + return False + if ( + action_source_id == action_destination_id + and action_source_id != "ANY" + and action_destination_id != "ANY" + ): + # ACL rule towards itself + return False + if action_permission == "DENY": + # DENY is unnecessary, we can create and delete allow rules instead + # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. + return False + + return True + + +def is_valid_acl_action_extra(action): + """Harsher version of valid acl actions, does not allow action.""" + if is_valid_acl_action(action) is False: + return False + + action_r = transform_action_acl_readable(action) + action_protocol = action_r[4] + action_port = action_r[5] + + # Don't allow protocols or ports to be ANY + # in the future we might want to do the opposite, and only have ANY option for ports and service + if action_protocol == "ANY": + return False + if action_port == "ANY": + return False + + return True + + +def get_new_action(old_action, action_dict): + """Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + + old_action can be either node or acl action type. + """ + for key, val in action_dict.items(): + if list(val) == list(old_action): + return key + # Not all possible actions are included in dict, only valid action are + # if action is not in the dict, its an invalid action so return 0 + return 0 + + +def get_action_description(action, action_dict): + """Get a string describing/explaining what an action is doing in words.""" + action_array = action_dict[action] + if len(action_array) == 4: + # node actions have length 4 + action_description = node_action_description(action_array) + elif len(action_array) == 6: + # acl actions have length 6 + action_description = acl_action_description(action_array) + else: + # Should never happen + action_description = "Unrecognised action" + + return action_description diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0e00c9e4..0aebf2a4 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -91,3 +91,29 @@ class FileSystemState(Enum): DESTROYED = 3 REPAIRING = 4 RESTORING = 5 + + +class NodeHardwareAction(Enum): + """Node hardware action.""" + + NONE = 0 + ON = 1 + OFF = 2 + RESET = 3 + + +class NodeSoftwareAction(Enum): + """Node software action.""" + + NONE = 0 + PATCHING = 1 + + +class LinkStatus(Enum): + """Link traffic status.""" + + NONE = 0 + LOW = 1 + MEDIUM = 2 + HIGH = 3 + OVERLOAD = 4 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 99c7c09f..1c72bba0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -15,6 +15,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -232,15 +233,9 @@ class Primaite(Env): # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa # [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa - self.action_space = spaces.MultiDiscrete( - [ - self.num_nodes, - self.ACTION_SPACE_NODE_PROPERTY_VALUES, - self.ACTION_SPACE_NODE_ACTION_VALUES, - self.num_services, - ] - ) - else: + self.action_dict = self.create_node_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) + elif self.action_type == ActionType.ACL: _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) @@ -249,16 +244,12 @@ class Primaite(Env): # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) # [0, num ports] - Port (0 = any, then 1 -> x resolving to port) - self.action_space = spaces.MultiDiscrete( - [ - self.ACTION_SPACE_ACL_ACTION_VALUES, - self.ACTION_SPACE_ACL_PERMISSION_VALUES, - self.num_nodes + 1, - self.num_nodes + 1, - self.num_services + 1, - self.num_ports + 1, - ] - ) + self.action_dict = self.create_acl_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) + else: + _LOGGER.info("Action space type ANY selected") + self.action_dict = self.create_node_and_acl_action_dict() + self.action_space = spaces.Discrete(len(self.action_dict)) # Set up a csv to store the results of the training try: @@ -1163,3 +1154,92 @@ class Primaite(Env): else: # Bad formatting pass + + def create_node_action_dict(self): + """ + Creates a dictionary mapping each possible discrete action to more readable multidiscrete action. + + Note: Only actions that have the potential to change the state exist in the mapping (except for key 0) + + example return: + {0: [1, 0, 0, 0], + 1: [1, 1, 1, 0], + 2: [1, 1, 2, 0], + 3: [1, 1, 3, 0], + 4: [1, 2, 1, 0], + 5: [1, 3, 1, 0], + ... + } + + """ + # reserve 0 action to be a nothing action + actions = {0: [1, 0, 0, 0]} + + action_key = 1 + for node in range(1, self.num_nodes + 1): + # 4 node properties (NONE, OPERATING, OS, SERVICE) + for node_property in range(4): + # Node Actions either: + # (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state + # Use MAX to ensure we get them all + for node_action in range(4): + for service_state in range(self.num_services): + action = [node, node_property, node_action, service_state] + # check to see if its a nothing aciton (has no effect) + if is_valid_node_action(action): + actions[action_key] = action + action_key += 1 + + return actions + + def create_acl_action_dict(self): + """Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.""" + # reserve 0 action to be a nothing action + actions = {0: [0, 0, 0, 0, 0, 0]} + + action_key = 1 + # 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE + for action_decision in range(3): + # 2 possible action permissions 0 = DENY, 1 = CREATE + for action_permission in range(2): + # Number of nodes + 1 (for any) + for source_ip in range(self.num_nodes + 1): + for dest_ip in range(self.num_nodes + 1): + for protocol in range(self.num_services + 1): + for port in range(self.num_ports + 1): + action = [ + action_decision, + action_permission, + source_ip, + dest_ip, + protocol, + port, + ] + # Check to see if its an action we want to include as possible i.e. not a nothing action + if is_valid_acl_action_extra(action): + actions[action_key] = action + action_key += 1 + + return actions + + def create_node_and_acl_action_dict(self): + """ + Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action. + + The dictionary contains actions of both Node and ACL action types. + + """ + node_action_dict = self.create_node_action_dict() + acl_action_dict = self.create_acl_action_dict() + + # Change node keys to not overlap with acl keys + # Only 1 nothing action (key 0) is required, remove the other + new_node_action_dict = { + k + len(acl_action_dict) - 1: v + for k, v in node_action_dict.items() + if k != 0 + } + + # Combine the Node dict and ACL dict + combined_action_dict = {**acl_action_dict, **new_node_action_dict} + return combined_action_dict diff --git a/tests/config/single_action_space_config.yaml b/tests/config/single_action_space_config.yaml new file mode 100644 index 00000000..6f6bb4e6 --- /dev/null +++ b/tests/config/single_action_space_config.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: GENERIC +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Operating State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node O/S or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py new file mode 100644 index 00000000..3ec1dc2e --- /dev/null +++ b/tests/test_single_action_space.py @@ -0,0 +1,12 @@ +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_primaite_env_from_config + + +def test_single_action_space(): + """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_lay_down_config.yaml", + ) + print("Average Reward:", env.average_reward)