From 3670f1676644ee15a755cbd8b1a2e542f7f8a7fe Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 19 Jun 2023 20:27:08 +0100 Subject: [PATCH] #917 - Integrated both SB3 and RLlib agents into PrimaiteSession --- src/primaite/agents/agent.py | 251 ++++++++ src/primaite/agents/agent_abc.py | 132 ----- src/primaite/agents/rllib.py | 15 +- src/primaite/agents/sb3.py | 33 +- src/primaite/common/enums.py | 12 +- .../training/training_config_main.yaml | 7 + src/primaite/config/lay_down_config.py | 56 +- src/primaite/config/training_config.py | 103 ++-- src/primaite/environment/primaite_env.py | 1 - src/primaite/main.py | 534 ++++++++---------- src/primaite/primaite_session.py | 265 +++------ .../transactions/transactions_to_file.py | 1 + tests/conftest.py | 4 +- 13 files changed, 726 insertions(+), 688 deletions(-) create mode 100644 src/primaite/agents/agent.py delete mode 100644 src/primaite/agents/agent_abc.py diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py new file mode 100644 index 00000000..58158dcb --- /dev/null +++ b/src/primaite/agents/agent.py @@ -0,0 +1,251 @@ +import json +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from typing import Optional, Final, Dict, Union, List +from uuid import uuid4 + +from primaite import getLogger +from primaite.common.enums import OutputVerboseLevel +from primaite.config import lay_down_config +from primaite.config import training_config +from primaite.config.training_config import TrainingConfig +from primaite.environment.primaite_env import Primaite +from primaite.transactions.transactions_to_file import \ + write_transaction_to_file + +_LOGGER = getLogger(__name__) + + +def _get_temp_session_path(session_timestamp: datetime) -> Path: + """ + Get a temp directory session path the test session will output to. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path("./") / date_dir / session_path + session_path.mkdir(exist_ok=True, parents=True) + + return session_path + + +class AgentSessionABC(ABC): + + @abstractmethod + def __init__( + self, + training_config_path, + lay_down_config_path + ): + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Final[Union[Path]] = training_config_path + self._training_config: Final[TrainingConfig] = training_config.load( + self._training_config_path + ) + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path + self._lay_down_config: Dict = lay_down_config.load( + self._lay_down_config_path + ) + self.output_verbose_level = self._training_config.output_verbose_level + + self._env: Primaite + self._agent = None + self._transaction_list: List[Dict] = [] + self._can_learn: bool = False + self._can_evaluate: bool = False + + self._uuid = str(uuid4()) + self.session_timestamp: datetime = datetime.now() + "The session timestamp" + self.session_path = _get_temp_session_path(self.session_timestamp) + "The Session path" + self.checkpoints_path = self.session_path / "checkpoints" + "The Session checkpoints path" + + self.timestamp_str = self.session_timestamp.strftime( + "%Y-%m-%d_%H-%M-%S") + "The session timestamp as a string" + + @property + def uuid(self): + """The Agent Session UUID.""" + return self._uuid + + def _write_session_metadata_file(self): + """ + Write the ``session_metadata.json`` file. + + Creates a ``session_metadata.json`` in the ``session_path`` directory + and adds the following key/value pairs: + + - uuid: The UUID assigned to the session upon instantiation. + - start_datetime: The date & time the session started in iso format. + - end_datetime: NULL. + - total_episodes: NULL. + - total_time_steps: NULL. + - env: + - training_config: + - All training config items + - lay_down_config: + - All lay down config items + + """ + metadata_dict = { + "uuid": self.uuid, + "start_datetime": self.session_timestamp.isoformat(), + "end_datetime": None, + "total_episodes": None, + "total_time_steps": None, + "env": { + "training_config": self._training_config.to_dict( + json_serializable=True + ), + "lay_down_config": self._lay_down_config, + }, + } + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Writing Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished writing session metadata file") + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_path`` directory + with the following key/value pairs: + + - end_datetime: The date & time the session ended in iso format. + - total_episodes: The total number of training episodes completed. + - total_time_steps: The total number of training time steps completed. + """ + with open(self.session_path / "session_metadata.json", "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + metadata_dict["total_episodes"] = self._env.episode_count + metadata_dict["total_time_steps"] = self._env.total_step_count + + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Updating Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished updating session metadata file") + + @abstractmethod + def _setup(self): + if self.output_verbose_level >= OutputVerboseLevel.INFO: + _LOGGER.info( + "Welcome to the Primary-level AI Training Environment " + "(PrimAITE)" + ) + _LOGGER.debug( + f"The output directory for this agent is: {self.session_path}" + ) + self._write_session_metadata_file() + self._can_learn = True + self._can_evaluate = False + + @abstractmethod + def _save_checkpoint(self): + pass + + @abstractmethod + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs + ): + if self._can_learn: + _LOGGER.debug("Writing transactions") + write_transaction_to_file( + transaction_list=self._transaction_list, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ) + self._update_session_metadata_file() + self._can_evaluate = True + + @abstractmethod + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs + ): + pass + + @abstractmethod + def _get_latest_checkpoint(self): + pass + + @classmethod + @abstractmethod + def load(cls): + pass + + @abstractmethod + def save(self): + self._agent.save(self.session_path) + + @abstractmethod + def export(self): + pass + + +class DeterministicAgentSessionABC(AgentSessionABC): + @abstractmethod + def __init__( + self, + training_config_path, + lay_down_config_path + ): + self._training_config_path = training_config_path + self._lay_down_config_path = lay_down_config_path + self._env: Primaite + self._agent = None + + @abstractmethod + def _setup(self): + pass + + @abstractmethod + def _get_latest_checkpoint(self): + pass + + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + pass + + @classmethod + @abstractmethod + def load(cls): + pass + + @abstractmethod + def save(self): + pass + + @abstractmethod + def export(self): + pass diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py deleted file mode 100644 index d5aceeaf..00000000 --- a/src/primaite/agents/agent_abc.py +++ /dev/null @@ -1,132 +0,0 @@ -from abc import ABC, abstractmethod -from datetime import datetime -from pathlib import Path -from typing import Optional, Final, Dict, Any, Union, Tuple - -import yaml - -from primaite import getLogger -from primaite.config.training_config import TrainingConfig, load -from primaite.environment.primaite_env import Primaite - -_LOGGER = getLogger(__name__) - - -def _get_temp_session_path(session_timestamp: datetime) -> Path: - """ - Get a temp directory session path the test session will output to. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path("./") / date_dir / session_dir - session_path.mkdir(exist_ok=True, parents=True) - - return session_path - - -class AgentABC(ABC): - - @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): - self._training_config_path = training_config_path - self._training_config: Final[TrainingConfig] = load( - self._training_config_path - ) - self._lay_down_config_path = lay_down_config_path - self._env: Primaite - self._agent = None - self.session_timestamp: datetime = datetime.now() - self.session_path = _get_temp_session_path(self.session_timestamp) - - self.timestamp_str = self.session_timestamp.strftime( - "%Y-%m-%d_%H-%M-%S") - - @abstractmethod - def _setup(self): - pass - - @abstractmethod - def _save_checkpoint(self): - pass - - @abstractmethod - def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None - ): - pass - - @abstractmethod - def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None - ): - pass - - @abstractmethod - def _get_latest_checkpoint(self): - pass - - @classmethod - @abstractmethod - def load(cls): - pass - - @abstractmethod - def save(self): - pass - - @abstractmethod - def export(self): - pass - - -class DeterministicAgentABC(AgentABC): - @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): - self._training_config_path = training_config_path - self._lay_down_config_path = lay_down_config_path - self._env: Primaite - self._agent = None - - @abstractmethod - def _setup(self): - pass - - @abstractmethod - def _get_latest_checkpoint(self): - pass - - def learn(self, time_steps: Optional[int], episodes: Optional[int]): - pass - _LOGGER.warning("Deterministic agents cannot learn") - - @abstractmethod - def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): - pass - - @classmethod - @abstractmethod - def load(cls): - pass - - @abstractmethod - def save(self): - pass - - @abstractmethod - def export(self): - pass diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index bb0daefb..80318499 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -8,7 +8,7 @@ from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.registry import register_env -from primaite.agents.agent_abc import AgentABC +from primaite.agents.agent import AgentSessionABC from primaite.config import training_config from primaite.environment.primaite_env import Primaite @@ -23,7 +23,7 @@ def _env_creator(env_config): ) -class RLlibPPO(AgentABC): +class RLlibPPO(AgentSessionABC): def __init__( self, @@ -34,8 +34,10 @@ class RLlibPPO(AgentABC): self._ppo_config: PPOConfig self._current_result: dict self._setup() + self._agent.save() def _setup(self): + super()._setup() register_env("primaite", _env_creator) self._ppo_config = PPOConfig() @@ -72,12 +74,13 @@ class RLlibPPO(AgentABC): (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes) ): - self._agent.save(self.session_path) + self._agent.save(self.checkpoints_path) def learn( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): # Temporarily override train_batch_size and horizon if time_steps: @@ -91,11 +94,13 @@ class RLlibPPO(AgentABC): self._current_result = self._agent.train() self._save_checkpoint() self._agent.stop() + super().learn() def evaluate( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 8fbbd815..6e6d8a5d 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,13 +1,14 @@ from typing import Optional +import numpy as np from stable_baselines3 import PPO -from primaite.agents.agent_abc import AgentABC +from primaite.agents.agent import AgentSessionABC from primaite.environment.primaite_env import Primaite from stable_baselines3.ppo import MlpPolicy as PPOMlp -class SB3PPO(AgentABC): +class SB3PPO(AgentSessionABC): def __init__( self, training_config_path, @@ -16,8 +17,10 @@ class SB3PPO(AgentABC): super().__init__(training_config_path, lay_down_config_path) self._tensorboard_log_path = self.session_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) + self._setup() def _setup(self): + super()._setup() self._env = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, @@ -28,15 +31,30 @@ class SB3PPO(AgentABC): self._agent = PPO( PPOMlp, self._env, - verbose=0, + verbose=1, n_steps=self._training_config.num_steps, tensorboard_log=self._tensorboard_log_path ) + def _save_checkpoint(self): + checkpoint_n = self._training_config.checkpoint_every_n_episodes + episode_count = self._env.episode_count + if checkpoint_n > 0 and episode_count > 0: + if ( + (episode_count % checkpoint_n == 0) + or (episode_count == self._training_config.num_episodes) + ): + self._agent.save( + self.checkpoints_path / f"sb3ppo_{episode_count}.zip") + + def _get_latest_checkpoint(self): + pass + def learn( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): if not time_steps: time_steps = self._training_config.num_steps @@ -46,12 +64,15 @@ class SB3PPO(AgentABC): for i in range(episodes): self._agent.learn(total_timesteps=time_steps) + self._save_checkpoint() + super().learn() def evaluate( self, time_steps: Optional[int] = None, episodes: Optional[int] = None, - deterministic: bool = True + deterministic: bool = True, + **kwargs ): if not time_steps: time_steps = self._training_config.num_steps @@ -67,6 +88,8 @@ class SB3PPO(AgentABC): obs, deterministic=deterministic ) + if isinstance(action, np.ndarray): + action = np.int64(action) obs, rewards, done, info = self._env.step(action) def load(self): diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index f28916c2..0c787e87 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Enumerations for APE.""" -from enum import Enum +from enum import Enum, IntEnum class NodeType(Enum): @@ -172,3 +172,13 @@ class LinkStatus(Enum): MEDIUM = 2 HIGH = 3 OVERLOAD = 4 + + +class OutputVerboseLevel(IntEnum): + """The Agent output verbosity level.""" + NONE = 0 + "No Output" + INFO = 1 + "Info Messages" + ALL = 2 + "All Messages" diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index ebee7f77..703f37f5 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -54,6 +54,13 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip] # The high value for the observation space observation_space_high_value: 1000000000 +# The Agent output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages) +# "ALL" (All Messages) +output_verbose_level: INFO + # Reward values # Generic all_ok: 0 diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 4fd2142e..49a33d6e 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,21 +1,63 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from pathlib import Path -from typing import Final +from typing import Final, Union, Dict, Any import networkx +import yaml from primaite import USERS_CONFIG_DIR, getLogger _LOGGER = getLogger(__name__) -_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +_EXAMPLE_LAY_DOWN: Final[ + Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" -# class LayDownConfig: -# network: networkx.Graph -# POL -# EIR -# ACL +def convert_legacy_lay_down_config_dict( + legacy_config_dict: Dict[str, Any] +) -> Dict[str, Any]: + """ + Convert a legacy lay down config dict to the new format. + + :param legacy_config_dict: A legacy lay down config dict. + """ + _LOGGER.warning("Legacy lay down config conversion not yet implemented") + return legacy_config_dict + + +def load( + file_path: Union[str, Path], + legacy_file: bool = False +) -> Dict: + """ + Read in a lay down config yaml file. + + :param file_path: The config file path. + :param legacy_file: True if the config file is legacy format, otherwise + False. + :return: The lay down config as a dict. + :raises ValueError: If the file_path does not exist. + """ + if not isinstance(file_path, Path): + file_path = Path(file_path) + if file_path.exists(): + with open(file_path, "r") as file: + config = yaml.safe_load(file) + _LOGGER.debug(f"Loading lay down config file: {file_path}") + if legacy_file: + try: + config = convert_legacy_lay_down_config_dict(config) + except KeyError: + msg = ( + f"Failed to convert lay down config file {file_path} " + f"from legacy format. Attempting to use file as is." + ) + _LOGGER.error(msg) + return config + msg = f"Cannot load the lay down config as it does not exist: {file_path}" + _LOGGER.error(msg) + raise ValueError(msg) + def ddos_basic_one_config_path() -> Path: """ diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index c2cb8db9..0d39f9c4 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -10,11 +10,27 @@ import yaml from primaite import USERS_CONFIG_DIR, getLogger from primaite.common.enums import DeepLearningFramework from primaite.common.enums import ActionType, RedAgentIdentifier, \ - AgentFramework, SessionType + AgentFramework, SessionType, OutputVerboseLevel _LOGGER = getLogger(__name__) -_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" +_EXAMPLE_TRAINING: Final[ + Path] = USERS_CONFIG_DIR / "example_config" / "training" + + +def main_training_config_path() -> Path: + """ + The path to the example training_config_main.yaml file. + + :return: The file path. + """ + path = _EXAMPLE_TRAINING / "training_config_main.yaml" + if not path.exists(): + msg = "Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path @dataclass() @@ -24,44 +40,47 @@ class TrainingConfig: "The AgentFramework" deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF - "The DeepLearningFramework." + "The DeepLearningFramework" red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO - "The RedAgentIdentifier.." + "The RedAgentIdentifier" action_type: ActionType = ActionType.ANY - "The ActionType to use." + "The ActionType to use" num_episodes: int = 10 - "The number of episodes to train over." + "The number of episodes to train over" num_steps: int = 256 - "The number of steps in an episode." + "The number of steps in an episode" checkpoint_every_n_episodes: int = 5 - "The agent will save a checkpoint every n episodes." + "The agent will save a checkpoint every n episodes" observation_space: dict = field( default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]} ) - "The observation space config dict." + "The observation space config dict" time_delay: int = 10 - "The delay between steps (ms). Applies to generic agents only." + "The delay between steps (ms). Applies to generic agents only" # file session_type: SessionType = SessionType.TRAINING - "The type of PrimAITE session to run." + "The type of PrimAITE session to run" load_agent: str = False - "Determine whether to load an agent from file." + "Determine whether to load an agent from file" agent_load_file: Optional[str] = None - "File path and file name of agent if you're loading one in." + "File path and file name of agent if you're loading one in" # Environment observation_space_high_value: int = 1000000000 - "The high value for the observation space." + "The high value for the observation space" + + output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO + "The Agent output verbosity level" # Reward values # Generic @@ -126,28 +145,28 @@ class TrainingConfig: # Patching / Reset durations os_patching_duration: int = 5 - "The time taken to patch the OS." + "The time taken to patch the OS" node_reset_duration: int = 5 - "The time taken to reset a node (hardware)." + "The time taken to reset a node (hardware)" node_booting_duration: int = 3 - "The Time taken to turn on the node." + "The Time taken to turn on the node" node_shutdown_duration: int = 2 - "The time taken to turn off the node." + "The time taken to turn off the node" service_patching_duration: int = 5 - "The time taken to patch a service." + "The time taken to patch a service" file_system_repairing_limit: int = 5 - "The time take to repair the file system." + "The time take to repair the file system" file_system_restoring_limit: int = 5 - "The time take to restore the file system." + "The time take to restore the file system" file_system_scanning_limit: int = 5 - "The time taken to scan the file system." + "The time taken to scan the file system" @classmethod def from_dict( @@ -157,9 +176,10 @@ class TrainingConfig: field_enum_map = { "agent_framework": AgentFramework, "deep_learning_framework": DeepLearningFramework, - "red_agent_identifier": RedAgentIdentifier, - "action_type": ActionType, - "session_type": SessionType + "red_agent_identifier": RedAgentIdentifier, + "action_type": ActionType, + "session_type": SessionType, + "output_verbose_level": OutputVerboseLevel } for field, enum_class in field_enum_map.items(): @@ -178,28 +198,19 @@ class TrainingConfig: """ data = self.__dict__ if json_serializable: + data["agent_framework"] = self.agent_framework.value + data["deep_learning_framework"] = self.deep_learning_framework.value + data["red_agent_identifier"] = self.red_agent_identifier.value data["action_type"] = self.action_type.value + data["output_verbose_level"] = self.output_verbose_level.value return data -def main_training_config_path() -> Path: - """ - The path to the example training_config_main.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_TRAINING / "training_config_main.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path - - -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load( + file_path: Union[str, Path], + legacy_file: bool = False +) -> TrainingConfig: """ Read in a training config yaml file. @@ -246,7 +257,8 @@ def convert_legacy_training_config_dict( agent_framework: AgentFramework = AgentFramework.SB3, red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO, action_type: ActionType = ActionType.ANY, - num_steps: int = 256 + num_steps: int = 256, + output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -260,13 +272,16 @@ def convert_legacy_training_config_dict( don't have action_type values. :param num_steps: The number of steps to set as legacy training configs don't have num_steps values. + :param output_verbose_level: The agent output verbose level to use as + legacy training configs don't have output_verbose_level values. :return: The converted training config dict. """ config_dict = { "agent_framework": agent_framework.name, "red_agent_identifier": red_agent_identifier.name, "action_type": action_type.name, - "num_steps": num_steps + "num_steps": num_steps, + "output_verbose_level": output_verbose_level } for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e0cfb119..68209713 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -435,7 +435,6 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes - if self.training_config.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: diff --git a/src/primaite/main.py b/src/primaite/main.py index 842b9259..8619dc57 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,305 +1,229 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -""" -The main PrimAITE session runner module. - -TODO: This will eventually be refactored out into a proper Session class. -TODO: The passing about of session_dir and timestamp_str is temporary and - will be cleaned up once we move to a proper Session class. -""" -import argparse -import json -import time -from datetime import datetime -from pathlib import Path -from typing import Final, Union -from uuid import uuid4 - -from stable_baselines3 import A2C, PPO -from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.ppo import MlpPolicy as PPOMlp - -from primaite import SESSIONS_DIR, getLogger -from primaite.config.training_config import TrainingConfig -from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file - -_LOGGER = getLogger(__name__) - - -def run_generic(env: Primaite, config_values: TrainingConfig): - """ - Run against a generic agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - """ - for episode in range(0, config_values.num_episodes): - env.reset() - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - action = env.action_space.sample() - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - - # Reset the environment at the end of the episode - - env.close() - - -def run_stable_baselines3_ppo( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -): - """ - Run against a stable_baselines3 PPO agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = PPO.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) - - env.close() - - -def _write_session_metadata_file( - session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite -): - """ - Write the ``session_metadata.json`` file. - - Creates a ``session_metadata.json`` in the ``session_dir`` directory - and adds the following key/value pairs: - - - uuid: The UUID assigned to the session upon instantiation. - - start_datetime: The date & time the session started in iso format. - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - - env: - - training_config: - - All training config items - - lay_down_config: - - All lay down config items - - """ - metadata_dict = { - "uuid": uuid, - "start_datetime": session_timestamp.isoformat(), - "end_datetime": None, - "total_episodes": None, - "total_time_steps": None, - "env": { - "training_config": env.training_config.to_dict(json_serializable=True), - "lay_down_config": env.lay_down_config, - }, - } - filepath = session_dir / "session_metadata.json" - _LOGGER.debug(f"Writing Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - - -def _update_session_metadata_file(session_dir: Path, env: Primaite): - """ - Update the ``session_metadata.json`` file. - - Updates the `session_metadata.json`` in the ``session_dir`` directory - with the following key/value pairs: - - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - """ - with open(session_dir / "session_metadata.json", "r") as file: - metadata_dict = json.load(file) - - metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = env.episode_count - metadata_dict["total_time_steps"] = env.total_step_count - - filepath = session_dir / "session_metadata.json" - _LOGGER.debug(f"Updating Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - - -def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): - """ - Persist an agent. - - Only works for stable baselines3 agents at present. - - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if not isinstance(agent, OnPolicyAlgorithm): - msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." - _LOGGER.error(msg) - else: - filepath = session_path / f"agent_saved_{timestamp_str}" - agent.save(filepath) - _LOGGER.debug(f"Trained agent saved as: {filepath}") - - -def _get_session_path(session_timestamp: datetime) -> Path: - """ - Get the directory path the session will output to. - - This is set in the format of: - ~/primaite/sessions//_. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_dir - session_path.mkdir(exist_ok=True, parents=True) - - return session_path - - -def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): - """Run the PrimAITE Session. - - :param training_config_path: The training config filepath. - :param lay_down_config_path: The lay down config filepath. - """ - # Welcome message - print("Welcome to the Primary-level AI Training Environment (PrimAITE)") - uuid = str(uuid4()) - session_timestamp: Final[datetime] = datetime.now() - session_dir = _get_session_path(session_timestamp) - timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - - print(f"The output directory for this session is: {session_dir}") - - # Create a list of transactions - # A transaction is an object holding the: - # - episode # - # - step # - # - initial observation space - # - action - # - reward - # - new observation space - transaction_list = [] - - # Create the Primaite environment - env = Primaite( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path, - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Writing Session Metadata file...") - - _write_session_metadata_file( - session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env - ) - - config_values = env.training_config - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = env.episode_steps - - # Run environment against an agent - if config_values.agent_identifier == "GENERIC": - run_generic(env=env, config_values=config_values) - elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": - run_stable_baselines3_ppo( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": - run_stable_baselines3_a2c( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Session finished") - _LOGGER.debug("Session finished") - - print("Saving transaction logs...") - write_transaction_to_file( - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Updating Session Metadata file...") - _update_session_metadata_file(session_dir=session_dir, env=env) - - print("Finished") - _LOGGER.debug("Finished") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--tc") - parser.add_argument("--ldc") - args = parser.parse_args() - if not args.tc: - _LOGGER.error( - "Please provide a training config file using the --tc " "argument" - ) - if not args.ldc: - _LOGGER.error( - "Please provide a lay down config file using the --ldc " "argument" - ) - run(training_config_path=args.tc, lay_down_config_path=args.ldc) - - +# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# """ +# The main PrimAITE session runner module. +# +# TODO: This will eventually be refactored out into a proper Session class. +# TODO: The passing about of session_path and timestamp_str is temporary and +# will be cleaned up once we move to a proper Session class. +# """ +# import argparse +# import json +# import time +# from datetime import datetime +# from pathlib import Path +# from typing import Final, Union +# from uuid import uuid4 +# +# from stable_baselines3 import A2C, PPO +# from stable_baselines3.common.evaluation import evaluate_policy +# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +# from stable_baselines3.ppo import MlpPolicy as PPOMlp +# +# from primaite import SESSIONS_DIR, getLogger +# from primaite.config.training_config import TrainingConfig +# from primaite.environment.primaite_env import Primaite +# from primaite.transactions.transactions_to_file import \ +# write_transaction_to_file +# +# _LOGGER = getLogger(__name__) +# +# +# def run_generic(env: Primaite, config_values: TrainingConfig): +# """ +# Run against a generic agent. +# +# :param env: An instance of +# :class:`~primaite.environment.primaite_env.Primaite`. +# :param config_values: An instance of +# :class:`~primaite.config.training_config.TrainingConfig`. +# """ +# for episode in range(0, config_values.num_episodes): +# env.reset() +# for step in range(0, config_values.num_steps): +# # Send the observation space to the agent to get an action +# # TEMP - random action for now +# # action = env.blue_agent_action(obs) +# action = env.action_space.sample() +# +# # Run the simulation step on the live environment +# obs, reward, done, info = env.step(action) +# +# # Break if done is True +# if done: +# break +# +# # Introduce a delay between steps +# time.sleep(config_values.time_delay / 1000) +# +# # Reset the environment at the end of the episode +# +# env.close() +# +# +# def run_stable_baselines3_ppo( +# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str +# ): +# """ +# Run against a stable_baselines3 PPO agent. +# +# :param env: An instance of +# :class:`~primaite.environment.primaite_env.Primaite`. +# :param config_values: An instance of +# :class:`~primaite.config.training_config.TrainingConfig`. +# :param session_path: The directory path the session is writing to. +# :param timestamp_str: The session timestamp in the format: +# _. +# """ +# if config_values.load_agent: +# try: +# agent = PPO.load( +# config_values.agent_load_file, +# env, +# verbose=0, +# n_steps=config_values.num_steps, +# ) +# except Exception: +# print( +# "ERROR: Could not load agent at location: " +# + config_values.agent_load_file +# ) +# _LOGGER.error("Could not load agent") +# _LOGGER.error("Exception occured", exc_info=True) +# else: +# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) +# +# if config_values.session_type == "TRAINING": +# # We're in a training session +# print("Starting training session...") +# _LOGGER.debug("Starting training session...") +# for episode in range(config_values.num_episodes): +# agent.learn(total_timesteps=config_values.num_steps) +# _save_agent(agent, session_path, timestamp_str) +# else: +# # Default to being in an evaluation session +# print("Starting evaluation session...") +# _LOGGER.debug("Starting evaluation session...") +# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) +# +# env.close() +# +# +# +# +# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): +# """ +# Persist an agent. +# +# Only works for stable baselines3 agents at present. +# +# :param session_path: The directory path the session is writing to. +# :param timestamp_str: The session timestamp in the format: +# _. +# """ +# if not isinstance(agent, OnPolicyAlgorithm): +# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." +# _LOGGER.error(msg) +# else: +# filepath = session_path / f"agent_saved_{timestamp_str}" +# agent.save(filepath) +# _LOGGER.debug(f"Trained agent saved as: {filepath}") +# +# +# +# +# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): +# """Run the PrimAITE Session. +# +# :param training_config_path: The training config filepath. +# :param lay_down_config_path: The lay down config filepath. +# """ +# # Welcome message +# print("Welcome to the Primary-level AI Training Environment (PrimAITE)") +# uuid = str(uuid4()) +# session_timestamp: Final[datetime] = datetime.now() +# session_path = _get_session_path(session_timestamp) +# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") +# +# print(f"The output directory for this session is: {session_path}") +# +# # Create a list of transactions +# # A transaction is an object holding the: +# # - episode # +# # - step # +# # - initial observation space +# # - action +# # - reward +# # - new observation space +# transaction_list = [] +# +# # Create the Primaite environment +# env = Primaite( +# training_config_path=training_config_path, +# lay_down_config_path=lay_down_config_path, +# transaction_list=transaction_list, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# +# print("Writing Session Metadata file...") +# +# _write_session_metadata_file( +# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env +# ) +# +# config_values = env.training_config +# +# # Get the number of steps (which is stored in the child config file) +# config_values.num_steps = env.episode_steps +# +# # Run environment against an agent +# if config_values.agent_identifier == "GENERIC": +# run_generic(env=env, config_values=config_values) +# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": +# run_stable_baselines3_ppo( +# env=env, +# config_values=config_values, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": +# run_stable_baselines3_a2c( +# env=env, +# config_values=config_values, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# +# print("Session finished") +# _LOGGER.debug("Session finished") +# +# print("Saving transaction logs...") +# write_transaction_to_file( +# transaction_list=transaction_list, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# +# print("Updating Session Metadata file...") +# _update_session_metadata_file(session_path=session_path, env=env) +# +# print("Finished") +# _LOGGER.debug("Finished") +# +# +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument("--tc") +# parser.add_argument("--ldc") +# args = parser.parse_args() +# if not args.tc: +# _LOGGER.error( +# "Please provide a training config file using the --tc " "argument" +# ) +# if not args.ldc: +# _LOGGER.error( +# "Please provide a lay down config file using the --ldc " "argument" +# ) +# run(training_config_path=args.tc, lay_down_config_path=args.ldc) +# +# diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 0efc0acf..8f3380c8 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -3,12 +3,16 @@ from __future__ import annotations import json from datetime import datetime from pathlib import Path -from typing import Final, Optional, Union +from typing import Final, Optional, Union, Dict from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR +from primaite.agents.agent import AgentSessionABC +from primaite.agents.rllib import RLlibPPO +from primaite.agents.sb3 import SB3PPO from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ ActionType +from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite @@ -26,8 +30,8 @@ def _get_session_path(session_timestamp: datetime) -> Path: :return: The session directory path. """ date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_dir + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = SESSIONS_DIR / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) _LOGGER.debug(f"Created PrimAITE Session path: {session_path}") @@ -45,211 +49,100 @@ class PrimaiteSession: if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path + self._training_config: Final[TrainingConfig] = training_config.load( + self._training_config_path + ) if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path - - self._auto: Final[bool] = auto - - self._uuid: str = str(uuid4()) - self._session_timestamp: Final[datetime] = datetime.now() - self._session_path: Final[Path] = _get_session_path( - self._session_timestamp + self._lay_down_config: Dict = lay_down_config.load( + self._lay_down_config_path ) - self._timestamp_str: Final[str] = self._session_timestamp.strftime( - "%Y-%m-%d_%H-%M-%S") - self._metadata_path = self._session_path / "session_metadata.json" - - self._env = None - self._training_config: TrainingConfig - self._can_learn: bool = False - _LOGGER.debug("") + self._auto: bool = auto + self._agent_session: AgentSessionABC = None # noqa if self._auto: self.setup() self.learn() - @property - def uuid(self): - """The session UUID.""" - return self._uuid - - def _setup_primaite_env(self, transaction_list: Optional[list] = None): - if not transaction_list: - transaction_list = [] - self._env: Primaite = Primaite( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - transaction_list=transaction_list, - session_path=self._session_path, - timestamp_str=self._timestamp_str - ) - self._training_config: TrainingConfig = self._env.training_config - - def _write_session_metadata_file(self): - """ - Write the ``session_metadata.json`` file. - - Creates a ``session_metadata.json`` in the ``session_dir`` directory - and adds the following key/value pairs: - - - uuid: The UUID assigned to the session upon instantiation. - - start_datetime: The date & time the session started in iso format. - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - - env: - - training_config: - - All training config items - - lay_down_config: - - All lay down config items - """ - metadata_dict = { - "uuid": self._uuid, - "start_datetime": self._session_timestamp.isoformat(), - "end_datetime": None, - "total_episodes": None, - "total_time_steps": None, - "env": { - "training_config": self._env.training_config.to_dict( - json_serializable=True - ), - "lay_down_config": self._env.lay_down_config, - }, - } - _LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}") - with open(self._metadata_path, "w") as file: - json.dump(metadata_dict, file) - - def _update_session_metadata_file(self): - """ - Update the ``session_metadata.json`` file. - - Updates the `session_metadata.json`` in the ``session_dir`` directory - with the following key/value pairs: - - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - """ - with open(self._metadata_path, "r") as file: - metadata_dict = json.load(file) - - metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._env.episode_count - metadata_dict["total_time_steps"] = self._env.total_step_count - - _LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}") - with open(self._metadata_path, "w") as file: - json.dump(metadata_dict, file) - def setup(self): - self._setup_primaite_env() - self._can_learn = True - pass + if self._training_config.agent_framework == AgentFramework.NONE: + if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: + # Stochastic Random Agent + raise NotImplementedError + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: + if self._training_config.action_type == ActionType.NODE: + # Deterministic Hardcoded Agent with Node Action Space + raise NotImplementedError + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + raise NotImplementedError + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + raise NotImplementedError + + else: + # Invalid RedAgentIdentifier ActionType combo + pass + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.SB3: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Stable Baselines3/Proximal Policy Optimization + self._agent_session = SB3PPO( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Stable Baselines3/Advantage Actor Critic + raise NotImplementedError + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.RLLIB: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Ray RLlib/Proximal Policy Optimization + self._agent_session = RLlibPPO( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Ray RLlib/Advantage Actor Critic + raise NotImplementedError + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + else: + # Invalid AgentFramework + pass def learn( self, - time_steps: Optional[int], - episodes: Optional[int], - iterations: Optional[int], + time_steps: Optional[int] = None, + episodes: Optional[int] = None, **kwargs ): - if self._can_learn: - # Run environment against an agent - if self._training_config.agent_framework == AgentFramework.NONE: - if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: - # Stochastic Random Agent - run_generic(env=env, config_values=config_values) - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: - if self._training_config.action_type == ActionType.NODE: - # Deterministic Hardcoded Agent with Node Action Space - pass - - elif self._training_config.action_type == ActionType.ACL: - # Deterministic Hardcoded Agent with ACL Action Space - pass - - elif self._training_config.action_type == ActionType.ANY: - # Deterministic Hardcoded Agent with ANY Action Space - pass - - else: - # Invalid RedAgentIdentifier ActionType combo - pass - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass - - elif self._training_config.agent_framework == AgentFramework.SB3: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Stable Baselines3/Proximal Policy Optimization - run_stable_baselines3_ppo( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Stable Baselines3/Advantage Actor Critic - run_stable_baselines3_a2c( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass - - elif self._training_config.agent_framework == AgentFramework.RLLIB: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Ray RLlib/Proximal Policy Optimization - pass - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Ray RLlib/Advantage Actor Critic - pass - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass - else: - # Invalid AgentFramework - pass - - print("Session finished") - _LOGGER.debug("Session finished") - - print("Saving transaction logs...") - write_transaction_to_file( - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Updating Session Metadata file...") - _update_session_metadata_file(session_dir=session_dir, env=env) - - print("Finished") - _LOGGER.debug("Finished") + self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( self, - time_steps: Optional[int], - episodes: Optional[int], + time_steps: Optional[int] = None, + episodes: Optional[int] = None, **kwargs ): - pass - - def export(self): - pass + self._agent_session.evaluate(time_steps, episodes, **kwargs) @classmethod def import_agent( diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 11e68af8..24581597 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -108,5 +108,6 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st csv_writer.writerow(csv_data) csv_file.close() + _LOGGER.debug("Finished writing transactions") except Exception: _LOGGER.error("Could not save the transaction file", exc_info=True) diff --git a/tests/conftest.py b/tests/conftest.py index f1411ba9..1bad5db0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,8 +19,8 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path: :return: The session directory path. """ date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path