#917 - Integrated both SB3 and RLlib agents into PrimaiteSession

This commit is contained in:
Chris McCarthy
2023-06-19 20:27:08 +01:00
parent c2c396052f
commit 23bafde457
13 changed files with 726 additions and 688 deletions

View File

@@ -0,0 +1,251 @@
import json
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Optional, Final, Dict, Union, List
from uuid import uuid4
from primaite import getLogger
from primaite.common.enums import OutputVerboseLevel
from primaite.config import lay_down_config
from primaite.config import training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
_LOGGER = getLogger(__name__)
def _get_temp_session_path(session_timestamp: datetime) -> Path:
"""
Get a temp directory session path the test session will output to.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path("./") / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
class AgentSessionABC(ABC):
@abstractmethod
def __init__(
self,
training_config_path,
lay_down_config_path
):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path
)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(
self._lay_down_config_path
)
self.output_verbose_level = self._training_config.output_verbose_level
self._env: Primaite
self._agent = None
self._transaction_list: List[Dict] = []
self._can_learn: bool = False
self._can_evaluate: bool = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = _get_temp_session_path(self.session_timestamp)
"The Session path"
self.checkpoints_path = self.session_path / "checkpoints"
"The Session checkpoints path"
self.timestamp_str = self.session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
"The session timestamp as a string"
@property
def uuid(self):
"""The Agent Session UUID."""
return self._uuid
def _write_session_metadata_file(self):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_path`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self.uuid,
"start_datetime": self.session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": self._training_config.to_dict(
json_serializable=True
),
"lay_down_config": self._lay_down_config,
},
}
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished writing session metadata file")
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._env.episode_count
metadata_dict["total_time_steps"] = self._env.total_step_count
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
@abstractmethod
def _setup(self):
if self.output_verbose_level >= OutputVerboseLevel.INFO:
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment "
"(PrimAITE)"
)
_LOGGER.debug(
f"The output directory for this agent is: {self.session_path}"
)
self._write_session_metadata_file()
self._can_learn = True
self._can_evaluate = False
@abstractmethod
def _save_checkpoint(self):
pass
@abstractmethod
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
):
if self._can_learn:
_LOGGER.debug("Writing transactions")
write_transaction_to_file(
transaction_list=self._transaction_list,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._update_session_metadata_file()
self._can_evaluate = True
@abstractmethod
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
):
pass
@abstractmethod
def _get_latest_checkpoint(self):
pass
@classmethod
@abstractmethod
def load(cls):
pass
@abstractmethod
def save(self):
self._agent.save(self.session_path)
@abstractmethod
def export(self):
pass
class DeterministicAgentSessionABC(AgentSessionABC):
@abstractmethod
def __init__(
self,
training_config_path,
lay_down_config_path
):
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self._env: Primaite
self._agent = None
@abstractmethod
def _setup(self):
pass
@abstractmethod
def _get_latest_checkpoint(self):
pass
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
pass
@classmethod
@abstractmethod
def load(cls):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def export(self):
pass

View File

@@ -1,132 +0,0 @@
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Optional, Final, Dict, Any, Union, Tuple
import yaml
from primaite import getLogger
from primaite.config.training_config import TrainingConfig, load
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _get_temp_session_path(session_timestamp: datetime) -> Path:
"""
Get a temp directory session path the test session will output to.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path("./") / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
return session_path
class AgentABC(ABC):
@abstractmethod
def __init__(
self,
training_config_path,
lay_down_config_path
):
self._training_config_path = training_config_path
self._training_config: Final[TrainingConfig] = load(
self._training_config_path
)
self._lay_down_config_path = lay_down_config_path
self._env: Primaite
self._agent = None
self.session_timestamp: datetime = datetime.now()
self.session_path = _get_temp_session_path(self.session_timestamp)
self.timestamp_str = self.session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
@abstractmethod
def _setup(self):
pass
@abstractmethod
def _save_checkpoint(self):
pass
@abstractmethod
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
pass
@abstractmethod
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
pass
@abstractmethod
def _get_latest_checkpoint(self):
pass
@classmethod
@abstractmethod
def load(cls):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def export(self):
pass
class DeterministicAgentABC(AgentABC):
@abstractmethod
def __init__(
self,
training_config_path,
lay_down_config_path
):
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self._env: Primaite
self._agent = None
@abstractmethod
def _setup(self):
pass
@abstractmethod
def _get_latest_checkpoint(self):
pass
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
pass
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
pass
@classmethod
@abstractmethod
def load(cls):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def export(self):
pass

