diff --git a/docs/source/config.rst b/docs/source/config.rst index afd012cc..c80baa3c 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -82,203 +82,203 @@ The environment config file consists of the following attributes: Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment. -* **Generic [all_ok]** [int] +* **Generic [all_ok]** [float] The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) -* **Node Hardware State [off_should_be_on]** [int] +* **Node Hardware State [off_should_be_on]** [float] The score to give when the node should be on, but is off -* **Node Hardware State [off_should_be_resetting]** [int] +* **Node Hardware State [off_should_be_resetting]** [float] The score to give when the node should be resetting, but is off -* **Node Hardware State [on_should_be_off]** [int] +* **Node Hardware State [on_should_be_off]** [float] The score to give when the node should be off, but is on -* **Node Hardware State [on_should_be_resetting]** [int] +* **Node Hardware State [on_should_be_resetting]** [float] The score to give when the node should be resetting, but is on -* **Node Hardware State [resetting_should_be_on]** [int] +* **Node Hardware State [resetting_should_be_on]** [float] The score to give when the node should be on, but is resetting -* **Node Hardware State [resetting_should_be_off]** [int] +* **Node Hardware State [resetting_should_be_off]** [float] The score to give when the node should be off, but is resetting -* **Node Hardware State [resetting]** [int] +* **Node Hardware State [resetting]** [float] The score to give when the node is resetting -* **Node Operating System or Service State [good_should_be_patching]** [int] +* **Node Operating System or Service State [good_should_be_patching]** [float] The score to give when the state should be patching, but is good -* **Node Operating System or Service State [good_should_be_compromised]** [int] +* **Node Operating System or Service State [good_should_be_compromised]** [float] The score to give when the state should be compromised, but is good -* **Node Operating System or Service State [good_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [good_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is good -* **Node Operating System or Service State [patching_should_be_good]** [int] +* **Node Operating System or Service State [patching_should_be_good]** [float] The score to give when the state should be good, but is patching -* **Node Operating System or Service State [patching_should_be_compromised]** [int] +* **Node Operating System or Service State [patching_should_be_compromised]** [float] The score to give when the state should be compromised, but is patching -* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is patching -* **Node Operating System or Service State [patching]** [int] +* **Node Operating System or Service State [patching]** [float] The score to give when the state is patching -* **Node Operating System or Service State [compromised_should_be_good]** [int] +* **Node Operating System or Service State [compromised_should_be_good]** [float] The score to give when the state should be good, but is compromised -* **Node Operating System or Service State [compromised_should_be_patching]** [int] +* **Node Operating System or Service State [compromised_should_be_patching]** [float] The score to give when the state should be patching, but is compromised -* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is compromised -* **Node Operating System or Service State [compromised]** [int] +* **Node Operating System or Service State [compromised]** [float] The score to give when the state is compromised -* **Node Operating System or Service State [overwhelmed_should_be_good]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_good]** [float] The score to give when the state should be good, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float] The score to give when the state should be patching, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float] The score to give when the state should be compromised, but is overwhelmed -* **Node Operating System or Service State [overwhelmed]** [int] +* **Node Operating System or Service State [overwhelmed]** [float] The score to give when the state is overwhelmed -* **Node File System State [good_should_be_repairing]** [int] +* **Node File System State [good_should_be_repairing]** [float] The score to give when the state should be repairing, but is good -* **Node File System State [good_should_be_restoring]** [int] +* **Node File System State [good_should_be_restoring]** [float] The score to give when the state should be restoring, but is good -* **Node File System State [good_should_be_corrupt]** [int] +* **Node File System State [good_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is good -* **Node File System State [good_should_be_destroyed]** [int] +* **Node File System State [good_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is good -* **Node File System State [repairing_should_be_good]** [int] +* **Node File System State [repairing_should_be_good]** [float] The score to give when the state should be good, but is repairing -* **Node File System State [repairing_should_be_restoring]** [int] +* **Node File System State [repairing_should_be_restoring]** [float] The score to give when the state should be restoring, but is repairing -* **Node File System State [repairing_should_be_corrupt]** [int] +* **Node File System State [repairing_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is repairing -* **Node File System State [repairing_should_be_destroyed]** [int] +* **Node File System State [repairing_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is repairing -* **Node File System State [repairing]** [int] +* **Node File System State [repairing]** [float] The score to give when the state is repairing -* **Node File System State [restoring_should_be_good]** [int] +* **Node File System State [restoring_should_be_good]** [float] The score to give when the state should be good, but is restoring -* **Node File System State [restoring_should_be_repairing]** [int] +* **Node File System State [restoring_should_be_repairing]** [float] The score to give when the state should be repairing, but is restoring -* **Node File System State [restoring_should_be_corrupt]** [int] +* **Node File System State [restoring_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is restoring -* **Node File System State [restoring_should_be_destroyed]** [int] +* **Node File System State [restoring_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is restoring -* **Node File System State [restoring]** [int] +* **Node File System State [restoring]** [float] The score to give when the state is restoring -* **Node File System State [corrupt_should_be_good]** [int] +* **Node File System State [corrupt_should_be_good]** [float] The score to give when the state should be good, but is corrupt -* **Node File System State [corrupt_should_be_repairing]** [int] +* **Node File System State [corrupt_should_be_repairing]** [float] The score to give when the state should be repairing, but is corrupt -* **Node File System State [corrupt_should_be_restoring]** [int] +* **Node File System State [corrupt_should_be_restoring]** [float] The score to give when the state should be restoring, but is corrupt -* **Node File System State [corrupt_should_be_destroyed]** [int] +* **Node File System State [corrupt_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is corrupt -* **Node File System State [corrupt]** [int] +* **Node File System State [corrupt]** [float] The score to give when the state is corrupt -* **Node File System State [destroyed_should_be_good]** [int] +* **Node File System State [destroyed_should_be_good]** [float] The score to give when the state should be good, but is destroyed -* **Node File System State [destroyed_should_be_repairing]** [int] +* **Node File System State [destroyed_should_be_repairing]** [float] The score to give when the state should be repairing, but is destroyed -* **Node File System State [destroyed_should_be_restoring]** [int] +* **Node File System State [destroyed_should_be_restoring]** [float] The score to give when the state should be restoring, but is destroyed -* **Node File System State [destroyed_should_be_corrupt]** [int] +* **Node File System State [destroyed_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is destroyed -* **Node File System State [destroyed]** [int] +* **Node File System State [destroyed]** [float] The score to give when the state is destroyed -* **Node File System State [scanning]** [int] +* **Node File System State [scanning]** [float] The score to give when the state is scanning -* **IER Status [red_ier_running]** [int] +* **IER Status [red_ier_running]** [float] The score to give when a red agent IER is permitted to run -* **IER Status [green_ier_blocked]** [int] +* **IER Status [green_ier_blocked]** [float] The score to give when a green agent IER is prevented from running diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 4b39839a..32118597 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -257,14 +257,19 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass + @property + def _saved_agent_path(self) -> Path: + file_name = ( + f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip" + ) + return self.learning_path / file_name + + @abstractmethod def save(self): """Save the agent.""" - agent_path = ( - self.session_path - / f"{self._training_config.agent_framework}_{self._training_config.agent_identifier}_{self.timestamp_str}" - ) - _LOGGER.debug(f"Saving agent: {agent_path}") - self._agent.save(agent_path) + pass @abstractmethod def export(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index dcb1f5c5..0bc41762 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +import shutil from datetime import datetime from pathlib import Path from typing import Union +from uuid import uuid4 from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig @@ -121,9 +123,11 @@ class RLlibAgent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - self._agent.save(str(self.checkpoints_path)) + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + self._agent.save(str(self.checkpoints_path)) def learn( self, @@ -141,6 +145,7 @@ class RLlibAgent(AgentSessionABC): for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() + self.save() self._agent.stop() super().learn() @@ -167,6 +172,26 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError + def save(self, overwrite_existing: bool = True): + """Save the agent.""" + # Make temp dir to save in isolation + temp_dir = self.learning_path / str(uuid4()) + temp_dir.mkdir() + + # Save the agent to the temp dir + self._agent.save(str(temp_dir)) + + # Capture the saved Rllib checkpoint inside the temp directory + for file in temp_dir.iterdir(): + checkpoint_dir = file + break + + # Zip the folder + shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa + + # Drop the temp directory + shutil.rmtree(temp_dir) + def export(self): """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 9d295c6f..aa8e312d 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -65,11 +65,13 @@ class SB3Agent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - self._agent.save(checkpoint_path) - _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -91,6 +93,7 @@ class SB3Agent(AgentSessionABC): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() self._env.reset() + self.save() self._env.close() super().learn() @@ -134,6 +137,10 @@ class SB3Agent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError + def save(self): + """Save the agent.""" + self._agent.save(self._saved_agent_path) + def export(self): """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 113457ff..15adc4dd 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -93,58 +93,58 @@ sb3_output_verbose_level: NONE # Generic all_ok: 0 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index db8927a1..e7b701c7 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -94,64 +94,64 @@ class TrainingConfig: # Reward values # Generic - all_ok: int = 0 + all_ok: float = 0 # Node Hardware State - off_should_be_on: int = -10 - off_should_be_resetting: int = -5 - on_should_be_off: int = -2 - on_should_be_resetting: int = -5 - resetting_should_be_on: int = -5 - resetting_should_be_off: int = -2 - resetting: int = -3 + off_should_be_on: float = -0.001 + off_should_be_resetting: float = -0.0005 + on_should_be_off: float = -0.0002 + on_should_be_resetting: float = -0.0005 + resetting_should_be_on: float = -0.0005 + resetting_should_be_off: float = -0.0002 + resetting: float = -0.0003 # Node Software or Service State - good_should_be_patching: int = 2 - good_should_be_compromised: int = 5 - good_should_be_overwhelmed: int = 5 - patching_should_be_good: int = -5 - patching_should_be_compromised: int = 2 - patching_should_be_overwhelmed: int = 2 - patching: int = -3 - compromised_should_be_good: int = -20 - compromised_should_be_patching: int = -20 - compromised_should_be_overwhelmed: int = -20 - compromised: int = -20 - overwhelmed_should_be_good: int = -20 - overwhelmed_should_be_patching: int = -20 - overwhelmed_should_be_compromised: int = -20 - overwhelmed: int = -20 + good_should_be_patching: float = 0.0002 + good_should_be_compromised: float = 0.0005 + good_should_be_overwhelmed: float = 0.0005 + patching_should_be_good: float = -0.0005 + patching_should_be_compromised: float = 0.0002 + patching_should_be_overwhelmed: float = 0.0002 + patching: float = -0.0003 + compromised_should_be_good: float = -0.002 + compromised_should_be_patching: float = -0.002 + compromised_should_be_overwhelmed: float = -0.002 + compromised: float = -0.002 + overwhelmed_should_be_good: float = -0.002 + overwhelmed_should_be_patching: float = -0.002 + overwhelmed_should_be_compromised: float = -0.002 + overwhelmed: float = -0.002 # Node File System State - good_should_be_repairing: int = 2 - good_should_be_restoring: int = 2 - good_should_be_corrupt: int = 5 - good_should_be_destroyed: int = 10 - repairing_should_be_good: int = -5 - repairing_should_be_restoring: int = 2 - repairing_should_be_corrupt: int = 2 - repairing_should_be_destroyed: int = 0 - repairing: int = -3 - restoring_should_be_good: int = -10 - restoring_should_be_repairing: int = -2 - restoring_should_be_corrupt: int = 1 - restoring_should_be_destroyed: int = 2 - restoring: int = -6 - corrupt_should_be_good: int = -10 - corrupt_should_be_repairing: int = -10 - corrupt_should_be_restoring: int = -10 - corrupt_should_be_destroyed: int = 2 - corrupt: int = -10 - destroyed_should_be_good: int = -20 - destroyed_should_be_repairing: int = -20 - destroyed_should_be_restoring: int = -20 - destroyed_should_be_corrupt: int = -20 - destroyed: int = -20 - scanning: int = -2 + good_should_be_repairing: float = 0.0002 + good_should_be_restoring: float = 0.0002 + good_should_be_corrupt: float = 0.0005 + good_should_be_destroyed: float = 0.001 + repairing_should_be_good: float = -0.0005 + repairing_should_be_restoring: float = 0.0002 + repairing_should_be_corrupt: float = 0.0002 + repairing_should_be_destroyed: float = 0.0000 + repairing: float = -0.0003 + restoring_should_be_good: float = -0.001 + restoring_should_be_repairing: float = -0.0002 + restoring_should_be_corrupt: float = 0.0001 + restoring_should_be_destroyed: float = 0.0002 + restoring: float = -0.0006 + corrupt_should_be_good: float = -0.001 + corrupt_should_be_repairing: float = -0.001 + corrupt_should_be_restoring: float = -0.001 + corrupt_should_be_destroyed: float = 0.0002 + corrupt: float = -0.001 + destroyed_should_be_good: float = -0.002 + destroyed_should_be_repairing: float = -0.002 + destroyed_should_be_restoring: float = -0.002 + destroyed_should_be_corrupt: float = -0.002 + destroyed: float = -0.002 + scanning: float = -0.0002 # IER status - red_ier_running: int = -5 - green_ier_blocked: int = -10 + red_ier_running: float = -0.0005 + green_ier_blocked: float = -0.001 # Patching / Reset durations os_patching_duration: int = 5 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 03c23f93..3a40066a 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -142,10 +142,10 @@ class Primaite(Env): self.step_info = {} # Total reward - self.total_reward = 0 + self.total_reward: float = 0 # Average reward - self.average_reward = 0 + self.average_reward: float = 0 # Episode count self.episode_count = 0 @@ -283,9 +283,9 @@ class Primaite(Env): self._create_random_red_agent() # Reset counters and totals - self.total_reward = 0 + self.total_reward = 0.0 self.step_count = 0 - self.average_reward = 0 + self.average_reward = 0.0 # Update observations space and return self.update_environent_obs() diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 19094a18..e4353cb9 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -20,7 +20,7 @@ def calculate_reward_function( red_iers, step_count, config_values, -): +) -> float: """ Compares the states of the initial and final nodes/links to get a reward. @@ -33,7 +33,7 @@ def calculate_reward_function( step_count: current step config_values: Config values """ - reward_value = 0 + reward_value: float = 0.0 # For each node, compare hardware state, SoftwareState, service states for node_key, final_node in final_nodes.items(): @@ -94,7 +94,7 @@ def calculate_reward_function( return reward_value -def score_node_operating_state(final_node, initial_node, reference_node, config_values): +def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the hardware state of a node. @@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_operating_state = final_node.hardware_state reference_node_operating_state = reference_node.hardware_state @@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ return score -def score_node_os_state(final_node, initial_node, reference_node, config_values): +def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the Software State of a node. @@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_os_state = final_node.software_state reference_node_os_state = reference_node.software_state @@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) return score -def score_node_service_state(final_node, initial_node, reference_node, config_values): +def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the service state(s) of a node. @@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_services: Dict[str, Service] = final_node.services reference_node_services: Dict[str, Service] = reference_node.services @@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va return score -def score_node_file_system(final_node, initial_node, reference_node, config_values): +def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the file system state of a node. @@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu initial_node: The node before red and blue agents take effect reference_node: The node if there had been no red or blue effect """ - score = 0 + score: float = 0.0 final_node_file_system_state = final_node.file_system_state_actual reference_node_file_system_state = reference_node.file_system_state_actual diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 7db2444a..95be8115 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -31,7 +31,7 @@ class Transaction(object): "The observation space before any actions are taken" self.obs_space_post = None "The observation space after any actions are taken" - self.reward = None + self.reward: float = None "The reward value" self.action_space = None "The action space invoked by the agent" diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml index f43c151c..23cff44e 100644 --- a/tests/config/ppo_not_seeded_training_config.yaml +++ b/tests/config/ppo_not_seeded_training_config.yaml @@ -91,60 +91,60 @@ sb3_output_verbose_level: NONE # Reward values # Generic -all_ok: 0 +all_ok: 0.0000 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index ae0b0870..75ea5882 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session): # Check that the network png file exists assert (session_path / f"network_{session.timestamp_str}.png").exists() + # Check that the saved agent exists + assert session._agent_session._saved_agent_path.exists() + # Check that both the transactions and av reward csv files exist for file in session.learning_path.iterdir(): if file.suffix == ".csv":