diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 3510db21..f6c30eff 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -37,6 +37,8 @@ from primaite.pol.ier import IER from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction +_LOGGER = logging.getLogger(__name__) + class Primaite(Env): """PRIMmary AI Training Evironment (Primaite) class.""" @@ -145,14 +147,12 @@ class Primaite(Env): # Open the config file and build the environment laydown try: - self.config_file = open( - "config/" + self.config_values.config_filename_use_case, "r" - ) + self.config_file = open(self.config_values.config_filename_use_case, "r") self.config_data = yaml.safe_load(self.config_file) self.load_config() except Exception: - logging.error("Could not load the environment configuration") - logging.error("Exception occured", exc_info=True) + _LOGGER.error("Could not load the environment configuration") + _LOGGER.error("Exception occured", exc_info=True) # Store the node objects as node attributes # (This is so we can access them as objects) @@ -180,8 +180,8 @@ class Primaite(Env): plt.savefig(filename, format="PNG") plt.clf() except Exception: - logging.error("Could not save network diagram") - logging.error("Exception occured", exc_info=True) + _LOGGER.error("Could not save network diagram") + _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") # Define Observation Space @@ -223,7 +223,7 @@ class Primaite(Env): # Define Action Space - depends on action space type (Node or ACL) if self.action_type == ACTION_TYPE.NODE: - logging.info("Action space type NODE selected") + _LOGGER.info("Action space type NODE selected") # Terms (for node action space): # [0, num nodes] - node ID (0 = nothing, node ID) # [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state) @@ -238,7 +238,7 @@ class Primaite(Env): ] ) else: - logging.info("Action space type ACL selected") + _LOGGER.info("Action space type ACL selected") # Terms (for ACL action space): # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) # [0, 1] - Permission (0 = DENY, 1 = ALLOW) @@ -273,10 +273,10 @@ class Primaite(Env): self.csv_writer = csv.writer(self.csv_file) self.csv_writer.writerow(header) except Exception: - logging.error( + _LOGGER.error( "Could not create csv file to hold average reward per episode" ) - logging.error("Exception occured", exc_info=True) + _LOGGER.error("Exception occured", exc_info=True) def reset(self): """ @@ -322,7 +322,7 @@ class Primaite(Env): step_info: Additional information relating to this step """ if self.step_count == 0: - print("Episode: " + str(self.episode_count) + " running") + print(f"Episode: {str(self.episode_count)}") # TEMP done = False @@ -402,7 +402,7 @@ class Primaite(Env): self.step_count, self.config_values, ) - # print("Step reward: " + str(reward)) + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -410,7 +410,7 @@ class Primaite(Env): # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - print("Average reward: " + str(self.average_reward)) + print(f" Average Reward: {str(self.average_reward)}") # Load the reward into the transaction transaction.set_reward(reward) @@ -757,7 +757,7 @@ class Primaite(Env): # Do nothing (bad formatting) pass - logging.info("Environment configuration loaded") + _LOGGER.info("Environment configuration loaded") print("Environment configuration loaded") def create_node(self, item): @@ -1090,7 +1090,7 @@ class Primaite(Env): item: A config data item representing steps info """ self.episode_steps = int(steps_info["steps"]) - logging.info("Training episodes have " + str(self.episode_steps) + " steps") + _LOGGER.info("Training episodes have " + str(self.episode_steps) + " steps") def reset_environment(self): """ diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 406cbfaa..548306b0 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -98,27 +98,27 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and final state of node (i.e. after red and blue actions) + # Need to compare initial and reference (current) state of node (i.e. at every step) if initial_node_operating_state == HARDWARE_STATE.ON: - if final_node_operating_state == HARDWARE_STATE.OFF: + if reference_node_operating_state == HARDWARE_STATE.OFF: score += config_values.off_should_be_on - elif final_node_operating_state == HARDWARE_STATE.RESETTING: + elif reference_node_operating_state == HARDWARE_STATE.RESETTING: score += config_values.resetting_should_be_on else: pass elif initial_node_operating_state == HARDWARE_STATE.OFF: - if final_node_operating_state == HARDWARE_STATE.ON: + if reference_node_operating_state == HARDWARE_STATE.ON: score += config_values.on_should_be_off - elif final_node_operating_state == HARDWARE_STATE.RESETTING: + elif reference_node_operating_state == HARDWARE_STATE.RESETTING: score += config_values.resetting_should_be_off else: pass elif initial_node_operating_state == HARDWARE_STATE.RESETTING: - if final_node_operating_state == HARDWARE_STATE.ON: + if reference_node_operating_state == HARDWARE_STATE.ON: score += config_values.on_should_be_resetting - elif final_node_operating_state == HARDWARE_STATE.OFF: + elif reference_node_operating_state == HARDWARE_STATE.OFF: score += config_values.off_should_be_resetting - elif final_node_operating_state == HARDWARE_STATE.RESETTING: + elif reference_node_operating_state == HARDWARE_STATE.RESETTING: score += config_values.resetting else: pass @@ -148,29 +148,29 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and final state of node (i.e. after red and blue actions) + # Need to compare initial and reference (current) state of node (i.e. at every step) if initial_node_os_state == SOFTWARE_STATE.GOOD: - if final_node_os_state == SOFTWARE_STATE.PATCHING: + if reference_node_os_state == SOFTWARE_STATE.PATCHING: score += config_values.patching_should_be_good - elif final_node_os_state == SOFTWARE_STATE.COMPROMISED: + elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED: score += config_values.compromised_should_be_good else: pass elif initial_node_os_state == SOFTWARE_STATE.PATCHING: - if final_node_os_state == SOFTWARE_STATE.GOOD: + if reference_node_os_state == SOFTWARE_STATE.GOOD: score += config_values.good_should_be_patching - elif final_node_os_state == SOFTWARE_STATE.COMPROMISED: + elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED: score += config_values.compromised_should_be_patching - elif final_node_os_state == SOFTWARE_STATE.PATCHING: + elif reference_node_os_state == SOFTWARE_STATE.PATCHING: score += config_values.patching else: pass elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED: - if final_node_os_state == SOFTWARE_STATE.GOOD: + if reference_node_os_state == SOFTWARE_STATE.GOOD: score += config_values.good_should_be_compromised - elif final_node_os_state == SOFTWARE_STATE.PATCHING: + elif reference_node_os_state == SOFTWARE_STATE.PATCHING: score += config_values.patching_should_be_compromised - elif final_node_os_state == SOFTWARE_STATE.COMPROMISED: + elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED: score += config_values.compromised else: pass @@ -204,46 +204,46 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and final state of node (i.e. after red and blue actions) + # Need to compare initial and reference state of node (i.e. at every step) if initial_service.get_state() == SOFTWARE_STATE.GOOD: - if final_service.get_state() == SOFTWARE_STATE.PATCHING: + if reference_service.get_state() == SOFTWARE_STATE.PATCHING: score += config_values.patching_should_be_good - elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED: + elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED: score += config_values.compromised_should_be_good - elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED: + elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED: score += config_values.overwhelmed_should_be_good else: pass elif initial_service.get_state() == SOFTWARE_STATE.PATCHING: - if final_service.get_state() == SOFTWARE_STATE.GOOD: + if reference_service.get_state() == SOFTWARE_STATE.GOOD: score += config_values.good_should_be_patching - elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED: + elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED: score += config_values.compromised_should_be_patching - elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED: + elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED: score += config_values.overwhelmed_should_be_patching - elif final_service.get_state() == SOFTWARE_STATE.PATCHING: + elif reference_service.get_state() == SOFTWARE_STATE.PATCHING: score += config_values.patching else: pass elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED: - if final_service.get_state() == SOFTWARE_STATE.GOOD: + if reference_service.get_state() == SOFTWARE_STATE.GOOD: score += config_values.good_should_be_compromised - elif final_service.get_state() == SOFTWARE_STATE.PATCHING: + elif reference_service.get_state() == SOFTWARE_STATE.PATCHING: score += config_values.patching_should_be_compromised - elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED: + elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED: score += config_values.compromised - elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED: + elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED: score += config_values.overwhelmed_should_be_compromised else: pass elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED: - if final_service.get_state() == SOFTWARE_STATE.GOOD: + if reference_service.get_state() == SOFTWARE_STATE.GOOD: score += config_values.good_should_be_overwhelmed - elif final_service.get_state() == SOFTWARE_STATE.PATCHING: + elif reference_service.get_state() == SOFTWARE_STATE.PATCHING: score += config_values.patching_should_be_overwhelmed - elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED: + elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED: score += config_values.compromised_should_be_overwhelmed - elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED: + elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED: score += config_values.overwhelmed else: pass @@ -276,67 +276,67 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and final state of node (i.e. after red and blue actions) + # Need to compare initial and reference state of node (i.e. at every step) if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD: - if final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: + if reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: score += config_values.repairing_should_be_good - elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: score += config_values.restoring_should_be_good - elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: score += config_values.corrupt_should_be_good - elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: score += config_values.destroyed_should_be_good else: pass elif initial_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: - if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD: + if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD: score += config_values.good_should_be_repairing - elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: score += config_values.restoring_should_be_repairing - elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: score += config_values.corrupt_should_be_repairing - elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: score += config_values.destroyed_should_be_repairing - elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: score += config_values.repairing else: pass elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: - if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD: + if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD: score += config_values.good_should_be_restoring - elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: score += config_values.repairing_should_be_restoring - elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: score += config_values.corrupt_should_be_restoring - elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: score += config_values.destroyed_should_be_restoring - elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: score += config_values.restoring else: pass elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: - if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD: + if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD: score += config_values.good_should_be_corrupt - elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: score += config_values.repairing_should_be_corrupt - elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: score += config_values.restoring_should_be_corrupt - elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: score += config_values.destroyed_should_be_corrupt - elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: score += config_values.corrupt else: pass elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: - if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD: + if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD: score += config_values.good_should_be_destroyed - elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: score += config_values.repairing_should_be_destroyed - elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: score += config_values.restoring_should_be_destroyed - elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: score += config_values.corrupt_should_be_destroyed - elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: + elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: score += config_values.destroyed else: pass diff --git a/src/primaite/main.py b/src/primaite/main.py index 0b5fad5a..b130b16e 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -165,7 +165,9 @@ def load_config_values(): config_values.agent_identifier = config_data["agentIdentifier"] config_values.num_episodes = int(config_data["numEpisodes"]) config_values.time_delay = int(config_data["timeDelay"]) - config_values.config_filename_use_case = config_data["configFilename"] + config_values.config_filename_use_case = ( + "config/" + config_data["configFilename"] + ) config_values.session_type = config_data["sessionType"] config_values.load_agent = bool(config_data["loadAgent"]) config_values.agent_load_file = config_data["agentLoadFile"] diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..4a0bdce1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from pathlib import Path +from typing import Final + +TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config" +"The tests config root directory." diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml new file mode 100644 index 00000000..a543f019 --- /dev/null +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -0,0 +1,125 @@ +- itemType: ACTIONS + type: NODE +- itemType: STEPS + steps: 13 +- itemType: PORTS + portsList: + - port: '21' +- itemType: SERVICES + serviceList: + - name: ftp +- itemType: NODE + id: '1' + name: node + baseType: SERVICE + nodeType: COMPUTER + priority: P1 + hardwareState: 'ON' + ipAddress: 192.168.0.1 + softwareState: GOOD + fileSystemState: GOOD + services: + - name: ftp + port: '21' + state: GOOD +- itemType: POSITION + positions: + - node: '1' + x_pos: 309 + y_pos: 78 +- itemType: RED_POL + id: '1' + startStep: 1 + endStep: 3 + targetNodeId: '1' + initiator: DIRECT + type: FILE + protocol: NA + state: CORRUPT + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '2' + startStep: 3 + endStep: 13 + targetNodeId: '1' + initiator: DIRECT + type: FILE + protocol: NA + state: GOOD + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '3' + startStep: 4 + endStep: 6 + targetNodeId: '1' + initiator: DIRECT + type: OPERATING + protocol: NA + state: 'OFF' + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '4' + startStep: 6 + endStep: 13 + targetNodeId: '1' + initiator: DIRECT + type: OPERATING + protocol: NA + state: 'ON' + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '5' + startStep: 7 + endStep: 9 + targetNodeId: '1' + initiator: DIRECT + type: SERVICE + protocol: ftp + state: COMPROMISED + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '6' + startStep: 9 + endStep: 13 + targetNodeId: '1' + initiator: DIRECT + type: SERVICE + protocol: ftp + state: GOOD + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '7' + startStep: 10 + endStep: 12 + targetNodeId: '1' + initiator: DIRECT + type: OS + protocol: NA + state: COMPROMISED + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA +- itemType: RED_POL + id: '8' + startStep: 12 + endStep: 13 + targetNodeId: '1' + initiator: DIRECT + type: OS + protocol: NA + state: GOOD + sourceNodeId: NA + sourceNodeService: NA + sourceNodeServiceState: NA diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml new file mode 100644 index 00000000..6f6bb4e6 --- /dev/null +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -0,0 +1,89 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agentIdentifier: GENERIC +# Number of episodes to run per session +numEpisodes: 1 +# Time delay between steps (for generic agents) +timeDelay: 1 +# Filename of the scenario / laydown +configFilename: one_node_states_on_off_lay_down_config.yaml +# Type of session to be run (TRAINING or EVALUATION) +sessionType: TRAINING +# Determine whether to load an agent from file +loadAgent: False +# File path and file name of agent if you're loading one in +agentLoadFile: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observationSpaceHighValue: 1000000000 + +# Reward values +# Generic +allOk: 0 +# Node Operating State +offShouldBeOn: -10 +offShouldBeResetting: -5 +onShouldBeOff: -2 +onShouldBeResetting: -5 +resettingShouldBeOn: -5 +resettingShouldBeOff: -2 +resetting: -3 +# Node O/S or Service State +goodShouldBePatching: 2 +goodShouldBeCompromised: 5 +goodShouldBeOverwhelmed: 5 +patchingShouldBeGood: -5 +patchingShouldBeCompromised: 2 +patchingShouldBeOverwhelmed: 2 +patching: -3 +compromisedShouldBeGood: -20 +compromisedShouldBePatching: -20 +compromisedShouldBeOverwhelmed: -20 +compromised: -20 +overwhelmedShouldBeGood: -20 +overwhelmedShouldBePatching: -20 +overwhelmedShouldBeCompromised: -20 +overwhelmed: -20 +# Node File System State +goodShouldBeRepairing: 2 +goodShouldBeRestoring: 2 +goodShouldBeCorrupt: 5 +goodShouldBeDestroyed: 10 +repairingShouldBeGood: -5 +repairingShouldBeRestoring: 2 +repairingShouldBeCorrupt: 2 +repairingShouldBeDestroyed: 0 +repairing: -3 +restoringShouldBeGood: -10 +restoringShouldBeRepairing: -2 +restoringShouldBeCorrupt: 1 +restoringShouldBeDestroyed: 2 +restoring: -6 +corruptShouldBeGood: -10 +corruptShouldBeRepairing: -10 +corruptShouldBeRestoring: -10 +corruptShouldBeDestroyed: 2 +corrupt: -10 +destroyedShouldBeGood: -20 +destroyedShouldBeRepairing: -20 +destroyedShouldBeRestoring: -20 +destroyedShouldBeCorrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +redIerRunning: -5 +greenIerBlocked: -10 + +# Patching / Reset durations +osPatchingDuration: 5 # The time taken to patch the OS +nodeResetDuration: 5 # The time taken to reset a node (hardware) +servicePatchingDuration: 5 # The time taken to patch a service +fileSystemRepairingLimit: 5 # The time take to repair the file system +fileSystemRestoringLimit: 5 # The time take to restore the file system +fileSystemScanningLimit: 5 # The time taken to scan the file system diff --git a/tests/conftest.py b/tests/conftest.py index 63f825c2..00f226a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1,199 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import time +from pathlib import Path +from typing import Union + +import yaml + +from primaite.common.config_values_main import config_values_main +from primaite.environment.primaite_env import Primaite + +ACTION_SPACE_NODE_VALUES = 1 +ACTION_SPACE_NODE_ACTION_VALUES = 1 + + +def _get_primaite_env_from_config( + main_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] +): + """Takes a config path and returns the created instance of Primaite.""" + + def load_config_values(): + config_values.agent_identifier = config_data["agentIdentifier"] + config_values.num_episodes = int(config_data["numEpisodes"]) + config_values.time_delay = int(config_data["timeDelay"]) + config_values.config_filename_use_case = lay_down_config_path + config_values.session_type = config_data["sessionType"] + config_values.load_agent = bool(config_data["loadAgent"]) + config_values.agent_load_file = config_data["agentLoadFile"] + # Environment + config_values.observation_space_high_value = int( + config_data["observationSpaceHighValue"] + ) + # Reward values + # Generic + config_values.all_ok = int(config_data["allOk"]) + # Node Operating State + config_values.off_should_be_on = int(config_data["offShouldBeOn"]) + config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"]) + config_values.on_should_be_off = int(config_data["onShouldBeOff"]) + config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"]) + config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"]) + config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"]) + config_values.resetting = int(config_data["resetting"]) + # Node O/S or Service State + config_values.good_should_be_patching = int(config_data["goodShouldBePatching"]) + config_values.good_should_be_compromised = int( + config_data["goodShouldBeCompromised"] + ) + config_values.good_should_be_overwhelmed = int( + config_data["goodShouldBeOverwhelmed"] + ) + config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"]) + config_values.patching_should_be_compromised = int( + config_data["patchingShouldBeCompromised"] + ) + config_values.patching_should_be_overwhelmed = int( + config_data["patchingShouldBeOverwhelmed"] + ) + config_values.patching = int(config_data["patching"]) + config_values.compromised_should_be_good = int( + config_data["compromisedShouldBeGood"] + ) + config_values.compromised_should_be_patching = int( + config_data["compromisedShouldBePatching"] + ) + config_values.compromised_should_be_overwhelmed = int( + config_data["compromisedShouldBeOverwhelmed"] + ) + config_values.compromised = int(config_data["compromised"]) + config_values.overwhelmed_should_be_good = int( + config_data["overwhelmedShouldBeGood"] + ) + config_values.overwhelmed_should_be_patching = int( + config_data["overwhelmedShouldBePatching"] + ) + config_values.overwhelmed_should_be_compromised = int( + config_data["overwhelmedShouldBeCompromised"] + ) + config_values.overwhelmed = int(config_data["overwhelmed"]) + # Node File System State + config_values.good_should_be_repairing = int( + config_data["goodShouldBeRepairing"] + ) + config_values.good_should_be_restoring = int( + config_data["goodShouldBeRestoring"] + ) + config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"]) + config_values.good_should_be_destroyed = int( + config_data["goodShouldBeDestroyed"] + ) + config_values.repairing_should_be_good = int( + config_data["repairingShouldBeGood"] + ) + config_values.repairing_should_be_restoring = int( + config_data["repairingShouldBeRestoring"] + ) + config_values.repairing_should_be_corrupt = int( + config_data["repairingShouldBeCorrupt"] + ) + config_values.repairing_should_be_destroyed = int( + config_data["repairingShouldBeDestroyed"] + ) + config_values.repairing = int(config_data["repairing"]) + config_values.restoring_should_be_good = int( + config_data["restoringShouldBeGood"] + ) + config_values.restoring_should_be_repairing = int( + config_data["restoringShouldBeRepairing"] + ) + config_values.restoring_should_be_corrupt = int( + config_data["restoringShouldBeCorrupt"] + ) + config_values.restoring_should_be_destroyed = int( + config_data["restoringShouldBeDestroyed"] + ) + config_values.restoring = int(config_data["restoring"]) + config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"]) + config_values.corrupt_should_be_repairing = int( + config_data["corruptShouldBeRepairing"] + ) + config_values.corrupt_should_be_restoring = int( + config_data["corruptShouldBeRestoring"] + ) + config_values.corrupt_should_be_destroyed = int( + config_data["corruptShouldBeDestroyed"] + ) + config_values.corrupt = int(config_data["corrupt"]) + config_values.destroyed_should_be_good = int( + config_data["destroyedShouldBeGood"] + ) + config_values.destroyed_should_be_repairing = int( + config_data["destroyedShouldBeRepairing"] + ) + config_values.destroyed_should_be_restoring = int( + config_data["destroyedShouldBeRestoring"] + ) + config_values.destroyed_should_be_corrupt = int( + config_data["destroyedShouldBeCorrupt"] + ) + config_values.destroyed = int(config_data["destroyed"]) + config_values.scanning = int(config_data["scanning"]) + # IER status + config_values.red_ier_running = int(config_data["redIerRunning"]) + config_values.green_ier_blocked = int(config_data["greenIerBlocked"]) + # Patching / Reset durations + config_values.os_patching_duration = int(config_data["osPatchingDuration"]) + config_values.node_reset_duration = int(config_data["nodeResetDuration"]) + config_values.service_patching_duration = int( + config_data["servicePatchingDuration"] + ) + config_values.file_system_repairing_limit = int( + config_data["fileSystemRepairingLimit"] + ) + config_values.file_system_restoring_limit = int( + config_data["fileSystemRestoringLimit"] + ) + config_values.file_system_scanning_limit = int( + config_data["fileSystemScanningLimit"] + ) + + config_file_main = open(main_config_path, "r") + config_data = yaml.safe_load(config_file_main) + # Create a config class + config_values = config_values_main() + # Load in config data + load_config_values() + env = Primaite(config_values, []) + config_values.num_steps = env.episode_steps + + if env.config_values.agent_identifier == "GENERIC": + run_generic(env, config_values) + + return env + + +def run_generic(env, config_values): + """Run against a generic agent.""" + # Reset the environment at the start of the episode + # env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = env.action_space.sample() + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + # env.reset() + + # env.close() diff --git a/tests/test_reward.py b/tests/test_reward.py index e69de29b..dcdd8d82 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -0,0 +1,31 @@ +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_primaite_env_from_config + + +def test_rewards_are_being_penalised_at_each_step_function(): + """ + Test that hardware state is penalised at each step. + + When the initial state is OFF compared to reference state which is ON. + """ + env = _get_primaite_env_from_config( + main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_lay_down_config.yaml", + ) + + """ + On different steps (of the 13 in total) these are the following rewards for config_6 which are activated: + File System State: goodShouldBeCorrupt = 5 (Step 3) + Hardware State: onShouldBeOff = -2 (Step 5) + Service State: goodShouldBeCompromised = 5 (Step 7) + Operating System State (Software State): goodShouldBeCompromised = 5 (Step 10) + + Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26 + Step Count: 13 + + For the 4 steps where this occurs the average reward is: + Average Reward: 2 (26 / 13) + """ + print("average reward", env.average_reward) + assert env.average_reward == 2.0