View File

@@ -8,7 +8,7 @@ from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from primaite.agents.agent_abc import AgentABC
from primaite.agents.agent import AgentSessionABC
from primaite.config import training_config
from primaite.environment.primaite_env import Primaite
@@ -23,7 +23,7 @@ def _env_creator(env_config):
)
class RLlibPPO(AgentABC):
class RLlibPPO(AgentSessionABC):
def __init__(
self,
@@ -34,8 +34,10 @@ class RLlibPPO(AgentABC):
self._ppo_config: PPOConfig
self._current_result: dict
self._setup()
self._agent.save()
def _setup(self):
super()._setup()
register_env("primaite", _env_creator)
self._ppo_config = PPOConfig()
@@ -72,12 +74,13 @@ class RLlibPPO(AgentABC):
(episode_count % checkpoint_n == 0)
or (episode_count == self._training_config.num_episodes)
):
self._agent.save(self.session_path)
self._agent.save(self.checkpoints_path)
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
episodes: Optional[int] = None,
**kwargs
):
# Temporarily override train_batch_size and horizon
if time_steps:
@@ -91,11 +94,13 @@ class RLlibPPO(AgentABC):
self._current_result = self._agent.train()
self._save_checkpoint()
self._agent.stop()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
episodes: Optional[int] = None,
**kwargs
):
raise NotImplementedError

View File

@@ -1,13 +1,14 @@
from typing import Optional
import numpy as np
from stable_baselines3 import PPO
from primaite.agents.agent_abc import AgentABC
from primaite.agents.agent import AgentSessionABC
from primaite.environment.primaite_env import Primaite
from stable_baselines3.ppo import MlpPolicy as PPOMlp
class SB3PPO(AgentABC):
class SB3PPO(AgentSessionABC):
def __init__(
self,
training_config_path,
@@ -16,8 +17,10 @@ class SB3PPO(AgentABC):
super().__init__(training_config_path, lay_down_config_path)
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
def _setup(self):
super()._setup()
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -28,15 +31,30 @@ class SB3PPO(AgentABC):
self._agent = PPO(
PPOMlp,
self._env,
verbose=0,
verbose=1,
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (
(episode_count % checkpoint_n == 0)
or (episode_count == self._training_config.num_episodes)
):
self._agent.save(
self.checkpoints_path / f"sb3ppo_{episode_count}.zip")
def _get_latest_checkpoint(self):
pass
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
episodes: Optional[int] = None,
**kwargs
):
if not time_steps:
time_steps = self._training_config.num_steps
@@ -46,12 +64,15 @@ class SB3PPO(AgentABC):
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
deterministic: bool = True
deterministic: bool = True,
**kwargs
):
if not time_steps:
time_steps = self._training_config.num_steps
@@ -67,6 +88,8 @@ class SB3PPO(AgentABC):
obs,
deterministic=deterministic
)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
def load(self):

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Enumerations for APE."""
from enum import Enum
from enum import Enum, IntEnum
class NodeType(Enum):
@@ -172,3 +172,13 @@ class LinkStatus(Enum):
MEDIUM = 2
HIGH = 3
OVERLOAD = 4
class OutputVerboseLevel(IntEnum):
"""The Agent output verbosity level."""
NONE = 0
"No Output"
INFO = 1
"Info Messages"
ALL = 2
"All Messages"

View File

@@ -54,6 +54,13 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# The high value for the observation space
observation_space_high_value: 1000000000
# The Agent output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages)
# "ALL" (All Messages)
output_verbose_level: INFO
# Reward values
# Generic
all_ok: 0

View File

@@ -1,21 +1,63 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final
from typing import Final, Union, Dict, Any
import networkx
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
_LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
_EXAMPLE_LAY_DOWN: Final[
Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
# class LayDownConfig:
# network: networkx.Graph
# POL
# EIR
# ACL
def convert_legacy_lay_down_config_dict(
legacy_config_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert a legacy lay down config dict to the new format.
:param legacy_config_dict: A legacy lay down config dict.
"""
_LOGGER.warning("Legacy lay down config conversion not yet implemented")
return legacy_config_dict
def load(
file_path: Union[str, Path],
legacy_file: bool = False
) -> Dict:
"""
Read in a lay down config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise
False.
:return: The lay down config as a dict.
:raises ValueError: If the file_path does not exist.
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loading lay down config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_lay_down_config_dict(config)
except KeyError:
msg = (
f"Failed to convert lay down config file {file_path} "
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
return config
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
_LOGGER.error(msg)
raise ValueError(msg)
def ddos_basic_one_config_path() -> Path:
"""

