From 8faf9d70a00edc7259709b365fa0667727d4b1f0 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 6 Jul 2023 10:07:54 +0100 Subject: [PATCH 1/7] temp --- src/primaite/agents/agent.py | 8 ++++- src/primaite/agents/rllib.py | 58 +++++++++++++++++++++++++++--------- src/primaite/agents/sb3.py | 15 ++++++---- 3 files changed, 60 insertions(+), 21 deletions(-) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 685fe776..fecf84d0 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -257,10 +257,16 @@ class AgentSessionABC(ABC): raise FileNotFoundError(msg) pass + @property + def _saved_agent_path(self) -> Path: + file_name = (f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip") + return self.learning_path / file_name @abstractmethod def save(self): """Save the agent.""" - self._agent.save(self.session_path) + pass @abstractmethod def export(self): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d851ba9c..32dc3dc0 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,9 +1,11 @@ from __future__ import annotations import json +import shutil from datetime import datetime from pathlib import Path from typing import Union +from uuid import uuid4 from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig @@ -83,8 +85,10 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result["episodes_total"] - metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] + metadata_dict["total_episodes"] = self._current_result[ + "episodes_total"] + metadata_dict["total_time_steps"] = self._current_result[ + "timesteps_total"] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -107,7 +111,8 @@ class RLlibAgent(AgentSessionABC): ), ) - self._agent_config.training(train_batch_size=self._training_config.num_steps) + self._agent_config.training( + train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") self._agent_config.rollouts( @@ -115,18 +120,21 @@ class RLlibAgent(AgentSessionABC): num_envs_per_worker=1, horizon=self._training_config.num_steps, ) - self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) + self._agent: Algorithm = self._agent_config.build( + logger_creator=_custom_log_creator(self.learning_path)) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - self._agent.save(str(self.checkpoints_path)) + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + self._agent.save(str(self.checkpoints_path)) def learn( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -136,16 +144,18 @@ class RLlibAgent(AgentSessionABC): time_steps = self._training_config.num_steps episodes = self._training_config.num_episodes - _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") + _LOGGER.info( + f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() + self.save() self._agent.stop() super().learn() def evaluate( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -162,9 +172,29 @@ class RLlibAgent(AgentSessionABC): """Load an agent from file.""" raise NotImplementedError - def save(self): + def save(self, overwrite_existing: bool = True): """Save the agent.""" - raise NotImplementedError + # Make temp dir to save in isolation + temp_dir = self.learning_path / str(uuid4()) + temp_dir.mkdir() + + # Save the agent to the temp dir + self._agent.save(str(temp_dir)) + + # Capture the saved Rllib checkpoint inside the temp directory + for file in temp_dir.iterdir(): + checkpoint_dir = file + break + + shutil.make_archive( + str(self._saved_agent_path).replace(".zip", ""), + "zip", + checkpoint_dir # noqa + ) + + # Drop the temp directory + shutil.rmtree(temp_dir) + def export(self): """Export the agent to transportable file format.""" diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index f5ac44cb..18e208e4 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -64,11 +64,13 @@ class SB3Agent(AgentSessionABC): def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count - if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - self._agent.save(checkpoint_path) - _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") + save_checkpoint = False + if checkpoint_n: + save_checkpoint = episode_count % checkpoint_n == 0 + if episode_count and save_checkpoint: + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -90,6 +92,7 @@ class SB3Agent(AgentSessionABC): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() self._env.reset() + self.save() self._env.close() super().learn() @@ -134,7 +137,7 @@ class SB3Agent(AgentSessionABC): def save(self): """Save the agent.""" - raise NotImplementedError + self._agent.save(self._saved_agent_path) def export(self): """Export the agent to transportable file format.""" From e174db5d9eac280ca472231aa11822362c5db2a8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 6 Jul 2023 10:51:34 +0100 Subject: [PATCH 2/7] Rescaled default rewards by a factor of 1/10000 --- .../training/training_config_main.yaml | 98 ++++++++--------- .../training_config_random_red_agent.yaml | 100 +++++++++--------- 2 files changed, 99 insertions(+), 99 deletions(-) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index a638fe14..7e7f239d 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -83,58 +83,58 @@ sb3_output_verbose_level: NONE # Generic all_ok: 0 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml index 96243daf..1ccc7c38 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -37,60 +37,60 @@ observation_space_high_value: 1000000000 # Reward values # Generic -all_ok: 0 +all_ok: 0.0000 # Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 +off_should_be_on: -0.001 +off_should_be_resetting: -0.0005 +on_should_be_off: -0.0002 +on_should_be_resetting: -0.0005 +resetting_should_be_on: -0.0005 +resetting_should_be_off: -0.0002 +resetting: -0.0003 # Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 +good_should_be_patching: 0.0002 +good_should_be_compromised: 0.0005 +good_should_be_overwhelmed: 0.0005 +patching_should_be_good: -0.0005 +patching_should_be_compromised: 0.0002 +patching_should_be_overwhelmed: 0.0002 +patching: -0.0003 +compromised_should_be_good: -0.002 +compromised_should_be_patching: -0.002 +compromised_should_be_overwhelmed: -0.002 +compromised: -0.002 +overwhelmed_should_be_good: -0.002 +overwhelmed_should_be_patching: -0.002 +overwhelmed_should_be_compromised: -0.002 +overwhelmed: -0.002 # Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 +good_should_be_repairing: 0.0002 +good_should_be_restoring: 0.0002 +good_should_be_corrupt: 0.0005 +good_should_be_destroyed: 0.001 +repairing_should_be_good: -0.0005 +repairing_should_be_restoring: 0.0002 +repairing_should_be_corrupt: 0.0002 +repairing_should_be_destroyed: 0.0000 +repairing: -0.0003 +restoring_should_be_good: -0.001 +restoring_should_be_repairing: -0.0002 +restoring_should_be_corrupt: 0.0001 +restoring_should_be_destroyed: 0.0002 +restoring: -0.0006 +corrupt_should_be_good: -0.001 +corrupt_should_be_repairing: -0.001 +corrupt_should_be_restoring: -0.001 +corrupt_should_be_destroyed: 0.0002 +corrupt: -0.001 +destroyed_should_be_good: -0.002 +destroyed_should_be_repairing: -0.002 +destroyed_should_be_restoring: -0.002 +destroyed_should_be_corrupt: -0.002 +destroyed: -0.002 +scanning: -0.0002 # IER status -red_ier_running: -5 -green_ier_blocked: -10 +red_ier_running: -0.0005 +green_ier_blocked: -0.001 # Patching / Reset durations os_patching_duration: 5 # The time taken to patch the OS From c5d7d5574773fe927e0d5386289a455f6b54a08d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 6 Jul 2023 12:52:14 +0100 Subject: [PATCH 3/7] Change reward to float and divide by 10000 --- src/primaite/config/training_config.py | 100 +++++++++++------------ src/primaite/environment/primaite_env.py | 8 +- src/primaite/environment/reward.py | 20 ++--- src/primaite/transactions/transaction.py | 2 +- 4 files changed, 65 insertions(+), 65 deletions(-) diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index bd73f65b..a89d8c4b 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -94,64 +94,64 @@ class TrainingConfig: # Reward values # Generic - all_ok: int = 0 + all_ok: float = 0 # Node Hardware State - off_should_be_on: int = -10 - off_should_be_resetting: int = -5 - on_should_be_off: int = -2 - on_should_be_resetting: int = -5 - resetting_should_be_on: int = -5 - resetting_should_be_off: int = -2 - resetting: int = -3 + off_should_be_on: float = -0.001 + off_should_be_resetting: float = -0.0005 + on_should_be_off: float = -0.0002 + on_should_be_resetting: float = -0.0005 + resetting_should_be_on: float = -0.0005 + resetting_should_be_off: float = -0.0002 + resetting: float = -0.0003 # Node Software or Service State - good_should_be_patching: int = 2 - good_should_be_compromised: int = 5 - good_should_be_overwhelmed: int = 5 - patching_should_be_good: int = -5 - patching_should_be_compromised: int = 2 - patching_should_be_overwhelmed: int = 2 - patching: int = -3 - compromised_should_be_good: int = -20 - compromised_should_be_patching: int = -20 - compromised_should_be_overwhelmed: int = -20 - compromised: int = -20 - overwhelmed_should_be_good: int = -20 - overwhelmed_should_be_patching: int = -20 - overwhelmed_should_be_compromised: int = -20 - overwhelmed: int = -20 + good_should_be_patching: float = 0.0002 + good_should_be_compromised: float = 0.0005 + good_should_be_overwhelmed: float = 0.0005 + patching_should_be_good: float = -0.0005 + patching_should_be_compromised: float = 0.0002 + patching_should_be_overwhelmed: float = 0.0002 + patching: float = -0.0003 + compromised_should_be_good: float = -0.002 + compromised_should_be_patching: float = -0.002 + compromised_should_be_overwhelmed: float = -0.002 + compromised: float = -0.002 + overwhelmed_should_be_good: float = -0.002 + overwhelmed_should_be_patching: float = -0.002 + overwhelmed_should_be_compromised: float = -0.002 + overwhelmed: float = -0.002 # Node File System State - good_should_be_repairing: int = 2 - good_should_be_restoring: int = 2 - good_should_be_corrupt: int = 5 - good_should_be_destroyed: int = 10 - repairing_should_be_good: int = -5 - repairing_should_be_restoring: int = 2 - repairing_should_be_corrupt: int = 2 - repairing_should_be_destroyed: int = 0 - repairing: int = -3 - restoring_should_be_good: int = -10 - restoring_should_be_repairing: int = -2 - restoring_should_be_corrupt: int = 1 - restoring_should_be_destroyed: int = 2 - restoring: int = -6 - corrupt_should_be_good: int = -10 - corrupt_should_be_repairing: int = -10 - corrupt_should_be_restoring: int = -10 - corrupt_should_be_destroyed: int = 2 - corrupt: int = -10 - destroyed_should_be_good: int = -20 - destroyed_should_be_repairing: int = -20 - destroyed_should_be_restoring: int = -20 - destroyed_should_be_corrupt: int = -20 - destroyed: int = -20 - scanning: int = -2 + good_should_be_repairing: float = 0.0002 + good_should_be_restoring: float = 0.0002 + good_should_be_corrupt: float = 0.0005 + good_should_be_destroyed: float = 0.001 + repairing_should_be_good: float = -0.0005 + repairing_should_be_restoring: float = 0.0002 + repairing_should_be_corrupt: float = 0.0002 + repairing_should_be_destroyed: float = 0.0000 + repairing: float = -0.0003 + restoring_should_be_good: float = -0.001 + restoring_should_be_repairing: float = -0.0002 + restoring_should_be_corrupt: float = 0.0001 + restoring_should_be_destroyed: float = 0.0002 + restoring: float = -0.0006 + corrupt_should_be_good: float = -0.001 + corrupt_should_be_repairing: float = -0.001 + corrupt_should_be_restoring: float = -0.001 + corrupt_should_be_destroyed: float = 0.0002 + corrupt: float = -0.001 + destroyed_should_be_good: float = -0.002 + destroyed_should_be_repairing: float = -0.002 + destroyed_should_be_restoring: float = -0.002 + destroyed_should_be_corrupt: float = -0.002 + destroyed: float = -0.002 + scanning: float = -0.0002 # IER status - red_ier_running: int = -5 - green_ier_blocked: int = -10 + red_ier_running: float = -0.0005 + green_ier_blocked: float = -0.001 # Patching / Reset durations os_patching_duration: int = 5 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 03c23f93..3a40066a 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -142,10 +142,10 @@ class Primaite(Env): self.step_info = {} # Total reward - self.total_reward = 0 + self.total_reward: float = 0 # Average reward - self.average_reward = 0 + self.average_reward: float = 0 # Episode count self.episode_count = 0 @@ -283,9 +283,9 @@ class Primaite(Env): self._create_random_red_agent() # Reset counters and totals - self.total_reward = 0 + self.total_reward = 0.0 self.step_count = 0 - self.average_reward = 0 + self.average_reward = 0.0 # Update observations space and return self.update_environent_obs() diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 19094a18..e4353cb9 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -20,7 +20,7 @@ def calculate_reward_function( red_iers, step_count, config_values, -): +) -> float: """ Compares the states of the initial and final nodes/links to get a reward. @@ -33,7 +33,7 @@ def calculate_reward_function( step_count: current step config_values: Config values """ - reward_value = 0 + reward_value: float = 0.0 # For each node, compare hardware state, SoftwareState, service states for node_key, final_node in final_nodes.items(): @@ -94,7 +94,7 @@ def calculate_reward_function( return reward_value -def score_node_operating_state(final_node, initial_node, reference_node, config_values): +def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the hardware state of a node. @@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_operating_state = final_node.hardware_state reference_node_operating_state = reference_node.hardware_state @@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ return score -def score_node_os_state(final_node, initial_node, reference_node, config_values): +def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the Software State of a node. @@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_os_state = final_node.software_state reference_node_os_state = reference_node.software_state @@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) return score -def score_node_service_state(final_node, initial_node, reference_node, config_values): +def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the service state(s) of a node. @@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va reference_node: The node if there had been no red or blue effect config_values: Config values """ - score = 0 + score: float = 0.0 final_node_services: Dict[str, Service] = final_node.services reference_node_services: Dict[str, Service] = reference_node.services @@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va return score -def score_node_file_system(final_node, initial_node, reference_node, config_values): +def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float: """ Calculates score relating to the file system state of a node. @@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu initial_node: The node before red and blue agents take effect reference_node: The node if there had been no red or blue effect """ - score = 0 + score: float = 0.0 final_node_file_system_state = final_node.file_system_state_actual reference_node_file_system_state = reference_node.file_system_state_actual diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 7db2444a..95be8115 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -31,7 +31,7 @@ class Transaction(object): "The observation space before any actions are taken" self.obs_space_post = None "The observation space after any actions are taken" - self.reward = None + self.reward: float = None "The reward value" self.action_space = None "The action space invoked by the agent" From 3b91a990708abe17cd596c3c79ba12516c387032 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 6 Jul 2023 12:56:24 +0100 Subject: [PATCH 4/7] Updated rewards type description in docs --- docs/source/config.rst | 100 ++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/docs/source/config.rst b/docs/source/config.rst index 71ade6c5..164a75e1 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -82,203 +82,203 @@ The environment config file consists of the following attributes: Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment. -* **Generic [all_ok]** [int] +* **Generic [all_ok]** [float] The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) -* **Node Hardware State [off_should_be_on]** [int] +* **Node Hardware State [off_should_be_on]** [float] The score to give when the node should be on, but is off -* **Node Hardware State [off_should_be_resetting]** [int] +* **Node Hardware State [off_should_be_resetting]** [float] The score to give when the node should be resetting, but is off -* **Node Hardware State [on_should_be_off]** [int] +* **Node Hardware State [on_should_be_off]** [float] The score to give when the node should be off, but is on -* **Node Hardware State [on_should_be_resetting]** [int] +* **Node Hardware State [on_should_be_resetting]** [float] The score to give when the node should be resetting, but is on -* **Node Hardware State [resetting_should_be_on]** [int] +* **Node Hardware State [resetting_should_be_on]** [float] The score to give when the node should be on, but is resetting -* **Node Hardware State [resetting_should_be_off]** [int] +* **Node Hardware State [resetting_should_be_off]** [float] The score to give when the node should be off, but is resetting -* **Node Hardware State [resetting]** [int] +* **Node Hardware State [resetting]** [float] The score to give when the node is resetting -* **Node Operating System or Service State [good_should_be_patching]** [int] +* **Node Operating System or Service State [good_should_be_patching]** [float] The score to give when the state should be patching, but is good -* **Node Operating System or Service State [good_should_be_compromised]** [int] +* **Node Operating System or Service State [good_should_be_compromised]** [float] The score to give when the state should be compromised, but is good -* **Node Operating System or Service State [good_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [good_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is good -* **Node Operating System or Service State [patching_should_be_good]** [int] +* **Node Operating System or Service State [patching_should_be_good]** [float] The score to give when the state should be good, but is patching -* **Node Operating System or Service State [patching_should_be_compromised]** [int] +* **Node Operating System or Service State [patching_should_be_compromised]** [float] The score to give when the state should be compromised, but is patching -* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is patching -* **Node Operating System or Service State [patching]** [int] +* **Node Operating System or Service State [patching]** [float] The score to give when the state is patching -* **Node Operating System or Service State [compromised_should_be_good]** [int] +* **Node Operating System or Service State [compromised_should_be_good]** [float] The score to give when the state should be good, but is compromised -* **Node Operating System or Service State [compromised_should_be_patching]** [int] +* **Node Operating System or Service State [compromised_should_be_patching]** [float] The score to give when the state should be patching, but is compromised -* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] +* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float] The score to give when the state should be overwhelmed, but is compromised -* **Node Operating System or Service State [compromised]** [int] +* **Node Operating System or Service State [compromised]** [float] The score to give when the state is compromised -* **Node Operating System or Service State [overwhelmed_should_be_good]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_good]** [float] The score to give when the state should be good, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float] The score to give when the state should be patching, but is overwhelmed -* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] +* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float] The score to give when the state should be compromised, but is overwhelmed -* **Node Operating System or Service State [overwhelmed]** [int] +* **Node Operating System or Service State [overwhelmed]** [float] The score to give when the state is overwhelmed -* **Node File System State [good_should_be_repairing]** [int] +* **Node File System State [good_should_be_repairing]** [float] The score to give when the state should be repairing, but is good -* **Node File System State [good_should_be_restoring]** [int] +* **Node File System State [good_should_be_restoring]** [float] The score to give when the state should be restoring, but is good -* **Node File System State [good_should_be_corrupt]** [int] +* **Node File System State [good_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is good -* **Node File System State [good_should_be_destroyed]** [int] +* **Node File System State [good_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is good -* **Node File System State [repairing_should_be_good]** [int] +* **Node File System State [repairing_should_be_good]** [float] The score to give when the state should be good, but is repairing -* **Node File System State [repairing_should_be_restoring]** [int] +* **Node File System State [repairing_should_be_restoring]** [float] The score to give when the state should be restoring, but is repairing -* **Node File System State [repairing_should_be_corrupt]** [int] +* **Node File System State [repairing_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is repairing -* **Node File System State [repairing_should_be_destroyed]** [int] +* **Node File System State [repairing_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is repairing -* **Node File System State [repairing]** [int] +* **Node File System State [repairing]** [float] The score to give when the state is repairing -* **Node File System State [restoring_should_be_good]** [int] +* **Node File System State [restoring_should_be_good]** [float] The score to give when the state should be good, but is restoring -* **Node File System State [restoring_should_be_repairing]** [int] +* **Node File System State [restoring_should_be_repairing]** [float] The score to give when the state should be repairing, but is restoring -* **Node File System State [restoring_should_be_corrupt]** [int] +* **Node File System State [restoring_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is restoring -* **Node File System State [restoring_should_be_destroyed]** [int] +* **Node File System State [restoring_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is restoring -* **Node File System State [restoring]** [int] +* **Node File System State [restoring]** [float] The score to give when the state is restoring -* **Node File System State [corrupt_should_be_good]** [int] +* **Node File System State [corrupt_should_be_good]** [float] The score to give when the state should be good, but is corrupt -* **Node File System State [corrupt_should_be_repairing]** [int] +* **Node File System State [corrupt_should_be_repairing]** [float] The score to give when the state should be repairing, but is corrupt -* **Node File System State [corrupt_should_be_restoring]** [int] +* **Node File System State [corrupt_should_be_restoring]** [float] The score to give when the state should be restoring, but is corrupt -* **Node File System State [corrupt_should_be_destroyed]** [int] +* **Node File System State [corrupt_should_be_destroyed]** [float] The score to give when the state should be destroyed, but is corrupt -* **Node File System State [corrupt]** [int] +* **Node File System State [corrupt]** [float] The score to give when the state is corrupt -* **Node File System State [destroyed_should_be_good]** [int] +* **Node File System State [destroyed_should_be_good]** [float] The score to give when the state should be good, but is destroyed -* **Node File System State [destroyed_should_be_repairing]** [int] +* **Node File System State [destroyed_should_be_repairing]** [float] The score to give when the state should be repairing, but is destroyed -* **Node File System State [destroyed_should_be_restoring]** [int] +* **Node File System State [destroyed_should_be_restoring]** [float] The score to give when the state should be restoring, but is destroyed -* **Node File System State [destroyed_should_be_corrupt]** [int] +* **Node File System State [destroyed_should_be_corrupt]** [float] The score to give when the state should be corrupt, but is destroyed -* **Node File System State [destroyed]** [int] +* **Node File System State [destroyed]** [float] The score to give when the state is destroyed -* **Node File System State [scanning]** [int] +* **Node File System State [scanning]** [float] The score to give when the state is scanning -* **IER Status [red_ier_running]** [int] +* **IER Status [red_ier_running]** [float] The score to give when a red agent IER is permitted to run -* **IER Status [green_ier_blocked]** [int] +* **IER Status [green_ier_blocked]** [float] The score to give when a green agent IER is prevented from running From 159d47fd6c4d7000176b1b49cc3860c323022bf4 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 6 Jul 2023 13:56:12 +0100 Subject: [PATCH 5/7] #1963 - Made RLlib and SB3 agents save at the end of each learning session by default using a common file naming format. Also now agents only checkpoint every n and not on the final episode --- src/primaite/agents/rllib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 32dc3dc0..ce072a03 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -186,6 +186,7 @@ class RLlibAgent(AgentSessionABC): checkpoint_dir = file break + # Zip the folder shutil.make_archive( str(self._saved_agent_path).replace(".zip", ""), "zip", From 82d7c168fe40f0e085ba9b9e737b3547cc882673 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 6 Jul 2023 14:13:02 +0100 Subject: [PATCH 6/7] #1593 - Check that agent saved file exists --- tests/test_primaite_session.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index ae0b0870..75ea5882 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session): # Check that the network png file exists assert (session_path / f"network_{session.timestamp_str}.png").exists() + # Check that the saved agent exists + assert session._agent_session._saved_agent_path.exists() + # Check that both the transactions and av reward csv files exist for file in session.learning_path.iterdir(): if file.suffix == ".csv": From c9f474165507ad5c2665e200d31d4a169a244cdc Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 6 Jul 2023 14:18:49 +0100 Subject: [PATCH 7/7] #1593 - Ran pre-commit hook --- src/primaite/agents/agent.py | 9 ++++++--- src/primaite/agents/rllib.py | 30 ++++++++++-------------------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index fecf84d0..32118597 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -259,10 +259,13 @@ class AgentSessionABC(ABC): @property def _saved_agent_path(self) -> Path: - file_name = (f"{self._training_config.agent_framework}_" - f"{self._training_config.agent_identifier}_" - f"{self.timestamp_str}.zip") + file_name = ( + f"{self._training_config.agent_framework}_" + f"{self._training_config.agent_identifier}_" + f"{self.timestamp_str}.zip" + ) return self.learning_path / file_name + @abstractmethod def save(self): """Save the agent.""" diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index ce072a03..427072c4 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -85,10 +85,8 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result[ - "episodes_total"] - metadata_dict["total_time_steps"] = self._current_result[ - "timesteps_total"] + metadata_dict["total_episodes"] = self._current_result["episodes_total"] + metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -111,8 +109,7 @@ class RLlibAgent(AgentSessionABC): ), ) - self._agent_config.training( - train_batch_size=self._training_config.num_steps) + self._agent_config.training(train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") self._agent_config.rollouts( @@ -120,8 +117,7 @@ class RLlibAgent(AgentSessionABC): num_envs_per_worker=1, horizon=self._training_config.num_steps, ) - self._agent: Algorithm = self._agent_config.build( - logger_creator=_custom_log_creator(self.learning_path)) + self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes @@ -133,8 +129,8 @@ class RLlibAgent(AgentSessionABC): self._agent.save(str(self.checkpoints_path)) def learn( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -144,8 +140,7 @@ class RLlibAgent(AgentSessionABC): time_steps = self._training_config.num_steps episodes = self._training_config.num_episodes - _LOGGER.info( - f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() @@ -154,8 +149,8 @@ class RLlibAgent(AgentSessionABC): super().learn() def evaluate( - self, - **kwargs, + self, + **kwargs, ): """ Evaluate the agent. @@ -187,16 +182,11 @@ class RLlibAgent(AgentSessionABC): break # Zip the folder - shutil.make_archive( - str(self._saved_agent_path).replace(".zip", ""), - "zip", - checkpoint_dir # noqa - ) + shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa # Drop the temp directory shutil.rmtree(temp_dir) - def export(self): """Export the agent to transportable file format.""" raise NotImplementedError