From 03ae4884e00019daa10e107c08954bd0bd6519cc Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 19 Jun 2023 21:53:25 +0100 Subject: [PATCH] #917 - Almost there. All output files being writen for SB3/RLLIB PPO & A2C. Just need to bring in the hardcoded agents then update the testa and docs. --- src/primaite/VERSION | 2 +- src/primaite/agents/agent.py | 18 +- src/primaite/agents/rllib.py | 97 ++++- src/primaite/agents/sb3.py | 34 +- src/primaite/common/enums.py | 4 +- .../lay_down_config_1_DDOS_basic.yaml | 4 - .../lay_down_config_2_DDOS_basic.yaml | 4 - .../lay_down_config_3_DOS_very_basic.yaml | 4 - .../training/training_config_main.yaml | 2 +- src/primaite/config/training_config.py | 2 +- src/primaite/environment/primaite_env.py | 17 +- src/primaite/main.py | 391 ++++++++---------- src/primaite/primaite_session.py | 62 +-- .../transactions/transactions_to_file.py | 8 +- 14 files changed, 321 insertions(+), 328 deletions(-) diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 4111d137..0da493b5 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0rc1 +2.0.0b1 \ No newline at end of file diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 58158dcb..34ad0adb 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -5,19 +5,18 @@ from pathlib import Path from typing import Optional, Final, Dict, Union, List from uuid import uuid4 -from primaite import getLogger +from primaite import getLogger, SESSIONS_DIR from primaite.common.enums import OutputVerboseLevel from primaite.config import lay_down_config from primaite.config import training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file + _LOGGER = getLogger(__name__) -def _get_temp_session_path(session_timestamp: datetime) -> Path: +def _get_session_path(session_timestamp: datetime) -> Path: """ Get a temp directory session path the test session will output to. @@ -26,7 +25,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path: """ date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path("./") / date_dir / session_path + session_path = SESSIONS_DIR / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path @@ -57,16 +56,16 @@ class AgentSessionABC(ABC): self._env: Primaite self._agent = None - self._transaction_list: List[Dict] = [] self._can_learn: bool = False self._can_evaluate: bool = False self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() "The session timestamp" - self.session_path = _get_temp_session_path(self.session_timestamp) + self.session_path = _get_session_path(self.session_timestamp) "The Session path" self.checkpoints_path = self.session_path / "checkpoints" + self.checkpoints_path.mkdir(parents=True, exist_ok=True) "The Session checkpoints path" self.timestamp_str = self.session_timestamp.strftime( @@ -167,11 +166,6 @@ class AgentSessionABC(ABC): ): if self._can_learn: _LOGGER.debug("Writing transactions") - write_transaction_to_file( - transaction_list=self._transaction_list, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ) self._update_session_metadata_file() self._can_evaluate = True diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 80318499..67ba6213 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,18 +1,23 @@ -import glob -import time -from enum import Enum +import json +from datetime import datetime from pathlib import Path -from typing import Union, Optional +from pathlib import Path +from typing import Optional from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.a2c import A2CConfig +from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env +from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.config import training_config +from primaite.common.enums import AgentFramework, RedAgentIdentifier from primaite.environment.primaite_env import Primaite +_LOGGER = getLogger(__name__) + def _env_creator(env_config): return Primaite( training_config_path=env_config["training_config_path"], @@ -23,7 +28,17 @@ def _env_creator(env_config): ) -class RLlibPPO(AgentSessionABC): +def _custom_log_creator(session_path: Path): + logdir = session_path / "ray_results" + logdir.mkdir(parents=True, exist_ok=True) + + def logger_creator(config): + return UnifiedLogger(config, logdir, loggers=None) + + return logger_creator + + +class RLlibAgent(AgentSessionABC): def __init__( self, @@ -31,17 +46,63 @@ class RLlibPPO(AgentSessionABC): lay_down_config_path ): super().__init__(training_config_path, lay_down_config_path) - self._ppo_config: PPOConfig + if not self._training_config.agent_framework == AgentFramework.RLLIB: + msg = (f"Expected RLLIB agent_framework, " + f"got {self._training_config.agent_framework}") + _LOGGER.error(msg) + raise ValueError(msg) + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + self._agent_config_class = PPOConfig + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + self._agent_config_class = A2CConfig + else: + msg = ("Expected PPO or A2C red_agent_identifier, " + f"got {self._training_config.red_agent_identifier.value}") + _LOGGER.error(msg) + raise ValueError(msg) + self._agent_config: PPOConfig + self._current_result: dict self._setup() - self._agent.save() + _LOGGER.debug( + f"Created {self.__class__.__name__} using: " + f"agent_framework={self._training_config.agent_framework}, " + f"red_agent_identifier=" + f"{self._training_config.red_agent_identifier}, " + f"deep_learning_framework=" + f"{self._training_config.deep_learning_framework}" + ) + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_path`` directory + with the following key/value pairs: + + - end_datetime: The date & time the session ended in iso format. + - total_episodes: The total number of training episodes completed. + - total_time_steps: The total number of training time steps completed. + """ + with open(self.session_path / "session_metadata.json", "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + metadata_dict["total_episodes"] = self._current_result["episodes_total"] + metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] + + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Updating Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished updating session metadata file") def _setup(self): super()._setup() register_env("primaite", _env_creator) - self._ppo_config = PPOConfig() + self._agent_config = self._agent_config_class() - self._ppo_config.environment( + self._agent_config.environment( env="primaite", env_config=dict( training_config_path=self._training_config_path, @@ -52,19 +113,21 @@ class RLlibPPO(AgentSessionABC): ) ) - self._ppo_config.training( + self._agent_config.training( train_batch_size=self._training_config.num_steps ) - self._ppo_config.framework( - framework=self._training_config.deep_learning_framework.value + self._agent_config.framework( + framework=self._training_config.deep_learning_framework ) - self._ppo_config.rollouts( + self._agent_config.rollouts( num_rollout_workers=1, num_envs_per_worker=1, horizon=self._training_config.num_steps ) - self._agent: Algorithm = self._ppo_config.build() + self._agent: Algorithm = self._agent_config.build( + logger_creator=_custom_log_creator(self.session_path) + ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes @@ -84,8 +147,8 @@ class RLlibPPO(AgentSessionABC): ): # Temporarily override train_batch_size and horizon if time_steps: - self._ppo_config.train_batch_size = time_steps - self._ppo_config.horizon = time_steps + self._agent_config.train_batch_size = time_steps + self._agent_config.horizon = time_steps if not episodes: episodes = self._training_config.num_episodes diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 6e6d8a5d..3cd2e50a 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,23 +1,48 @@ from typing import Optional import numpy as np -from stable_baselines3 import PPO +from stable_baselines3 import PPO, A2C +from primaite import getLogger from primaite.agents.agent import AgentSessionABC +from primaite.common.enums import RedAgentIdentifier, AgentFramework from primaite.environment.primaite_env import Primaite from stable_baselines3.ppo import MlpPolicy as PPOMlp +_LOGGER = getLogger(__name__) -class SB3PPO(AgentSessionABC): + +class SB3Agent(AgentSessionABC): def __init__( self, training_config_path, lay_down_config_path ): super().__init__(training_config_path, lay_down_config_path) + if not self._training_config.agent_framework == AgentFramework.SB3: + msg = (f"Expected SB3 agent_framework, " + f"got {self._training_config.agent_framework}") + _LOGGER.error(msg) + raise ValueError(msg) + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + self._agent_class = PPO + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + self._agent_class = A2C + else: + msg = ("Expected PPO or A2C red_agent_identifier, " + f"got {self._training_config.red_agent_identifier.value}") + _LOGGER.error(msg) + raise ValueError(msg) + self._tensorboard_log_path = self.session_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) self._setup() + _LOGGER.debug( + f"Created {self.__class__.__name__} using: " + f"agent_framework={self._training_config.agent_framework}, " + f"red_agent_identifier=" + f"{self._training_config.red_agent_identifier}" + ) def _setup(self): super()._setup() @@ -28,10 +53,10 @@ class SB3PPO(AgentSessionABC): session_path=self.session_path, timestamp_str=self.timestamp_str ) - self._agent = PPO( + self._agent = self._agent_class( PPOMlp, self._env, - verbose=1, + verbose=self._training_config.output_verbose_level, n_steps=self._training_config.num_steps, tensorboard_log=self._tensorboard_log_path ) @@ -65,6 +90,7 @@ class SB3PPO(AgentSessionABC): for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() + self._env.close() super().learn() def evaluate( diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0c787e87..89bfd737 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -80,13 +80,13 @@ class Protocol(Enum): class SessionType(Enum): - "The type of PrimAITE Session to be run." + """The type of PrimAITE Session to be run.""" TRAINING = 1 EVALUATION = 2 BOTH = 3 -class VerboseLevel(Enum): +class VerboseLevel(IntEnum): """PrimAITE Session Output verbose level.""" NO_OUTPUT = 0 INFO = 1 diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml index f7c1e372..3f0c546a 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 128 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml index e4a3385d..39bf7dac 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 128 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml index 9f37a6f0..619a0d35 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 256 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 703f37f5..d7b4db98 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -59,7 +59,7 @@ observation_space_high_value: 1000000000 # "NONE" (No Output) # "INFO" (Info Messages) # "ALL" (All Messages) -output_verbose_level: INFO +output_verbose_level: NONE # Reward values # Generic diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 0d39f9c4..4695f2f5 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -185,7 +185,6 @@ class TrainingConfig: for field, enum_class in field_enum_map.items(): if field in config_dict: config_dict[field] = enum_class[config_dict[field]] - return TrainingConfig(**config_dict) def to_dict(self, json_serializable: bool = True): @@ -203,6 +202,7 @@ class TrainingConfig: data["red_agent_identifier"] = self.red_agent_identifier.value data["action_type"] = self.action_type.value data["output_verbose_level"] = self.output_verbose_level.value + data["session_type"] = self.session_type.value return data diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 68209713..0876f070 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -45,6 +45,8 @@ from primaite.pol.ier import IER from primaite.pol.red_agent_pol import apply_red_agent_iers, \ apply_red_agent_node_pol from primaite.transactions.transaction import Transaction +from primaite.transactions.transactions_to_file import \ + write_transaction_to_file _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) @@ -407,10 +409,19 @@ class Primaite(Env): # Return return self.env_obs, reward, done, self.step_info - def __close__(self): - """Override close function.""" - self.csv_file.close() + def close(self): + self.__close__() + def __close__(self): + """ + Override close function + """ + write_transaction_to_file( + self.transaction_list, + self.session_path, + self.timestamp_str + ) + self.csv_file.close() def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() diff --git a/src/primaite/main.py b/src/primaite/main.py index 8619dc57..34134ba2 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,229 +1,162 @@ -# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -# """ -# The main PrimAITE session runner module. -# -# TODO: This will eventually be refactored out into a proper Session class. -# TODO: The passing about of session_path and timestamp_str is temporary and -# will be cleaned up once we move to a proper Session class. -# """ -# import argparse -# import json -# import time -# from datetime import datetime -# from pathlib import Path -# from typing import Final, Union -# from uuid import uuid4 -# -# from stable_baselines3 import A2C, PPO -# from stable_baselines3.common.evaluation import evaluate_policy -# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -# from stable_baselines3.ppo import MlpPolicy as PPOMlp -# -# from primaite import SESSIONS_DIR, getLogger -# from primaite.config.training_config import TrainingConfig -# from primaite.environment.primaite_env import Primaite -# from primaite.transactions.transactions_to_file import \ -# write_transaction_to_file -# -# _LOGGER = getLogger(__name__) -# -# -# def run_generic(env: Primaite, config_values: TrainingConfig): -# """ -# Run against a generic agent. -# -# :param env: An instance of -# :class:`~primaite.environment.primaite_env.Primaite`. -# :param config_values: An instance of -# :class:`~primaite.config.training_config.TrainingConfig`. -# """ -# for episode in range(0, config_values.num_episodes): -# env.reset() -# for step in range(0, config_values.num_steps): -# # Send the observation space to the agent to get an action -# # TEMP - random action for now -# # action = env.blue_agent_action(obs) -# action = env.action_space.sample() -# -# # Run the simulation step on the live environment -# obs, reward, done, info = env.step(action) -# -# # Break if done is True -# if done: -# break -# -# # Introduce a delay between steps -# time.sleep(config_values.time_delay / 1000) -# -# # Reset the environment at the end of the episode -# -# env.close() -# -# -# def run_stable_baselines3_ppo( -# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -# ): -# """ -# Run against a stable_baselines3 PPO agent. -# -# :param env: An instance of -# :class:`~primaite.environment.primaite_env.Primaite`. -# :param config_values: An instance of -# :class:`~primaite.config.training_config.TrainingConfig`. -# :param session_path: The directory path the session is writing to. -# :param timestamp_str: The session timestamp in the format: -# _. -# """ -# if config_values.load_agent: -# try: -# agent = PPO.load( -# config_values.agent_load_file, -# env, -# verbose=0, -# n_steps=config_values.num_steps, -# ) -# except Exception: -# print( -# "ERROR: Could not load agent at location: " -# + config_values.agent_load_file -# ) -# _LOGGER.error("Could not load agent") -# _LOGGER.error("Exception occured", exc_info=True) -# else: -# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) -# -# if config_values.session_type == "TRAINING": -# # We're in a training session -# print("Starting training session...") -# _LOGGER.debug("Starting training session...") -# for episode in range(config_values.num_episodes): -# agent.learn(total_timesteps=config_values.num_steps) -# _save_agent(agent, session_path, timestamp_str) -# else: -# # Default to being in an evaluation session -# print("Starting evaluation session...") -# _LOGGER.debug("Starting evaluation session...") -# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) -# -# env.close() -# -# -# -# -# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): -# """ -# Persist an agent. -# -# Only works for stable baselines3 agents at present. -# -# :param session_path: The directory path the session is writing to. -# :param timestamp_str: The session timestamp in the format: -# _. -# """ -# if not isinstance(agent, OnPolicyAlgorithm): -# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." -# _LOGGER.error(msg) -# else: -# filepath = session_path / f"agent_saved_{timestamp_str}" -# agent.save(filepath) -# _LOGGER.debug(f"Trained agent saved as: {filepath}") -# -# -# -# -# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): -# """Run the PrimAITE Session. -# -# :param training_config_path: The training config filepath. -# :param lay_down_config_path: The lay down config filepath. -# """ -# # Welcome message -# print("Welcome to the Primary-level AI Training Environment (PrimAITE)") -# uuid = str(uuid4()) -# session_timestamp: Final[datetime] = datetime.now() -# session_path = _get_session_path(session_timestamp) -# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") -# -# print(f"The output directory for this session is: {session_path}") -# -# # Create a list of transactions -# # A transaction is an object holding the: -# # - episode # -# # - step # -# # - initial observation space -# # - action -# # - reward -# # - new observation space -# transaction_list = [] -# -# # Create the Primaite environment -# env = Primaite( -# training_config_path=training_config_path, -# lay_down_config_path=lay_down_config_path, -# transaction_list=transaction_list, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# -# print("Writing Session Metadata file...") -# -# _write_session_metadata_file( -# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env -# ) -# -# config_values = env.training_config -# -# # Get the number of steps (which is stored in the child config file) -# config_values.num_steps = env.episode_steps -# -# # Run environment against an agent -# if config_values.agent_identifier == "GENERIC": -# run_generic(env=env, config_values=config_values) -# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": -# run_stable_baselines3_ppo( -# env=env, -# config_values=config_values, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": -# run_stable_baselines3_a2c( -# env=env, -# config_values=config_values, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# -# print("Session finished") -# _LOGGER.debug("Session finished") -# -# print("Saving transaction logs...") -# write_transaction_to_file( -# transaction_list=transaction_list, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# -# print("Updating Session Metadata file...") -# _update_session_metadata_file(session_path=session_path, env=env) -# -# print("Finished") -# _LOGGER.debug("Finished") -# -# -# if __name__ == "__main__": -# parser = argparse.ArgumentParser() -# parser.add_argument("--tc") -# parser.add_argument("--ldc") -# args = parser.parse_args() -# if not args.tc: -# _LOGGER.error( -# "Please provide a training config file using the --tc " "argument" -# ) -# if not args.ldc: -# _LOGGER.error( -# "Please provide a lay down config file using the --ldc " "argument" -# ) -# run(training_config_path=args.tc, lay_down_config_path=args.ldc) -# -# +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +""" +The main PrimAITE session runner module. + +TODO: This will eventually be refactored out into a proper Session class. +TODO: The passing about of session_path and timestamp_str is temporary and + will be cleaned up once we move to a proper Session class. +""" +import argparse +import json +import time +from datetime import datetime +from pathlib import Path +from typing import Final, Union +from uuid import uuid4 + +from stable_baselines3 import A2C, PPO +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.ppo import MlpPolicy as PPOMlp + +from primaite import SESSIONS_DIR, getLogger +from primaite.config.training_config import TrainingConfig +from primaite.environment.primaite_env import Primaite +from primaite.primaite_session import PrimaiteSession +from primaite.transactions.transactions_to_file import \ + write_transaction_to_file + +_LOGGER = getLogger(__name__) + + +def run_generic(env: Primaite, config_values: TrainingConfig): + """ + Run against a generic agent. + + :param env: An instance of + :class:`~primaite.environment.primaite_env.Primaite`. + :param config_values: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + """ + for episode in range(0, config_values.num_episodes): + env.reset() + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = env.action_space.sample() + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + + env.close() + + +def run_stable_baselines3_ppo( + env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str +): + """ + Run against a stable_baselines3 PPO agent. + + :param env: An instance of + :class:`~primaite.environment.primaite_env.Primaite`. + :param config_values: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + if config_values.load_agent: + try: + agent = PPO.load( + config_values.agent_load_file, + env, + verbose=0, + n_steps=config_values.num_steps, + ) + except Exception: + print( + "ERROR: Could not load agent at location: " + + config_values.agent_load_file + ) + _LOGGER.error("Could not load agent") + _LOGGER.error("Exception occured", exc_info=True) + else: + agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) + + if config_values.session_type == "TRAINING": + # We're in a training session + print("Starting training session...") + _LOGGER.debug("Starting training session...") + for episode in range(config_values.num_episodes): + agent.learn(total_timesteps=config_values.num_steps) + _save_agent(agent, session_path, timestamp_str) + else: + # Default to being in an evaluation session + print("Starting evaluation session...") + _LOGGER.debug("Starting evaluation session...") + evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) + + env.close() + + + + +def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): + """ + Persist an agent. + + Only works for stable baselines3 agents at present. + + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + if not isinstance(agent, OnPolicyAlgorithm): + msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." + _LOGGER.error(msg) + else: + filepath = session_path / f"agent_saved_{timestamp_str}" + agent.save(filepath) + _LOGGER.debug(f"Trained agent saved as: {filepath}") + + + + +def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): + """Run the PrimAITE Session. + + :param training_config_path: The training config filepath. + :param lay_down_config_path: The lay down config filepath. + """ + session = PrimaiteSession(training_config_path, lay_down_config_path) + + session.setup() + session.learn() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tc") + parser.add_argument("--ldc") + args = parser.parse_args() + if not args.tc: + _LOGGER.error( + "Please provide a training config file using the --tc " "argument" + ) + if not args.ldc: + _LOGGER.error( + "Please provide a lay down config file using the --ldc " "argument" + ) + run(training_config_path=args.tc, lay_down_config_path=args.ldc) + + diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 8f3380c8..a4148d12 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -8,10 +8,10 @@ from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR from primaite.agents.agent import AgentSessionABC -from primaite.agents.rllib import RLlibPPO -from primaite.agents.sb3 import SB3PPO +from primaite.agents.rllib import RLlibAgent +from primaite.agents.sb3 import SB3Agent from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ - ActionType + ActionType, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite @@ -95,35 +95,19 @@ class PrimaiteSession: pass elif self._training_config.agent_framework == AgentFramework.SB3: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Stable Baselines3/Proximal Policy Optimization - self._agent_session = SB3PPO( - self._training_config_path, - self._lay_down_config_path - ) - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Stable Baselines3/Advantage Actor Critic - raise NotImplementedError - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass + # Stable Baselines3 Agent + self._agent_session = SB3Agent( + self._training_config_path, + self._lay_down_config_path + ) elif self._training_config.agent_framework == AgentFramework.RLLIB: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Ray RLlib/Proximal Policy Optimization - self._agent_session = RLlibPPO( - self._training_config_path, - self._lay_down_config_path - ) + # Ray RLlib Agent + self._agent_session = RLlibAgent( + self._training_config_path, + self._lay_down_config_path + ) - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Ray RLlib/Advantage Actor Critic - raise NotImplementedError - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass else: # Invalid AgentFramework pass @@ -134,7 +118,8 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - self._agent_session.learn(time_steps, episodes, **kwargs) + if not self._training_config.session_type == SessionType.EVALUATION: + self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( self, @@ -142,18 +127,5 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - self._agent_session.evaluate(time_steps, episodes, **kwargs) - - @classmethod - def import_agent( - cls, - gent_path: str, - training_config_path: str, - lay_down_config_path: str - ) -> PrimaiteSession: - session = PrimaiteSession(training_config_path, lay_down_config_path) - - # Reset the UUID - session._uuid = "" - - return session + if not self._training_config.session_type == SessionType.TRAINING: + self._agent_session.evaluate(time_steps, episodes, **kwargs) diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 24581597..ed7a8f1c 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -4,6 +4,8 @@ import csv from pathlib import Path +import numpy as np + from primaite import getLogger _LOGGER = getLogger(__name__) @@ -54,8 +56,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action # space as "AS_1" # This will be tied into the PrimAITE Use Case so that they make sense + template_transation = transaction_list[0] - action_length = template_transation.action_space.size + if isinstance(template_transation.action_space, int): + action_length = template_transation.action_space + else: + action_length = template_transation.action_space.size obs_shape = template_transation.obs_space_post.shape obs_assets = template_transation.obs_space_post.shape[0] if len(obs_shape) == 1: