#917 - Integrated both SB3 and RLlib agents into PrimaiteSession
This commit is contained in:
251
src/primaite/agents/agent.py
Normal file
251
src/primaite/agents/agent.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Final, Dict, Union, List
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import OutputVerboseLevel
|
||||
from primaite.config import lay_down_config
|
||||
from primaite.config import training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path("./") / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class AgentSessionABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(
|
||||
self._lay_down_config_path
|
||||
)
|
||||
self.output_verbose_level = self._training_config.output_verbose_level
|
||||
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._transaction_list: List[Dict] = []
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = _get_temp_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
self.checkpoints_path = self.session_path / "checkpoints"
|
||||
"The Session checkpoints path"
|
||||
|
||||
self.timestamp_str = self.session_timestamp.strftime(
|
||||
"%Y-%m-%d_%H-%M-%S")
|
||||
"The session timestamp as a string"
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self):
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_path`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": self.uuid,
|
||||
"start_datetime": self.session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None,
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(
|
||||
json_serializable=True
|
||||
),
|
||||
"lay_down_config": self._lay_down_config,
|
||||
},
|
||||
}
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._env.episode_count
|
||||
metadata_dict["total_time_steps"] = self._env.total_step_count
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
if self.output_verbose_level >= OutputVerboseLevel.INFO:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment "
|
||||
"(PrimAITE)"
|
||||
)
|
||||
_LOGGER.debug(
|
||||
f"The output directory for this agent is: {self.session_path}"
|
||||
)
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
if self._can_learn:
|
||||
_LOGGER.debug("Writing transactions")
|
||||
write_transaction_to_file(
|
||||
transaction_list=self._transaction_list,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
self._agent.save(self.session_path)
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeterministicAgentSessionABC(AgentSessionABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
@@ -1,132 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Final, Dict, Any, Union, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.training_config import TrainingConfig, load
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path("./") / date_dir / session_dir
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class AgentABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
self._training_config_path = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = load(
|
||||
self._training_config_path
|
||||
)
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
self.session_path = _get_temp_session_path(self.session_timestamp)
|
||||
|
||||
self.timestamp_str = self.session_timestamp.strftime(
|
||||
"%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeterministicAgentABC(AgentABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite.agents.agent_abc import AgentABC
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.config import training_config
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -23,7 +23,7 @@ def _env_creator(env_config):
|
||||
)
|
||||
|
||||
|
||||
class RLlibPPO(AgentABC):
|
||||
class RLlibPPO(AgentSessionABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -34,8 +34,10 @@ class RLlibPPO(AgentABC):
|
||||
self._ppo_config: PPOConfig
|
||||
self._current_result: dict
|
||||
self._setup()
|
||||
self._agent.save()
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
register_env("primaite", _env_creator)
|
||||
self._ppo_config = PPOConfig()
|
||||
|
||||
@@ -72,12 +74,13 @@ class RLlibPPO(AgentABC):
|
||||
(episode_count % checkpoint_n == 0)
|
||||
or (episode_count == self._training_config.num_episodes)
|
||||
):
|
||||
self._agent.save(self.session_path)
|
||||
self._agent.save(self.checkpoints_path)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
# Temporarily override train_batch_size and horizon
|
||||
if time_steps:
|
||||
@@ -91,11 +94,13 @@ class RLlibPPO(AgentABC):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self._agent.stop()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from primaite.agents.agent_abc import AgentABC
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
|
||||
class SB3PPO(AgentABC):
|
||||
class SB3PPO(AgentSessionABC):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
@@ -16,8 +17,10 @@ class SB3PPO(AgentABC):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
@@ -28,15 +31,30 @@ class SB3PPO(AgentABC):
|
||||
self._agent = PPO(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=0,
|
||||
verbose=1,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=self._tensorboard_log_path
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (
|
||||
(episode_count % checkpoint_n == 0)
|
||||
or (episode_count == self._training_config.num_episodes)
|
||||
):
|
||||
self._agent.save(
|
||||
self.checkpoints_path / f"sb3ppo_{episode_count}.zip")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
@@ -46,12 +64,15 @@ class SB3PPO(AgentABC):
|
||||
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
deterministic: bool = True
|
||||
deterministic: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
@@ -67,6 +88,8 @@ class SB3PPO(AgentABC):
|
||||
obs,
|
||||
deterministic=deterministic
|
||||
)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
def load(self):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
@@ -172,3 +172,13 @@ class LinkStatus(Enum):
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
OVERLOAD = 4
|
||||
|
||||
|
||||
class OutputVerboseLevel(IntEnum):
|
||||
"""The Agent output verbosity level."""
|
||||
NONE = 0
|
||||
"No Output"
|
||||
INFO = 1
|
||||
"Info Messages"
|
||||
ALL = 2
|
||||
"All Messages"
|
||||
|
||||
@@ -54,6 +54,13 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# The Agent output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages)
|
||||
# "ALL" (All Messages)
|
||||
output_verbose_level: INFO
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
|
||||
@@ -1,21 +1,63 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
from typing import Final, Union, Dict, Any
|
||||
|
||||
import networkx
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
_EXAMPLE_LAY_DOWN: Final[
|
||||
Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
# class LayDownConfig:
|
||||
# network: networkx.Graph
|
||||
# POL
|
||||
# EIR
|
||||
# ACL
|
||||
def convert_legacy_lay_down_config_dict(
|
||||
legacy_config_dict: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy lay down config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy lay down config dict.
|
||||
"""
|
||||
_LOGGER.warning("Legacy lay down config conversion not yet implemented")
|
||||
return legacy_config_dict
|
||||
|
||||
|
||||
def load(
|
||||
file_path: Union[str, Path],
|
||||
legacy_file: bool = False
|
||||
) -> Dict:
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
False.
|
||||
:return: The lay down config as a dict.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading lay down config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_lay_down_config_dict(config)
|
||||
except KeyError:
|
||||
msg = (
|
||||
f"Failed to convert lay down config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
return config
|
||||
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
|
||||
@@ -10,11 +10,27 @@ import yaml
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import DeepLearningFramework
|
||||
from primaite.common.enums import ActionType, RedAgentIdentifier, \
|
||||
AgentFramework, SessionType
|
||||
AgentFramework, SessionType, OutputVerboseLevel
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
_EXAMPLE_TRAINING: Final[
|
||||
Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@dataclass()
|
||||
@@ -24,44 +40,47 @@ class TrainingConfig:
|
||||
"The AgentFramework"
|
||||
|
||||
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
|
||||
"The DeepLearningFramework."
|
||||
"The DeepLearningFramework"
|
||||
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
|
||||
"The RedAgentIdentifier.."
|
||||
"The RedAgentIdentifier"
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use."
|
||||
"The ActionType to use"
|
||||
|
||||
num_episodes: int = 10
|
||||
"The number of episodes to train over."
|
||||
"The number of episodes to train over"
|
||||
|
||||
num_steps: int = 256
|
||||
"The number of steps in an episode."
|
||||
"The number of steps in an episode"
|
||||
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes."
|
||||
"The agent will save a checkpoint every n episodes"
|
||||
|
||||
observation_space: dict = field(
|
||||
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
)
|
||||
"The observation space config dict."
|
||||
"The observation space config dict"
|
||||
|
||||
time_delay: int = 10
|
||||
"The delay between steps (ms). Applies to generic agents only."
|
||||
"The delay between steps (ms). Applies to generic agents only"
|
||||
|
||||
# file
|
||||
session_type: SessionType = SessionType.TRAINING
|
||||
"The type of PrimAITE session to run."
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: str = False
|
||||
"Determine whether to load an agent from file."
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
"File path and file name of agent if you're loading one in."
|
||||
"File path and file name of agent if you're loading one in"
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int = 1000000000
|
||||
"The high value for the observation space."
|
||||
"The high value for the observation space"
|
||||
|
||||
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
|
||||
"The Agent output verbosity level"
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
@@ -126,28 +145,28 @@ class TrainingConfig:
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: int = 5
|
||||
"The time taken to patch the OS."
|
||||
"The time taken to patch the OS"
|
||||
|
||||
node_reset_duration: int = 5
|
||||
"The time taken to reset a node (hardware)."
|
||||
"The time taken to reset a node (hardware)"
|
||||
|
||||
node_booting_duration: int = 3
|
||||
"The Time taken to turn on the node."
|
||||
"The Time taken to turn on the node"
|
||||
|
||||
node_shutdown_duration: int = 2
|
||||
"The time taken to turn off the node."
|
||||
"The time taken to turn off the node"
|
||||
|
||||
service_patching_duration: int = 5
|
||||
"The time taken to patch a service."
|
||||
"The time taken to patch a service"
|
||||
|
||||
file_system_repairing_limit: int = 5
|
||||
"The time take to repair the file system."
|
||||
"The time take to repair the file system"
|
||||
|
||||
file_system_restoring_limit: int = 5
|
||||
"The time take to restore the file system."
|
||||
"The time take to restore the file system"
|
||||
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system."
|
||||
"The time taken to scan the file system"
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
@@ -157,9 +176,10 @@ class TrainingConfig:
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
"red_agent_identifier": RedAgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType
|
||||
"red_agent_identifier": RedAgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType,
|
||||
"output_verbose_level": OutputVerboseLevel
|
||||
}
|
||||
|
||||
for field, enum_class in field_enum_map.items():
|
||||
@@ -178,28 +198,19 @@ class TrainingConfig:
|
||||
"""
|
||||
data = self.__dict__
|
||||
if json_serializable:
|
||||
data["agent_framework"] = self.agent_framework.value
|
||||
data["deep_learning_framework"] = self.deep_learning_framework.value
|
||||
data["red_agent_identifier"] = self.red_agent_identifier.value
|
||||
data["action_type"] = self.action_type.value
|
||||
data["output_verbose_level"] = self.output_verbose_level.value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path],
|
||||
legacy_file: bool = False) -> TrainingConfig:
|
||||
def load(
|
||||
file_path: Union[str, Path],
|
||||
legacy_file: bool = False
|
||||
) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
@@ -246,7 +257,8 @@ def convert_legacy_training_config_dict(
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256
|
||||
num_steps: int = 256,
|
||||
output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
@@ -260,13 +272,16 @@ def convert_legacy_training_config_dict(
|
||||
don't have action_type values.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:param output_verbose_level: The agent output verbose level to use as
|
||||
legacy training configs don't have output_verbose_level values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"agent_framework": agent_framework.name,
|
||||
"red_agent_identifier": red_agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_steps": num_steps
|
||||
"num_steps": num_steps,
|
||||
"output_verbose_level": output_verbose_level
|
||||
}
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
|
||||
@@ -435,7 +435,6 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
|
||||
@@ -1,305 +1,229 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The main PrimAITE session runner module.
|
||||
|
||||
TODO: This will eventually be refactored out into a proper Session class.
|
||||
TODO: The passing about of session_dir and timestamp_str is temporary and
|
||||
will be cleaned up once we move to a proper Session class.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
"""
|
||||
Run against a generic agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
"""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
env.reset()
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo(
|
||||
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(config_values.num_episodes):
|
||||
agent.learn(total_timesteps=config_values.num_steps)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def _write_session_metadata_file(
|
||||
session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite
|
||||
):
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_dir`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": uuid,
|
||||
"start_datetime": session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None,
|
||||
"env": {
|
||||
"training_config": env.training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": env.lay_down_config,
|
||||
},
|
||||
}
|
||||
filepath = session_dir / "session_metadata.json"
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
|
||||
def _update_session_metadata_file(session_dir: Path, env: Primaite):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_dir`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
"""
|
||||
with open(session_dir / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = env.episode_count
|
||||
metadata_dict["total_time_steps"] = env.total_step_count
|
||||
|
||||
filepath = session_dir / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
|
||||
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Persist an agent.
|
||||
|
||||
Only works for stable baselines3 agents at present.
|
||||
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if not isinstance(agent, OnPolicyAlgorithm):
|
||||
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
_LOGGER.error(msg)
|
||||
else:
|
||||
filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
agent.save(filepath)
|
||||
_LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
|
||||
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_dir
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
|
||||
"""Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
"""
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
uuid = str(uuid4())
|
||||
session_timestamp: Final[datetime] = datetime.now()
|
||||
session_dir = _get_session_path(session_timestamp)
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
print(f"The output directory for this session is: {session_dir}")
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the Primaite environment
|
||||
env = Primaite(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Writing Session Metadata file...")
|
||||
|
||||
_write_session_metadata_file(
|
||||
session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env
|
||||
)
|
||||
|
||||
config_values = env.training_config
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic(env=env, config_values=config_values)
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Session finished")
|
||||
_LOGGER.debug("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
write_transaction_to_file(
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Updating Session Metadata file...")
|
||||
_update_session_metadata_file(session_dir=session_dir, env=env)
|
||||
|
||||
print("Finished")
|
||||
_LOGGER.debug("Finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error(
|
||||
"Please provide a training config file using the --tc " "argument"
|
||||
)
|
||||
if not args.ldc:
|
||||
_LOGGER.error(
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
|
||||
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# """
|
||||
# The main PrimAITE session runner module.
|
||||
#
|
||||
# TODO: This will eventually be refactored out into a proper Session class.
|
||||
# TODO: The passing about of session_path and timestamp_str is temporary and
|
||||
# will be cleaned up once we move to a proper Session class.
|
||||
# """
|
||||
# import argparse
|
||||
# import json
|
||||
# import time
|
||||
# from datetime import datetime
|
||||
# from pathlib import Path
|
||||
# from typing import Final, Union
|
||||
# from uuid import uuid4
|
||||
#
|
||||
# from stable_baselines3 import A2C, PPO
|
||||
# from stable_baselines3.common.evaluation import evaluate_policy
|
||||
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
# from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
#
|
||||
# from primaite import SESSIONS_DIR, getLogger
|
||||
# from primaite.config.training_config import TrainingConfig
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
# from primaite.transactions.transactions_to_file import \
|
||||
# write_transaction_to_file
|
||||
#
|
||||
# _LOGGER = getLogger(__name__)
|
||||
#
|
||||
#
|
||||
# def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
# """
|
||||
# Run against a generic agent.
|
||||
#
|
||||
# :param env: An instance of
|
||||
# :class:`~primaite.environment.primaite_env.Primaite`.
|
||||
# :param config_values: An instance of
|
||||
# :class:`~primaite.config.training_config.TrainingConfig`.
|
||||
# """
|
||||
# for episode in range(0, config_values.num_episodes):
|
||||
# env.reset()
|
||||
# for step in range(0, config_values.num_steps):
|
||||
# # Send the observation space to the agent to get an action
|
||||
# # TEMP - random action for now
|
||||
# # action = env.blue_agent_action(obs)
|
||||
# action = env.action_space.sample()
|
||||
#
|
||||
# # Run the simulation step on the live environment
|
||||
# obs, reward, done, info = env.step(action)
|
||||
#
|
||||
# # Break if done is True
|
||||
# if done:
|
||||
# break
|
||||
#
|
||||
# # Introduce a delay between steps
|
||||
# time.sleep(config_values.time_delay / 1000)
|
||||
#
|
||||
# # Reset the environment at the end of the episode
|
||||
#
|
||||
# env.close()
|
||||
#
|
||||
#
|
||||
# def run_stable_baselines3_ppo(
|
||||
# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
# ):
|
||||
# """
|
||||
# Run against a stable_baselines3 PPO agent.
|
||||
#
|
||||
# :param env: An instance of
|
||||
# :class:`~primaite.environment.primaite_env.Primaite`.
|
||||
# :param config_values: An instance of
|
||||
# :class:`~primaite.config.training_config.TrainingConfig`.
|
||||
# :param session_path: The directory path the session is writing to.
|
||||
# :param timestamp_str: The session timestamp in the format:
|
||||
# <yyyy-mm-dd>_<hh-mm-ss>.
|
||||
# """
|
||||
# if config_values.load_agent:
|
||||
# try:
|
||||
# agent = PPO.load(
|
||||
# config_values.agent_load_file,
|
||||
# env,
|
||||
# verbose=0,
|
||||
# n_steps=config_values.num_steps,
|
||||
# )
|
||||
# except Exception:
|
||||
# print(
|
||||
# "ERROR: Could not load agent at location: "
|
||||
# + config_values.agent_load_file
|
||||
# )
|
||||
# _LOGGER.error("Could not load agent")
|
||||
# _LOGGER.error("Exception occured", exc_info=True)
|
||||
# else:
|
||||
# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
#
|
||||
# if config_values.session_type == "TRAINING":
|
||||
# # We're in a training session
|
||||
# print("Starting training session...")
|
||||
# _LOGGER.debug("Starting training session...")
|
||||
# for episode in range(config_values.num_episodes):
|
||||
# agent.learn(total_timesteps=config_values.num_steps)
|
||||
# _save_agent(agent, session_path, timestamp_str)
|
||||
# else:
|
||||
# # Default to being in an evaluation session
|
||||
# print("Starting evaluation session...")
|
||||
# _LOGGER.debug("Starting evaluation session...")
|
||||
# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
#
|
||||
# env.close()
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
# """
|
||||
# Persist an agent.
|
||||
#
|
||||
# Only works for stable baselines3 agents at present.
|
||||
#
|
||||
# :param session_path: The directory path the session is writing to.
|
||||
# :param timestamp_str: The session timestamp in the format:
|
||||
# <yyyy-mm-dd>_<hh-mm-ss>.
|
||||
# """
|
||||
# if not isinstance(agent, OnPolicyAlgorithm):
|
||||
# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
# _LOGGER.error(msg)
|
||||
# else:
|
||||
# filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
# agent.save(filepath)
|
||||
# _LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
|
||||
# """Run the PrimAITE Session.
|
||||
#
|
||||
# :param training_config_path: The training config filepath.
|
||||
# :param lay_down_config_path: The lay down config filepath.
|
||||
# """
|
||||
# # Welcome message
|
||||
# print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
# uuid = str(uuid4())
|
||||
# session_timestamp: Final[datetime] = datetime.now()
|
||||
# session_path = _get_session_path(session_timestamp)
|
||||
# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
#
|
||||
# print(f"The output directory for this session is: {session_path}")
|
||||
#
|
||||
# # Create a list of transactions
|
||||
# # A transaction is an object holding the:
|
||||
# # - episode #
|
||||
# # - step #
|
||||
# # - initial observation space
|
||||
# # - action
|
||||
# # - reward
|
||||
# # - new observation space
|
||||
# transaction_list = []
|
||||
#
|
||||
# # Create the Primaite environment
|
||||
# env = Primaite(
|
||||
# training_config_path=training_config_path,
|
||||
# lay_down_config_path=lay_down_config_path,
|
||||
# transaction_list=transaction_list,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
#
|
||||
# print("Writing Session Metadata file...")
|
||||
#
|
||||
# _write_session_metadata_file(
|
||||
# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env
|
||||
# )
|
||||
#
|
||||
# config_values = env.training_config
|
||||
#
|
||||
# # Get the number of steps (which is stored in the child config file)
|
||||
# config_values.num_steps = env.episode_steps
|
||||
#
|
||||
# # Run environment against an agent
|
||||
# if config_values.agent_identifier == "GENERIC":
|
||||
# run_generic(env=env, config_values=config_values)
|
||||
# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
# run_stable_baselines3_ppo(
|
||||
# env=env,
|
||||
# config_values=config_values,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
# run_stable_baselines3_a2c(
|
||||
# env=env,
|
||||
# config_values=config_values,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
#
|
||||
# print("Session finished")
|
||||
# _LOGGER.debug("Session finished")
|
||||
#
|
||||
# print("Saving transaction logs...")
|
||||
# write_transaction_to_file(
|
||||
# transaction_list=transaction_list,
|
||||
# session_path=session_path,
|
||||
# timestamp_str=timestamp_str,
|
||||
# )
|
||||
#
|
||||
# print("Updating Session Metadata file...")
|
||||
# _update_session_metadata_file(session_path=session_path, env=env)
|
||||
#
|
||||
# print("Finished")
|
||||
# _LOGGER.debug("Finished")
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--tc")
|
||||
# parser.add_argument("--ldc")
|
||||
# args = parser.parse_args()
|
||||
# if not args.tc:
|
||||
# _LOGGER.error(
|
||||
# "Please provide a training config file using the --tc " "argument"
|
||||
# )
|
||||
# if not args.ldc:
|
||||
# _LOGGER.error(
|
||||
# "Please provide a lay down config file using the --ldc " "argument"
|
||||
# )
|
||||
# run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
#
|
||||
#
|
||||
|
||||
@@ -3,12 +3,16 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Optional, Union
|
||||
from typing import Final, Optional, Union, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.rllib import RLlibPPO
|
||||
from primaite.agents.sb3 import SB3PPO
|
||||
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
|
||||
ActionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -26,8 +30,8 @@ def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_dir
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
|
||||
|
||||
@@ -45,211 +49,100 @@ class PrimaiteSession:
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
|
||||
self._auto: Final[bool] = auto
|
||||
|
||||
self._uuid: str = str(uuid4())
|
||||
self._session_timestamp: Final[datetime] = datetime.now()
|
||||
self._session_path: Final[Path] = _get_session_path(
|
||||
self._session_timestamp
|
||||
self._lay_down_config: Dict = lay_down_config.load(
|
||||
self._lay_down_config_path
|
||||
)
|
||||
self._timestamp_str: Final[str] = self._session_timestamp.strftime(
|
||||
"%Y-%m-%d_%H-%M-%S")
|
||||
self._metadata_path = self._session_path / "session_metadata.json"
|
||||
|
||||
|
||||
self._env = None
|
||||
self._training_config: TrainingConfig
|
||||
self._can_learn: bool = False
|
||||
_LOGGER.debug("")
|
||||
self._auto: bool = auto
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
|
||||
if self._auto:
|
||||
self.setup()
|
||||
self.learn()
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
"""The session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _setup_primaite_env(self, transaction_list: Optional[list] = None):
|
||||
if not transaction_list:
|
||||
transaction_list = []
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=transaction_list,
|
||||
session_path=self._session_path,
|
||||
timestamp_str=self._timestamp_str
|
||||
)
|
||||
self._training_config: TrainingConfig = self._env.training_config
|
||||
|
||||
def _write_session_metadata_file(self):
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_dir`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": self._uuid,
|
||||
"start_datetime": self._session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None,
|
||||
"env": {
|
||||
"training_config": self._env.training_config.to_dict(
|
||||
json_serializable=True
|
||||
),
|
||||
"lay_down_config": self._env.lay_down_config,
|
||||
},
|
||||
}
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}")
|
||||
with open(self._metadata_path, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_dir`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
"""
|
||||
with open(self._metadata_path, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._env.episode_count
|
||||
metadata_dict["total_time_steps"] = self._env.total_step_count
|
||||
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}")
|
||||
with open(self._metadata_path, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
def setup(self):
|
||||
self._setup_primaite_env()
|
||||
self._can_learn = True
|
||||
pass
|
||||
if self._training_config.agent_framework == AgentFramework.NONE:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
|
||||
# Stochastic Random Agent
|
||||
raise NotImplementedError
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid RedAgentIdentifier ActionType combo
|
||||
pass
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Stable Baselines3/Proximal Policy Optimization
|
||||
self._agent_session = SB3PPO(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Stable Baselines3/Advantage Actor Critic
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Ray RLlib/Proximal Policy Optimization
|
||||
self._agent_session = RLlibPPO(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Ray RLlib/Advantage Actor Critic
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int],
|
||||
episodes: Optional[int],
|
||||
iterations: Optional[int],
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
if self._can_learn:
|
||||
# Run environment against an agent
|
||||
if self._training_config.agent_framework == AgentFramework.NONE:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
|
||||
# Stochastic Random Agent
|
||||
run_generic(env=env, config_values=config_values)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
pass
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
pass
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
pass
|
||||
|
||||
else:
|
||||
# Invalid RedAgentIdentifier ActionType combo
|
||||
pass
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Stable Baselines3/Proximal Policy Optimization
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Stable Baselines3/Advantage Actor Critic
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Ray RLlib/Proximal Policy Optimization
|
||||
pass
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Ray RLlib/Advantage Actor Critic
|
||||
pass
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
pass
|
||||
|
||||
print("Session finished")
|
||||
_LOGGER.debug("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
write_transaction_to_file(
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Updating Session Metadata file...")
|
||||
_update_session_metadata_file(session_dir=session_dir, env=env)
|
||||
|
||||
print("Finished")
|
||||
_LOGGER.debug("Finished")
|
||||
self._agent_session.learn(time_steps, episodes, **kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int],
|
||||
episodes: Optional[int],
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
||||
|
||||
def export(self):
|
||||
pass
|
||||
self._agent_session.evaluate(time_steps, episodes, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def import_agent(
|
||||
|
||||
@@ -108,5 +108,6 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
csv_file.close()
|
||||
_LOGGER.debug("Finished writing transactions")
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save the transaction file", exc_info=True)
|
||||
|
||||
@@ -19,8 +19,8 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
Reference in New Issue
Block a user