View File

@@ -10,11 +10,27 @@ import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import DeepLearningFramework
from primaite.common.enums import ActionType, RedAgentIdentifier, \
AgentFramework, SessionType
AgentFramework, SessionType, OutputVerboseLevel
_LOGGER = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
_EXAMPLE_TRAINING: Final[
Path] = USERS_CONFIG_DIR / "example_config" / "training"
def main_training_config_path() -> Path:
"""
The path to the example training_config_main.yaml file.
:return: The file path.
"""
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
@dataclass()
@@ -24,44 +40,47 @@ class TrainingConfig:
"The AgentFramework"
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
"The DeepLearningFramework."
"The DeepLearningFramework"
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
"The RedAgentIdentifier.."
"The RedAgentIdentifier"
action_type: ActionType = ActionType.ANY
"The ActionType to use."
"The ActionType to use"
num_episodes: int = 10
"The number of episodes to train over."
"The number of episodes to train over"
num_steps: int = 256
"The number of steps in an episode."
"The number of steps in an episode"
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes."
"The agent will save a checkpoint every n episodes"
observation_space: dict = field(
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
)
"The observation space config dict."
"The observation space config dict"
time_delay: int = 10
"The delay between steps (ms). Applies to generic agents only."
"The delay between steps (ms). Applies to generic agents only"
# file
session_type: SessionType = SessionType.TRAINING
"The type of PrimAITE session to run."
"The type of PrimAITE session to run"
load_agent: str = False
"Determine whether to load an agent from file."
"Determine whether to load an agent from file"
agent_load_file: Optional[str] = None
"File path and file name of agent if you're loading one in."
"File path and file name of agent if you're loading one in"
# Environment
observation_space_high_value: int = 1000000000
"The high value for the observation space."
"The high value for the observation space"
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
"The Agent output verbosity level"
# Reward values
# Generic
@@ -126,28 +145,28 @@ class TrainingConfig:
# Patching / Reset durations
os_patching_duration: int = 5
"The time taken to patch the OS."
"The time taken to patch the OS"
node_reset_duration: int = 5
"The time taken to reset a node (hardware)."
"The time taken to reset a node (hardware)"
node_booting_duration: int = 3
"The Time taken to turn on the node."
"The Time taken to turn on the node"
node_shutdown_duration: int = 2
"The time taken to turn off the node."
"The time taken to turn off the node"
service_patching_duration: int = 5
"The time taken to patch a service."
"The time taken to patch a service"
file_system_repairing_limit: int = 5
"The time take to repair the file system."
"The time take to repair the file system"
file_system_restoring_limit: int = 5
"The time take to restore the file system."
"The time take to restore the file system"
file_system_scanning_limit: int = 5
"The time taken to scan the file system."
"The time taken to scan the file system"
@classmethod
def from_dict(
@@ -157,9 +176,10 @@ class TrainingConfig:
field_enum_map = {
"agent_framework": AgentFramework,
"deep_learning_framework": DeepLearningFramework,
"red_agent_identifier": RedAgentIdentifier,
"action_type": ActionType,
"session_type": SessionType
"red_agent_identifier": RedAgentIdentifier,
"action_type": ActionType,
"session_type": SessionType,
"output_verbose_level": OutputVerboseLevel
}
for field, enum_class in field_enum_map.items():
@@ -178,28 +198,19 @@ class TrainingConfig:
"""
data = self.__dict__
if json_serializable:
data["agent_framework"] = self.agent_framework.value
data["deep_learning_framework"] = self.deep_learning_framework.value
data["red_agent_identifier"] = self.red_agent_identifier.value
data["action_type"] = self.action_type.value
data["output_verbose_level"] = self.output_verbose_level.value
return data
def main_training_config_path() -> Path:
"""
The path to the example training_config_main.yaml file.
:return: The file path.
"""
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def load(file_path: Union[str, Path],
legacy_file: bool = False) -> TrainingConfig:
def load(
file_path: Union[str, Path],
legacy_file: bool = False
) -> TrainingConfig:
"""
Read in a training config yaml file.
@@ -246,7 +257,8 @@ def convert_legacy_training_config_dict(
agent_framework: AgentFramework = AgentFramework.SB3,
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256
num_steps: int = 256,
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
@@ -260,13 +272,16 @@ def convert_legacy_training_config_dict(
don't have action_type values.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:param output_verbose_level: The agent output verbose level to use as
legacy training configs don't have output_verbose_level values.
:return: The converted training config dict.
"""
config_dict = {
"agent_framework": agent_framework.name,
"red_agent_identifier": red_agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps
"num_steps": num_steps,
"output_verbose_level": output_verbose_level
}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)

View File

@@ -435,7 +435,6 @@ class Primaite(Env):
_action: The action space from the agent
"""
# At the moment, actions are only affecting nodes
if self.training_config.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:

View File

@@ -1,305 +1,229 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""
The main PrimAITE session runner module.
TODO: This will eventually be refactored out into a proper Session class.
TODO: The passing about of session_dir and timestamp_str is temporary and
will be cleaned up once we move to a proper Session class.
"""
import argparse
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Final, Union
from uuid import uuid4
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
_LOGGER = getLogger(__name__)
def run_generic(env: Primaite, config_values: TrainingConfig):
"""
Run against a generic agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
"""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.close()
def run_stable_baselines3_ppo(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 PPO agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = PPO.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def _write_session_metadata_file(
session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite
):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_dir`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": uuid,
"start_datetime": session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": env.training_config.to_dict(json_serializable=True),
"lay_down_config": env.lay_down_config,
},
}
filepath = session_dir / "session_metadata.json"
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
def _update_session_metadata_file(session_dir: Path, env: Primaite):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_dir`` directory
with the following key/value pairs:
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
"""
with open(session_dir / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = env.episode_count
metadata_dict["total_time_steps"] = env.total_step_count
filepath = session_dir / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
"""
Persist an agent.
Only works for stable baselines3 agents at present.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if not isinstance(agent, OnPolicyAlgorithm):
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
_LOGGER.error(msg)
else:
filepath = session_path / f"agent_saved_{timestamp_str}"
agent.save(filepath)
_LOGGER.debug(f"Trained agent saved as: {filepath}")
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
"""Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
"""
# Welcome message
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
uuid = str(uuid4())
session_timestamp: Final[datetime] = datetime.now()
session_dir = _get_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
print(f"The output directory for this session is: {session_dir}")
# Create a list of transactions
# A transaction is an object holding the:
# - episode #
# - step #
# - initial observation space
# - action
# - reward
# - new observation space
transaction_list = []
# Create the Primaite environment
env = Primaite(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Writing Session Metadata file...")
_write_session_metadata_file(
session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env
)
config_values = env.training_config
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = env.episode_steps
# Run environment against an agent
if config_values.agent_identifier == "GENERIC":
run_generic(env=env, config_values=config_values)
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Updating Session Metadata file...")
_update_session_metadata_file(session_dir=session_dir, env=env)
print("Finished")
_LOGGER.debug("Finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
args = parser.parse_args()
if not args.tc:
_LOGGER.error(
"Please provide a training config file using the --tc " "argument"
)
if not args.ldc:
_LOGGER.error(
"Please provide a lay down config file using the --ldc " "argument"
)
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# """
# The main PrimAITE session runner module.
#
# TODO: This will eventually be refactored out into a proper Session class.
# TODO: The passing about of session_path and timestamp_str is temporary and
# will be cleaned up once we move to a proper Session class.
# """
# import argparse
# import json
# import time
# from datetime import datetime
# from pathlib import Path
# from typing import Final, Union
# from uuid import uuid4
#
# from stable_baselines3 import A2C, PPO
# from stable_baselines3.common.evaluation import evaluate_policy
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
# from stable_baselines3.ppo import MlpPolicy as PPOMlp
#
# from primaite import SESSIONS_DIR, getLogger
# from primaite.config.training_config import TrainingConfig
# from primaite.environment.primaite_env import Primaite
# from primaite.transactions.transactions_to_file import \
# write_transaction_to_file
#
# _LOGGER = getLogger(__name__)
#
#
# def run_generic(env: Primaite, config_values: TrainingConfig):
# """
# Run against a generic agent.
#
# :param env: An instance of
# :class:`~primaite.environment.primaite_env.Primaite`.
# :param config_values: An instance of
# :class:`~primaite.config.training_config.TrainingConfig`.
# """
# for episode in range(0, config_values.num_episodes):
# env.reset()
# for step in range(0, config_values.num_steps):
# # Send the observation space to the agent to get an action
# # TEMP - random action for now
# # action = env.blue_agent_action(obs)
# action = env.action_space.sample()
#
# # Run the simulation step on the live environment
# obs, reward, done, info = env.step(action)
#
# # Break if done is True
# if done:
# break
#
# # Introduce a delay between steps
# time.sleep(config_values.time_delay / 1000)
#
# # Reset the environment at the end of the episode
#
# env.close()
#
#
# def run_stable_baselines3_ppo(
# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
# ):
# """
# Run against a stable_baselines3 PPO agent.
#
# :param env: An instance of
# :class:`~primaite.environment.primaite_env.Primaite`.
# :param config_values: An instance of
# :class:`~primaite.config.training_config.TrainingConfig`.
# :param session_path: The directory path the session is writing to.
# :param timestamp_str: The session timestamp in the format:
# <yyyy-mm-dd>_<hh-mm-ss>.
# """
# if config_values.load_agent:
# try:
# agent = PPO.load(
# config_values.agent_load_file,
# env,
# verbose=0,
# n_steps=config_values.num_steps,
# )
# except Exception:
# print(
# "ERROR: Could not load agent at location: "
# + config_values.agent_load_file
# )
# _LOGGER.error("Could not load agent")
# _LOGGER.error("Exception occured", exc_info=True)
# else:
# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
#
# if config_values.session_type == "TRAINING":
# # We're in a training session
# print("Starting training session...")
# _LOGGER.debug("Starting training session...")
# for episode in range(config_values.num_episodes):
# agent.learn(total_timesteps=config_values.num_steps)
# _save_agent(agent, session_path, timestamp_str)
# else:
# # Default to being in an evaluation session
# print("Starting evaluation session...")
# _LOGGER.debug("Starting evaluation session...")
# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
#
# env.close()
#
#
#
#
# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
# """
# Persist an agent.
#
# Only works for stable baselines3 agents at present.
#
# :param session_path: The directory path the session is writing to.
# :param timestamp_str: The session timestamp in the format:
# <yyyy-mm-dd>_<hh-mm-ss>.
# """
# if not isinstance(agent, OnPolicyAlgorithm):
# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
# _LOGGER.error(msg)
# else:
# filepath = session_path / f"agent_saved_{timestamp_str}"
# agent.save(filepath)
# _LOGGER.debug(f"Trained agent saved as: {filepath}")
#
#
#
#
# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
# """Run the PrimAITE Session.
#
# :param training_config_path: The training config filepath.
# :param lay_down_config_path: The lay down config filepath.
# """
# # Welcome message
# print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
# uuid = str(uuid4())
# session_timestamp: Final[datetime] = datetime.now()
# session_path = _get_session_path(session_timestamp)
# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
#
# print(f"The output directory for this session is: {session_path}")
#
# # Create a list of transactions
# # A transaction is an object holding the:
# # - episode #
# # - step #
# # - initial observation space
# # - action
# # - reward
# # - new observation space
# transaction_list = []
#
# # Create the Primaite environment
# env = Primaite(
# training_config_path=training_config_path,
# lay_down_config_path=lay_down_config_path,
# transaction_list=transaction_list,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
#
# print("Writing Session Metadata file...")
#
# _write_session_metadata_file(
# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env
# )
#
# config_values = env.training_config
#
# # Get the number of steps (which is stored in the child config file)
# config_values.num_steps = env.episode_steps
#
# # Run environment against an agent
# if config_values.agent_identifier == "GENERIC":
# run_generic(env=env, config_values=config_values)
# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
# run_stable_baselines3_ppo(
# env=env,
# config_values=config_values,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
# run_stable_baselines3_a2c(
# env=env,
# config_values=config_values,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
#
# print("Session finished")
# _LOGGER.debug("Session finished")
#
# print("Saving transaction logs...")
# write_transaction_to_file(
# transaction_list=transaction_list,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
#
# print("Updating Session Metadata file...")
# _update_session_metadata_file(session_path=session_path, env=env)
#
# print("Finished")
# _LOGGER.debug("Finished")
#
#
# if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# parser.add_argument("--tc")
# parser.add_argument("--ldc")
# args = parser.parse_args()
# if not args.tc:
# _LOGGER.error(
# "Please provide a training config file using the --tc " "argument"
# )
# if not args.ldc:
# _LOGGER.error(
# "Please provide a lay down config file using the --ldc " "argument"
# )
# run(training_config_path=args.tc, lay_down_config_path=args.ldc)
#
#

View File

@@ -3,12 +3,16 @@ from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Final, Optional, Union
from typing import Final, Optional, Union, Dict
from uuid import uuid4
from primaite import getLogger, SESSIONS_DIR
from primaite.agents.agent import AgentSessionABC
from primaite.agents.rllib import RLlibPPO
from primaite.agents.sb3 import SB3PPO
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
ActionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
@@ -26,8 +30,8 @@ def _get_session_path(session_timestamp: datetime) -> Path:
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
@@ -45,211 +49,100 @@ class PrimaiteSession:
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path
)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._auto: Final[bool] = auto
self._uuid: str = str(uuid4())
self._session_timestamp: Final[datetime] = datetime.now()
self._session_path: Final[Path] = _get_session_path(
self._session_timestamp
self._lay_down_config: Dict = lay_down_config.load(
self._lay_down_config_path
)
self._timestamp_str: Final[str] = self._session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
self._metadata_path = self._session_path / "session_metadata.json"
self._env = None
self._training_config: TrainingConfig
self._can_learn: bool = False
_LOGGER.debug("")
self._auto: bool = auto
self._agent_session: AgentSessionABC = None # noqa
if self._auto:
self.setup()
self.learn()
@property
def uuid(self):
"""The session UUID."""
return self._uuid
def _setup_primaite_env(self, transaction_list: Optional[list] = None):
if not transaction_list:
transaction_list = []
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=transaction_list,
session_path=self._session_path,
timestamp_str=self._timestamp_str
)
self._training_config: TrainingConfig = self._env.training_config
def _write_session_metadata_file(self):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_dir`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self._uuid,
"start_datetime": self._session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": self._env.training_config.to_dict(
json_serializable=True
),
"lay_down_config": self._env.lay_down_config,
},
}
_LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_dir`` directory
with the following key/value pairs:
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
"""
with open(self._metadata_path, "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._env.episode_count
metadata_dict["total_time_steps"] = self._env.total_step_count
_LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def setup(self):
self._setup_primaite_env()
self._can_learn = True
pass
if self._training_config.agent_framework == AgentFramework.NONE:
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
# Stochastic Random Agent
raise NotImplementedError
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
raise NotImplementedError
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
raise NotImplementedError
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid RedAgentIdentifier ActionType combo
pass
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
elif self._training_config.agent_framework == AgentFramework.SB3:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Stable Baselines3/Proximal Policy Optimization
self._agent_session = SB3PPO(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Stable Baselines3/Advantage Actor Critic
raise NotImplementedError
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
elif self._training_config.agent_framework == AgentFramework.RLLIB:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Ray RLlib/Proximal Policy Optimization
self._agent_session = RLlibPPO(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Ray RLlib/Advantage Actor Critic
raise NotImplementedError
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
else:
# Invalid AgentFramework
pass
def learn(
self,
time_steps: Optional[int],
episodes: Optional[int],
iterations: Optional[int],
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
):
if self._can_learn:
# Run environment against an agent
if self._training_config.agent_framework == AgentFramework.NONE:
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
# Stochastic Random Agent
run_generic(env=env, config_values=config_values)
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
pass
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
pass
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
pass
else:
# Invalid RedAgentIdentifier ActionType combo
pass
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
elif self._training_config.agent_framework == AgentFramework.SB3:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Stable Baselines3/Proximal Policy Optimization
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Stable Baselines3/Advantage Actor Critic
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
elif self._training_config.agent_framework == AgentFramework.RLLIB:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Ray RLlib/Proximal Policy Optimization
pass
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Ray RLlib/Advantage Actor Critic
pass
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
else:
# Invalid AgentFramework
pass
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Updating Session Metadata file...")
_update_session_metadata_file(session_dir=session_dir, env=env)
print("Finished")
_LOGGER.debug("Finished")
self._agent_session.learn(time_steps, episodes, **kwargs)
def evaluate(
self,
time_steps: Optional[int],
episodes: Optional[int],
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
):
pass
def export(self):
pass
self._agent_session.evaluate(time_steps, episodes, **kwargs)
@classmethod
def import_agent(

View File

@@ -108,5 +108,6 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
csv_writer.writerow(csv_data)
csv_file.close()
_LOGGER.debug("Finished writing transactions")
except Exception:
_LOGGER.error("Could not save the transaction file", exc_info=True)

View File

@@ -19,8 +19,8 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path