diff --git a/src/primaite/VERSION b/src/primaite/VERSION index bd82b28c..4111d137 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0dev0 +2.0.0rc1 diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index c9067210..d5aceeaf 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -1,36 +1,84 @@ from abc import ABC, abstractmethod -from typing import Optional, Final, Dict, Any +from datetime import datetime +from pathlib import Path +from typing import Optional, Final, Dict, Any, Union, Tuple + +import yaml from primaite import getLogger -from primaite.config.training_config import TrainingConfig +from primaite.config.training_config import TrainingConfig, load from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) +def _get_temp_session_path(session_timestamp: datetime) -> Path: + """ + Get a temp directory session path the test session will output to. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path("./") / date_dir / session_dir + session_path.mkdir(exist_ok=True, parents=True) + + return session_path + + class AgentABC(ABC): @abstractmethod - def __init__(self, env: Primaite): - self._env: Primaite = env - self._training_config: Final[TrainingConfig] = self._env.training_config - self._lay_down_config: Dict[str, Any] = self._env.lay_down_config + def __init__( + self, + training_config_path, + lay_down_config_path + ): + self._training_config_path = training_config_path + self._training_config: Final[TrainingConfig] = load( + self._training_config_path + ) + self._lay_down_config_path = lay_down_config_path + self._env: Primaite self._agent = None + self.session_timestamp: datetime = datetime.now() + self.session_path = _get_temp_session_path(self.session_timestamp) + + self.timestamp_str = self.session_timestamp.strftime( + "%Y-%m-%d_%H-%M-%S") @abstractmethod def _setup(self): pass @abstractmethod - def learn(self, time_steps: Optional[int], episodes: Optional[int]): + def _save_checkpoint(self): pass @abstractmethod - def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): pass @abstractmethod - def load(self): + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + pass + + @abstractmethod + def _get_latest_checkpoint(self): + pass + + @classmethod + @abstractmethod + def load(cls): pass @abstractmethod @@ -44,14 +92,24 @@ class AgentABC(ABC): class DeterministicAgentABC(AgentABC): @abstractmethod - def __init__(self, env: Primaite): - self._env: Primaite = env + def __init__( + self, + training_config_path, + lay_down_config_path + ): + self._training_config_path = training_config_path + self._lay_down_config_path = lay_down_config_path + self._env: Primaite self._agent = None @abstractmethod def _setup(self): pass + @abstractmethod + def _get_latest_checkpoint(self): + pass + def learn(self, time_steps: Optional[int], episodes: Optional[int]): pass _LOGGER.warning("Deterministic agents cannot learn") @@ -60,8 +118,9 @@ class DeterministicAgentABC(AgentABC): def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): pass + @classmethod @abstractmethod - def load(self): + def load(cls): pass @abstractmethod diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d07265b4..bb0daefb 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -8,170 +8,106 @@ from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.registry import register_env +from primaite.agents.agent_abc import AgentABC from primaite.config import training_config from primaite.environment.primaite_env import Primaite -class DLFramework(Enum): - """The DL Frameworks enumeration.""" - TF = "tf" - TF2 = "tf2" - TORCH = "torch" +def _env_creator(env_config): + return Primaite( + training_config_path=env_config["training_config_path"], + lay_down_config_path=env_config["lay_down_config_path"], + transaction_list=env_config["transaction_list"], + session_path=env_config["session_path"], + timestamp_str=env_config["timestamp_str"] + ) -def env_creator(env_config): - training_config_path = env_config["training_config_path"] - lay_down_config_path = env_config["lay_down_config_path"] - return Primaite(training_config_path, lay_down_config_path, []) +class RLlibPPO(AgentABC): + def __init__( + self, + training_config_path, + lay_down_config_path + ): + super().__init__(training_config_path, lay_down_config_path) + self._ppo_config: PPOConfig + self._current_result: dict + self._setup() -def get_ppo_config( - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - framework: Optional[DLFramework] = DLFramework.TORCH -) -> PPOConfig(): - # Register environment - register_env("primaite", env_creator) + def _setup(self): + register_env("primaite", _env_creator) + self._ppo_config = PPOConfig() - # Setup PPO - config = PPOConfig() - - config_values = training_config.load(training_config_path) - - # Setup our config object to use our environment - config.environment( - env="primaite", - env_config=dict( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path + self._ppo_config.environment( + env="primaite", + env_config=dict( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=[], + session_path=self.session_path, + timestamp_str=self.timestamp_str + ) ) - ) - env_config = config_values - action_type = env_config.action_type - red_agent = env_config.red_agent_identifier + self._ppo_config.training( + train_batch_size=self._training_config.num_steps + ) + self._ppo_config.framework( + framework=self._training_config.deep_learning_framework.value + ) - if red_agent == "RANDOM" and action_type == "NODE": - config.training( - train_batch_size=6000, lr=5e-5 - ) # number of steps in a training iteration - elif red_agent == "RANDOM" and action_type != "NODE": - config.training(train_batch_size=6000, lr=5e-5) - elif red_agent == "CONFIG" and action_type == "NODE": - config.training(train_batch_size=400, lr=5e-5) - elif red_agent == "CONFIG" and action_type != "NONE": - config.training(train_batch_size=500, lr=5e-5) - else: - config.training(train_batch_size=500, lr=5e-5) + self._ppo_config.rollouts( + num_rollout_workers=1, + num_envs_per_worker=1, + horizon=self._training_config.num_steps + ) + self._agent: Algorithm = self._ppo_config.build() - # Decide if you want torch or tensorflow DL framework. Default is "tf" - config.framework(framework=framework.value) + def _save_checkpoint(self): + checkpoint_n = self._training_config.checkpoint_every_n_episodes + episode_count = self._current_result["episodes_total"] + if checkpoint_n > 0 and episode_count > 0: + if ( + (episode_count % checkpoint_n == 0) + or (episode_count == self._training_config.num_episodes) + ): + self._agent.save(self.session_path) - # Set the log level to DEBUG, INFO, WARN, or ERROR - config.debugging(seed=415, log_level="ERROR") + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + # Temporarily override train_batch_size and horizon + if time_steps: + self._ppo_config.train_batch_size = time_steps + self._ppo_config.horizon = time_steps - # Setup evaluation - # Explicitly set "explore"=False to override default - # config.evaluation( - # evaluation_interval=100, - # evaluation_duration=20, - # # evaluation_duration_unit="timesteps",) #default episodes - # evaluation_config={"explore": False}, - # ) + if not episodes: + episodes = self._training_config.num_episodes - # Setup sampling rollout workers - config.rollouts( - num_rollout_workers=4, - num_envs_per_worker=1, - horizon=128, # num parralel workiers - ) # max num steps in an episode + for i in range(episodes): + self._current_result = self._agent.train() + self._save_checkpoint() + self._agent.stop() - config.build() # Build config + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + raise NotImplementedError - return config + def _get_latest_checkpoint(self): + raise NotImplementedError + @classmethod + def load(cls): + raise NotImplementedError -def train( - num_iterations: int, - config: Optional[PPOConfig] = None, - algo: Optional[Algorithm] = None -): - """ + def save(self): + raise NotImplementedError - Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint) - """ - - start_time = time.time() - - if algo is None: - algo = config.build() - elif config is None: - config = algo.get_config() - - print(f"Algorithm type: {type(algo)}") - - # iterations are not the same as episodes. - for i in range(num_iterations): - result = algo.train() - # # Save every 10 iterations or after last iteration in training - # if (i % 100 == 0) or (i == num_iterations - 1): - print( - f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}") - # save checkpoint file - checkpoint_file = algo.save("./") - print(f"Checkpoint saved at {checkpoint_file}") - - # convert num_iterations to num_episodes - num_episodes = len( - result["hist_stats"]["episode_lengths"]) * num_iterations - # convert num_iterations to num_timesteps - num_timesteps = sum( - result["hist_stats"]["episode_lengths"] * num_iterations) - # calculate number of wins - - # train time - print(f"Training took {time.time() - start_time:.2f} seconds") - print( - f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}") - return result - - -def load_model_from_checkpoint(config, checkpoint=None): - # create an empty Algorithm - algo = config.build() - - if checkpoint is None: - # Get the checkpoint with the highest iteration number - checkpoint = get_most_recent_checkpoint(config) - - # restore the agent from the checkpoint - algo.restore(checkpoint) - - return algo - - -def get_most_recent_checkpoint(config): - """ - Get the most recent checkpoint for specified action type, red agent and algorithm - """ - - env_config = list(config.env_config.values())[0] - action_type = env_config.action_type - red_agent = env_config.red_agent_identifier - algo_name = config.algo_class.__name__ - - # Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model - relevant_checkpoints = glob.glob( - f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*" - ) - checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints] - max_checkpoint = str(max(checkpoint_numbers)) - checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint - checkpoint = ( - relevant_checkpoints[0].split("_")[0] - + "_" - + checkpoint_number_to_use - + "/rllib_checkpoint.json" - ) - - return checkpoint + def export(self): + raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 7d0fba3b..8fbbd815 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -6,30 +6,74 @@ from primaite.agents.agent_abc import AgentABC from primaite.environment.primaite_env import Primaite from stable_baselines3.ppo import MlpPolicy as PPOMlp + class SB3PPO(AgentABC): - def __init__(self, env: Primaite): - super().__init__(env) + def __init__( + self, + training_config_path, + lay_down_config_path + ): + super().__init__(training_config_path, lay_down_config_path) + self._tensorboard_log_path = self.session_path / "tensorboard_logs" + self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) def _setup(self): + self._env = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=[], + session_path=self.session_path, + timestamp_str=self.timestamp_str + ) self._agent = PPO( PPOMlp, self._env, verbose=0, - n_steps=self._training_config.num_steps + n_steps=self._training_config.num_steps, + tensorboard_log=self._tensorboard_log_path ) + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + if not time_steps: + time_steps = self._training_config.num_steps - def learn(self, time_steps: Optional[int], episodes: Optional[int]): - pass + if not episodes: + episodes = self._training_config.num_episodes - def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): - pass + for i in range(episodes): + self._agent.learn(total_timesteps=time_steps) + + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + deterministic: bool = True + ): + if not time_steps: + time_steps = self._training_config.num_steps + + if not episodes: + episodes = self._training_config.num_episodes + + for episode in range(episodes): + obs = self._env.reset() + + for step in range(time_steps): + action, _states = self._agent.predict( + obs, + deterministic=deterministic + ) + obs, rewards, done, info = self._env.step(action) def load(self): - pass + raise NotImplementedError def save(self): - pass + raise NotImplementedError def export(self): - pass \ No newline at end of file + raise NotImplementedError diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 121beb60..f28916c2 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -95,15 +95,32 @@ class VerboseLevel(Enum): class AgentFramework(Enum): NONE = 0 + "Custom Agent" SB3 = 1 + "Stable Baselines3" RLLIB = 2 + "Ray RLlib" + + +class DeepLearningFramework(Enum): + """The deep learning framework enumeration.""" + TF = "tf" + "Tensorflow" + TF2 = "tf2" + "Tensorflow 2.x" + TORCH = "torch" + "PyTorch" class RedAgentIdentifier(Enum): A2C = 1 + "Advantage Actor Critic" PPO = 2 + "Proximal Policy Optimization" HARDCODED = 3 + "Custom Agent" RANDOM = 4 + "Custom Agent" class ActionType(Enum): diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..ebee7f77 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -1,26 +1,52 @@ # Main Config File -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +# Sets which agent algorithm framework will be used: +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "NONE" (Custom Agent) +agent_framework: RLLIB + +# Sets which deep learning framework will be used. Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TORCH + +# Sets which Red Agent algo/class will be used: +# Options are: +# "A2C" (Advantage Actor Critic) +# "PPO" (Proximal Policy Optimization) +# "HARDCODED" (Custom Agent) +# "RANDOM" (Random Action) +red_agent_identifier: PPO + # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE + # Number of episodes to run per session num_episodes: 10 + # Number of time_steps per episode num_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 5 + # Time delay between steps (for generic agents) time_delay: 10 + # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING + # Determine whether to load an agent from file load_agent: False + # File path and file name of agent if you're loading one in agent_load_file: C:\[Path]\[agent_saved_filename.zip] diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index b0956d42..c2cb8db9 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Final, Union, Optional import yaml from primaite import USERS_CONFIG_DIR, getLogger +from primaite.common.enums import DeepLearningFramework from primaite.common.enums import ActionType, RedAgentIdentifier, \ AgentFramework, SessionType @@ -20,10 +21,13 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training class TrainingConfig: """The Training Config class.""" agent_framework: AgentFramework = AgentFramework.SB3 - "The agent framework." + "The AgentFramework" + + deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF + "The DeepLearningFramework." red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO - "The red agent/algo class." + "The RedAgentIdentifier.." action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -33,6 +37,10 @@ class TrainingConfig: num_steps: int = 256 "The number of steps in an episode." + + checkpoint_every_n_episodes: int = 5 + "The agent will save a checkpoint every n episodes." + observation_space: dict = field( default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]} ) @@ -148,6 +156,7 @@ class TrainingConfig: ) -> TrainingConfig: field_enum_map = { "agent_framework": AgentFramework, + "deep_learning_framework": DeepLearningFramework, "red_agent_identifier": RedAgentIdentifier, "action_type": ActionType, "session_type": SessionType @@ -155,7 +164,7 @@ class TrainingConfig: for field, enum_class in field_enum_map.items(): if field in config_dict: - config_dict[field] = enum_class[field] + config_dict[field] = enum_class[config_dict[field]] return TrainingConfig(**config_dict) @@ -219,7 +228,7 @@ def load(file_path: Union[str, Path], ) _LOGGER.error(msg) try: - return TrainingConfig.from_dict(**config) + return TrainingConfig.from_dict(config) except TypeError as e: msg = ( f"Error when creating an instance of {TrainingConfig} " diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..e0cfb119 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,7 +5,7 @@ import csv import logging from datetime import datetime from pathlib import Path -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, Final import networkx as nx import numpy as np @@ -77,6 +77,8 @@ class Primaite(Env): :param timestamp_str: The session timestamp in the format: _. """ + self.session_path: Final[Path] = session_path + self.timestamp_str: Final[str] = timestamp_str self._training_config_path = training_config_path self._lay_down_config_path = lay_down_config_path @@ -93,7 +95,7 @@ class Primaite(Env): self.transaction_list = transaction_list # The agent in use - self.agent_identifier = self.training_config.agent_identifier + self.agent_identifier = self.training_config.red_agent_identifier # Create a dictionary to hold all the nodes self.nodes: Dict[str, NodeUnion] = {} diff --git a/src/primaite/main.py b/src/primaite/main.py index ac32a018..842b9259 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -108,54 +108,6 @@ def run_stable_baselines3_ppo( env.close() -def run_stable_baselines3_a2c( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -): - """ - Run against a stable_baselines3 A2C agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = A2C.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) - - env.close() - - def _write_session_metadata_file( session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite ):