#917 - Almost there. All output files being writen for SB3/RLLIB PPO & A2C. Just need to bring in the hardcoded agents then update the testa and docs.
This commit is contained in:
@@ -1 +1 @@
|
||||
2.0.0rc1
|
||||
2.0.0b1
|
||||
@@ -5,19 +5,18 @@ from pathlib import Path
|
||||
from typing import Optional, Final, Dict, Union, List
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.common.enums import OutputVerboseLevel
|
||||
from primaite.config import lay_down_config
|
||||
from primaite.config import training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
@@ -26,7 +25,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path("./") / date_dir / session_path
|
||||
session_path = SESSIONS_DIR / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
@@ -57,16 +56,16 @@ class AgentSessionABC(ABC):
|
||||
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._transaction_list: List[Dict] = []
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = _get_temp_session_path(self.session_timestamp)
|
||||
self.session_path = _get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
self.checkpoints_path = self.session_path / "checkpoints"
|
||||
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
|
||||
"The Session checkpoints path"
|
||||
|
||||
self.timestamp_str = self.session_timestamp.strftime(
|
||||
@@ -167,11 +166,6 @@ class AgentSessionABC(ABC):
|
||||
):
|
||||
if self._can_learn:
|
||||
_LOGGER.debug("Writing transactions")
|
||||
write_transaction_to_file(
|
||||
transaction_list=self._transaction_list,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
import glob
|
||||
import time
|
||||
from enum import Enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.config import training_config
|
||||
from primaite.common.enums import AgentFramework, RedAgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
def _env_creator(env_config):
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
@@ -23,7 +28,17 @@ def _env_creator(env_config):
|
||||
)
|
||||
|
||||
|
||||
class RLlibPPO(AgentSessionABC):
|
||||
def _custom_log_creator(session_path: Path):
|
||||
logdir = session_path / "ray_results"
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def logger_creator(config):
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
return logger_creator
|
||||
|
||||
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -31,17 +46,63 @@ class RLlibPPO(AgentSessionABC):
|
||||
lay_down_config_path
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._ppo_config: PPOConfig
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = (f"Expected RLLIB agent_framework, "
|
||||
f"got {self._training_config.agent_framework}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
self._agent_config_class = PPOConfig
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = ("Expected PPO or A2C red_agent_identifier, "
|
||||
f"got {self._training_config.red_agent_identifier.value}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: PPOConfig
|
||||
|
||||
self._current_result: dict
|
||||
self._setup()
|
||||
self._agent.save()
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"red_agent_identifier="
|
||||
f"{self._training_config.red_agent_identifier}, "
|
||||
f"deep_learning_framework="
|
||||
f"{self._training_config.deep_learning_framework}"
|
||||
)
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._ppo_config = PPOConfig()
|
||||
self._agent_config = self._agent_config_class()
|
||||
|
||||
self._ppo_config.environment(
|
||||
self._agent_config.environment(
|
||||
env="primaite",
|
||||
env_config=dict(
|
||||
training_config_path=self._training_config_path,
|
||||
@@ -52,19 +113,21 @@ class RLlibPPO(AgentSessionABC):
|
||||
)
|
||||
)
|
||||
|
||||
self._ppo_config.training(
|
||||
self._agent_config.training(
|
||||
train_batch_size=self._training_config.num_steps
|
||||
)
|
||||
self._ppo_config.framework(
|
||||
framework=self._training_config.deep_learning_framework.value
|
||||
self._agent_config.framework(
|
||||
framework=self._training_config.deep_learning_framework
|
||||
)
|
||||
|
||||
self._ppo_config.rollouts(
|
||||
self._agent_config.rollouts(
|
||||
num_rollout_workers=1,
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps
|
||||
)
|
||||
self._agent: Algorithm = self._ppo_config.build()
|
||||
self._agent: Algorithm = self._agent_config.build(
|
||||
logger_creator=_custom_log_creator(self.session_path)
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
@@ -84,8 +147,8 @@ class RLlibPPO(AgentSessionABC):
|
||||
):
|
||||
# Temporarily override train_batch_size and horizon
|
||||
if time_steps:
|
||||
self._ppo_config.train_batch_size = time_steps
|
||||
self._ppo_config.horizon = time_steps
|
||||
self._agent_config.train_batch_size = time_steps
|
||||
self._agent_config.horizon = time_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
@@ -1,23 +1,48 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3 import PPO, A2C
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import RedAgentIdentifier, AgentFramework
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
class SB3PPO(AgentSessionABC):
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = (f"Expected SB3 agent_framework, "
|
||||
f"got {self._training_config.agent_framework}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = ("Expected PPO or A2C red_agent_identifier, "
|
||||
f"got {self._training_config.red_agent_identifier.value}")
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
self._setup()
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"red_agent_identifier="
|
||||
f"{self._training_config.red_agent_identifier}"
|
||||
)
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
@@ -28,10 +53,10 @@ class SB3PPO(AgentSessionABC):
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
self._agent = PPO(
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=1,
|
||||
verbose=self._training_config.output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=self._tensorboard_log_path
|
||||
)
|
||||
@@ -65,6 +90,7 @@ class SB3PPO(AgentSessionABC):
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
|
||||
@@ -80,13 +80,13 @@ class Protocol(Enum):
|
||||
|
||||
|
||||
class SessionType(Enum):
|
||||
"The type of PrimAITE Session to be run."
|
||||
"""The type of PrimAITE Session to be run."""
|
||||
TRAINING = 1
|
||||
EVALUATION = 2
|
||||
BOTH = 3
|
||||
|
||||
|
||||
class VerboseLevel(Enum):
|
||||
class VerboseLevel(IntEnum):
|
||||
"""PrimAITE Session Output verbose level."""
|
||||
NO_OUTPUT = 0
|
||||
INFO = 1
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
- item_type: ACTIONS
|
||||
type: NODE
|
||||
- item_type: STEPS
|
||||
steps: 128
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
- item_type: ACTIONS
|
||||
type: NODE
|
||||
- item_type: STEPS
|
||||
steps: 128
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
- item_type: ACTIONS
|
||||
type: NODE
|
||||
- item_type: STEPS
|
||||
steps: 256
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
|
||||
@@ -59,7 +59,7 @@ observation_space_high_value: 1000000000
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages)
|
||||
# "ALL" (All Messages)
|
||||
output_verbose_level: INFO
|
||||
output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
|
||||
@@ -185,7 +185,6 @@ class TrainingConfig:
|
||||
for field, enum_class in field_enum_map.items():
|
||||
if field in config_dict:
|
||||
config_dict[field] = enum_class[config_dict[field]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
@@ -203,6 +202,7 @@ class TrainingConfig:
|
||||
data["red_agent_identifier"] = self.red_agent_identifier.value
|
||||
data["action_type"] = self.action_type.value
|
||||
data["output_verbose_level"] = self.output_verbose_level.value
|
||||
data["session_type"] = self.session_type.value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
@@ -407,10 +409,19 @@ class Primaite(Env):
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def __close__(self):
|
||||
"""Override close function."""
|
||||
self.csv_file.close()
|
||||
def close(self):
|
||||
self.__close__()
|
||||
|
||||
def __close__(self):
|
||||
"""
|
||||
Override close function
|
||||
"""
|
||||
write_transaction_to_file(
|
||||
self.transaction_list,
|
||||
self.session_path,
|
||||
self.timestamp_str
|
||||
)
|
||||
self.csv_file.close()
|
||||
def init_acl(self):
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
|
||||
@@ -1,229 +1,162 @@
|
||||
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# """
|
||||
# The main PrimAITE session runner module.
|
||||
#
|
||||
# TODO: This will eventually be refactored out into a proper Session class.
|
||||
# TODO: The passing about of session_path and timestamp_str is temporary and
|
||||
# will be cleaned up once we move to a proper Session class.
|
||||
# """
|
||||
# import argparse
|
||||
# import json
|
||||
# import time
|
||||
# from datetime import datetime
|
||||
# from pathlib import Path
|
||||
# from typing import Final, Union
|
||||
# from uuid import uuid4
|
||||
#
|
||||
# from stable_baselines3 import A2C, PPO
|
||||
# from stable_baselines3.common.evaluation import evaluate_policy
|
||||
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
# from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
#
|
||||
# from primaite import SESSIONS_DIR, getLogger
|
||||
# from primaite.config.training_config import TrainingConfig
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
# from primaite.transactions.transactions_to_file import \
|
||||
# write_transaction_to_file
|
||||
#
|
||||
# _LOGGER = getLogger(__name__)
|
||||
#
|
||||
#
|
||||
# def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
# """
|
||||
# Run against a generic agent.
|
||||
#
|
||||
# :param env: An instance of
|
||||
# :class:`~primaite.environment.primaite_env.Primaite`.
|
||||
# :param config_values: An instance of
|
||||
# :class:`~primaite.config.training_config.TrainingConfig`.
|
||||
# """
|
||||
# for episode in range(0, config_values.num_episodes):
|
||||
# env.reset()
|
||||
# for step in range(0, config_values.num_steps):
|
||||
# # Send the observation space to the agent to get an action
|
||||
# # TEMP - random action for now
|
||||
# # action = env.blue_agent_action(obs)
|
||||
# action = env.action_space.sample()
|
||||
#
|
||||
# # Run the simulation step on the live environment
|
||||
# obs, reward, done, info = env.step(action)
|
||||
#
|
||||
# # Break if done is True
|
||||
# if done:
|
||||
# break
|
||||
#
|
||||
# # Introduce a delay between steps
|
||||
# time.sleep(config_values.time_delay / 1000)
|
||||
#
|
||||
# # Reset the environment at the end of the episode
|
||||
#
|
||||
# env.close()
|
||||
#
|
||||
#
|
||||
# def run_stable_baselines3_ppo(
|
||||
# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
# ):
|
||||
# """
|
||||
# Run against a stable_baselines3 PPO agent.
|
||||
#
|
||||
# :param env: An instance of
|
||||
# :class:`~primaite.environment.primaite_env.Primaite`.
|
||||
# :param config_values: An instance of
|
||||
# :class:`~primaite.config.training_config.TrainingConfig`.
|
||||
# :param session_path: The directory path the session is writing to.
|
||||
# :param timestamp_str: The session timestamp in the format:
|
||||
# <yyyy-mm-dd>_<hh-mm-ss>.
|
||||
# """
|
||||
# if config_values.load_agent:
|
||||
# try:
|
||||
# agent = PPO.load(
|
||||
# config_values.agent_load_file,
|
||||
# env,
|
||||
# verbose=0,
|
||||
# n_steps=config_values.num_steps,
|
||||
# )
|
||||
# except Exception:
|
||||
# print(
|
||||
# "ERROR: Could not load agent at location: "
|
||||
# + config_values.agent_load_file
|
||||
# )
|
||||
# _LOGGER.error("Could not load agent")
|
||||
# _LOGGER.error("Exception occured", exc_info=True)
|
||||
# else:
|
||||
# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
#
|
||||
# if config_values.session_type == "TRAINING":
|
||||
# # We're in a training session
|
||||
# print("Starting training session...")
|
||||
# _LOGGER.debug("Starting training session...")
|
||||
# for episode in range(config_values.num_episodes):
|
||||
# agent.learn(total_timesteps=config_values.num_steps)
|
||||
# _save_agent(agent, session_path, timestamp_str)
|
||||
# else:
|
||||
# # Default to being in an evaluation session
|
||||
# print("Starting evaluation session...")
|
||||
# _LOGGER.debug("Starting evaluation session...")
|
||||
# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
#
|
||||
# env.close()
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
# """
|
||||
# Persist an agent.
|
||||
#
|
||||
# Only works for stable baselines3 agents at present.
|
||||
#
|
||||
# :param session_path: The directory path the session is writing to.
|
||||
# :param timestamp_str: The session timestamp in the format:
|
||||
# <yyyy-mm-dd>_<hh-mm-ss>.
|
||||
# """
|
||||
# if not isinstance(agent, OnPolicyAlgorithm):
|
||||
# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
# _LOGGER.error(msg)
|
||||
# else:
|
||||
# filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
# agent.save(filepath)
|
||||
# _LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
|
||||
# """Run the PrimAITE Session.
|
||||
#
|
||||
# :param training_config_path: The training config filepath.
|
||||
# :param lay_down_config_path: The lay down config filepath.
|
||||
# """
|
||||
# # Welcome message
|
||||
# print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
# uuid = str(uuid4())
|
||||
# session_timestamp: Final[datetime] = datetime.now()
|
||||
# session_path = _get_session_path(session_timestamp)
|
||||
# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
#
|
||||
# print(f"The output directory for this session is: {session_path}")
|
||||
#
|
||||
# # Create a list of transactions
|
||||
# # A transaction is an object holding the:
|
||||
# # - episode #
|
||||
# # - step #
|
||||
# # - initial observation space
|
||||
# # - action
|
||||
# # - reward
|
||||
# # - new observation space
|
||||
# transaction_list = []
|
||||
#
|
||||
# # Create the Primaite environment
|
||||
# env = Primaite(
|
||||
# training_config_path=training_config_path,
|
||||
# lay_down_config_path=lay_down_config_path,
|
||||
# transaction_list=transaction_list,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
#
|
||||
# print("Writing Session Metadata file...")
|
||||
#
|
||||
# _write_session_metadata_file(
|
||||
# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env
|
||||
# )
|
||||
#
|
||||
# config_values = env.training_config
|
||||
#
|
||||
# # Get the number of steps (which is stored in the child config file)
|
||||
# config_values.num_steps = env.episode_steps
|
||||
#
|
||||
# # Run environment against an agent
|
||||
# if config_values.agent_identifier == "GENERIC":
|
||||
# run_generic(env=env, config_values=config_values)
|
||||
# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
# run_stable_baselines3_ppo(
|
||||
# env=env,
|
||||
# config_values=config_values,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
# run_stable_baselines3_a2c(
|
||||
# env=env,
|
||||
# config_values=config_values,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
#
|
||||
# print("Session finished")
|
||||
# _LOGGER.debug("Session finished")
|
||||
#
|
||||
# print("Saving transaction logs...")
|
||||
# write_transaction_to_file(
|
||||
# transaction_list=transaction_list,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
#
|
||||
# print("Updating Session Metadata file...")
|
||||
# _update_session_metadata_file(session_path=session_path, env=env)
|
||||
#
|
||||
# print("Finished")
|
||||
# _LOGGER.debug("Finished")
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--tc")
|
||||
# parser.add_argument("--ldc")
|
||||
# args = parser.parse_args()
|
||||
# if not args.tc:
|
||||
# _LOGGER.error(
|
||||
# "Please provide a training config file using the --tc " "argument"
|
||||
# )
|
||||
# if not args.ldc:
|
||||
# _LOGGER.error(
|
||||
# "Please provide a lay down config file using the --ldc " "argument"
|
||||
# )
|
||||
# run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
#
|
||||
#
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The main PrimAITE session runner module.
|
||||
|
||||
TODO: This will eventually be refactored out into a proper Session class.
|
||||
TODO: The passing about of session_path and timestamp_str is temporary and
|
||||
will be cleaned up once we move to a proper Session class.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
"""
|
||||
Run against a generic agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
"""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
env.reset()
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo(
|
||||
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(config_values.num_episodes):
|
||||
agent.learn(total_timesteps=config_values.num_steps)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
|
||||
|
||||
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Persist an agent.
|
||||
|
||||
Only works for stable baselines3 agents at present.
|
||||
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if not isinstance(agent, OnPolicyAlgorithm):
|
||||
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
_LOGGER.error(msg)
|
||||
else:
|
||||
filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
agent.save(filepath)
|
||||
_LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
|
||||
|
||||
|
||||
|
||||
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
|
||||
"""Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
"""
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error(
|
||||
"Please provide a training config file using the --tc " "argument"
|
||||
)
|
||||
if not args.ldc:
|
||||
_LOGGER.error(
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from uuid import uuid4
|
||||
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.rllib import RLlibPPO
|
||||
from primaite.agents.sb3 import SB3PPO
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
|
||||
ActionType
|
||||
ActionType, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
@@ -95,35 +95,19 @@ class PrimaiteSession:
|
||||
pass
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Stable Baselines3/Proximal Policy Optimization
|
||||
self._agent_session = SB3PPO(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Stable Baselines3/Advantage Actor Critic
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Ray RLlib/Proximal Policy Optimization
|
||||
self._agent_session = RLlibPPO(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Ray RLlib/Advantage Actor Critic
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
pass
|
||||
@@ -134,7 +118,8 @@ class PrimaiteSession:
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
self._agent_session.learn(time_steps, episodes, **kwargs)
|
||||
if not self._training_config.session_type == SessionType.EVALUATION:
|
||||
self._agent_session.learn(time_steps, episodes, **kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
@@ -142,18 +127,5 @@ class PrimaiteSession:
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
self._agent_session.evaluate(time_steps, episodes, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def import_agent(
|
||||
cls,
|
||||
gent_path: str,
|
||||
training_config_path: str,
|
||||
lay_down_config_path: str
|
||||
) -> PrimaiteSession:
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
# Reset the UUID
|
||||
session._uuid = ""
|
||||
|
||||
return session
|
||||
if not self._training_config.session_type == SessionType.TRAINING:
|
||||
self._agent_session.evaluate(time_steps, episodes, **kwargs)
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -54,8 +56,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
|
||||
# space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
|
||||
template_transation = transaction_list[0]
|
||||
action_length = template_transation.action_space.size
|
||||
if isinstance(template_transation.action_space, int):
|
||||
action_length = template_transation.action_space
|
||||
else:
|
||||
action_length = template_transation.action_space.size
|
||||
obs_shape = template_transation.obs_space_post.shape
|
||||
obs_assets = template_transation.obs_space_post.shape[0]
|
||||
if len(obs_shape) == 1:
|
||||
|
||||
Reference in New Issue
Block a user