temp commit

This commit is contained in:
Chris McCarthy
2023-06-13 09:42:54 +01:00
parent 9b0e24c27b
commit eb3368edd6
11 changed files with 626 additions and 173 deletions

View File

@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import Optional
from primaite.environment.primaite_env import Primaite
class AgentABC(ABC):
@abstractmethod
def __init__(self, env: Primaite):
self._env: Primaite = env
self._agent = None
@abstractmethod
def _setup(self):
pass
@abstractmethod
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
pass
@abstractmethod
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
pass
@abstractmethod
def load(self):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def export(self):
pass

View File

@@ -0,0 +1,177 @@
import glob
import time
from enum import Enum
from pathlib import Path
from typing import Union, Optional
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from primaite.config import training_config
from primaite.environment.primaite_env import Primaite
class DLFramework(Enum):
"""The DL Frameworks enumeration."""
TF = "tf"
TF2 = "tf2"
TORCH = "torch"
def env_creator(env_config):
training_config_path = env_config["training_config_path"]
lay_down_config_path = env_config["lay_down_config_path"]
return Primaite(training_config_path, lay_down_config_path, [])
def get_ppo_config(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
framework: Optional[DLFramework] = DLFramework.TORCH
) -> PPOConfig():
# Register environment
register_env("primaite", env_creator)
# Setup PPO
config = PPOConfig()
config_values = training_config.load(training_config_path)
# Setup our config object to use our environment
config.environment(
env="primaite",
env_config=dict(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path
)
)
env_config = config_values
action_type = env_config.action_type
red_agent = env_config.red_agent_identifier
if red_agent == "RANDOM" and action_type == "NODE":
config.training(
train_batch_size=6000, lr=5e-5
) # number of steps in a training iteration
elif red_agent == "RANDOM" and action_type != "NODE":
config.training(train_batch_size=6000, lr=5e-5)
elif red_agent == "CONFIG" and action_type == "NODE":
config.training(train_batch_size=400, lr=5e-5)
elif red_agent == "CONFIG" and action_type != "NONE":
config.training(train_batch_size=500, lr=5e-5)
else:
config.training(train_batch_size=500, lr=5e-5)
# Decide if you want torch or tensorflow DL framework. Default is "tf"
config.framework(framework=framework.value)
# Set the log level to DEBUG, INFO, WARN, or ERROR
config.debugging(seed=415, log_level="ERROR")
# Setup evaluation
# Explicitly set "explore"=False to override default
# config.evaluation(
# evaluation_interval=100,
# evaluation_duration=20,
# # evaluation_duration_unit="timesteps",) #default episodes
# evaluation_config={"explore": False},
# )
# Setup sampling rollout workers
config.rollouts(
num_rollout_workers=4,
num_envs_per_worker=1,
horizon=128, # num parralel workiers
) # max num steps in an episode
config.build() # Build config
return config
def train(
num_iterations: int,
config: Optional[PPOConfig] = None,
algo: Optional[Algorithm] = None
):
"""
Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint)
"""
start_time = time.time()
if algo is None:
algo = config.build()
elif config is None:
config = algo.get_config()
print(f"Algorithm type: {type(algo)}")
# iterations are not the same as episodes.
for i in range(num_iterations):
result = algo.train()
# # Save every 10 iterations or after last iteration in training
# if (i % 100 == 0) or (i == num_iterations - 1):
print(
f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}")
# save checkpoint file
checkpoint_file = algo.save("./")
print(f"Checkpoint saved at {checkpoint_file}")
# convert num_iterations to num_episodes
num_episodes = len(
result["hist_stats"]["episode_lengths"]) * num_iterations
# convert num_iterations to num_timesteps
num_timesteps = sum(
result["hist_stats"]["episode_lengths"] * num_iterations)
# calculate number of wins
# train time
print(f"Training took {time.time() - start_time:.2f} seconds")
print(
f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}")
return result
def load_model_from_checkpoint(config, checkpoint=None):
# create an empty Algorithm
algo = config.build()
if checkpoint is None:
# Get the checkpoint with the highest iteration number
checkpoint = get_most_recent_checkpoint(config)
# restore the agent from the checkpoint
algo.restore(checkpoint)
return algo
def get_most_recent_checkpoint(config):
"""
Get the most recent checkpoint for specified action type, red agent and algorithm
"""
env_config = list(config.env_config.values())[0]
action_type = env_config.action_type
red_agent = env_config.red_agent_identifier
algo_name = config.algo_class.__name__
# Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model
relevant_checkpoints = glob.glob(
f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*"
)
checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints]
max_checkpoint = str(max(checkpoint_numbers))
checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint
checkpoint = (
relevant_checkpoints[0].split("_")[0]
+ "_"
+ checkpoint_number_to_use
+ "/rllib_checkpoint.json"
)
return checkpoint

View File

@@ -0,0 +1,28 @@
# from typing import Optional
#
# from primaite.agents.agent_abc import AgentABC
# from primaite.environment.primaite_env import Primaite
#
#
# class SB3PPO(AgentABC):
# def __init__(self, env: Primaite):
# super().__init__(env)
#
# def _setup(self):
# if self._env.training_config
# pass
#
# def learn(self, time_steps: Optional[int], episodes: Optional[int]):
# pass
#
# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
# pass
#
# def load(self):
# pass
#
# def save(self):
# pass
#
# def export(self):
# pass

View File

@@ -79,6 +79,33 @@ class Protocol(Enum):
NONE = 7
class SessionType(Enum):
"The type of PrimAITE Session to be run."
TRAINING = 1
EVALUATION = 2
BOTH = 3
class VerboseLevel(Enum):
"""PrimAITE Session Output verbose level."""
NO_OUTPUT = 0
INFO = 1
DEBUG = 2
class AgentFramework(Enum):
NONE = 0
SB3 = 1
RLLIB = 2
class RedAgentIdentifier(Enum):
A2C = 1
PPO = 2
HARDCODED = 3
RANDOM = 4
class ActionType(Enum):
"""Action type enumeration."""

View File

@@ -1,95 +0,0 @@
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# """The config class."""
# from dataclasses import dataclass
#
# from primaite.common.enums import ActionType
#
#
# @dataclass()
# class TrainingConfig:
# """Class to hold main config values."""
#
# # Generic
# agent_identifier: str # The Red Agent algo/class to be used
# action_type: ActionType # type of action to use (NODE/ACL/ANY)
# num_episodes: int # number of episodes to train over
# num_steps: int # number of steps in an episode
# time_delay: int # delay between steps (ms) - applies to generic agents only
# # file
# session_type: str # the session type to run (TRAINING or EVALUATION)
# load_agent: str # Determine whether to load an agent from file
# agent_load_file: str # File path and file name of agent if you're loading one in
#
# # Environment
# observation_space_high_value: int # The high value for the observation space
#
# # Reward values
# # Generic
# all_ok: int
# # Node Hardware State
# off_should_be_on: int
# off_should_be_resetting: int
# on_should_be_off: int
# on_should_be_resetting: int
# resetting_should_be_on: int
# resetting_should_be_off: int
# resetting: int
# # Node Software or Service State
# good_should_be_patching: int
# good_should_be_compromised: int
# good_should_be_overwhelmed: int
# patching_should_be_good: int
# patching_should_be_compromised: int
# patching_should_be_overwhelmed: int
# patching: int
# compromised_should_be_good: int
# compromised_should_be_patching: int
# compromised_should_be_overwhelmed: int
# compromised: int
# overwhelmed_should_be_good: int
# overwhelmed_should_be_patching: int
# overwhelmed_should_be_compromised: int
# overwhelmed: int
# # Node File System State
# good_should_be_repairing: int
# good_should_be_restoring: int
# good_should_be_corrupt: int
# good_should_be_destroyed: int
# repairing_should_be_good: int
# repairing_should_be_restoring: int
# repairing_should_be_corrupt: int
# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
#
# repairing: int
# restoring_should_be_good: int
# restoring_should_be_repairing: int
# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
#
# restoring_should_be_destroyed: int
# restoring: int
# corrupt_should_be_good: int
# corrupt_should_be_repairing: int
# corrupt_should_be_restoring: int
# corrupt_should_be_destroyed: int
# corrupt: int
# destroyed_should_be_good: int
# destroyed_should_be_repairing: int
# destroyed_should_be_restoring: int
# destroyed_should_be_corrupt: int
# destroyed: int
# scanning: int
# # IER status
# red_ier_running: int
# green_ier_blocked: int
#
# # Patching / Reset
# os_patching_duration: int # The time taken to patch the OS
# node_reset_duration: int # The time taken to reset a node (hardware)
# node_booting_duration = 0 # The Time taken to turn on the node
# node_shutdown_duration = 0 # The time taken to turn off the node
# service_patching_duration: int # The time taken to patch a service
# file_system_repairing_limit: int # The time take to repair a file
# file_system_restoring_limit: int # The time take to restore a file
# file_system_scanning_limit: int # The time taken to scan the file system
# # Patching / Reset
#

View File

@@ -2,6 +2,8 @@
from pathlib import Path
from typing import Final
import networkx
from primaite import USERS_CONFIG_DIR, getLogger
_LOGGER = getLogger(__name__)
@@ -9,6 +11,12 @@ _LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
# class LayDownConfig:
# network: networkx.Graph
# POL
# EIR
# ACL
def ddos_basic_one_config_path() -> Path:
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.

View File

@@ -1,4 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Union, Optional
@@ -6,7 +8,8 @@ from typing import Any, Dict, Final, Union, Optional
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import ActionType
from primaite.common.enums import ActionType, RedAgentIdentifier, \
AgentFramework, SessionType
_LOGGER = getLogger(__name__)
@@ -16,10 +19,11 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
@dataclass()
class TrainingConfig:
"""The Training Config class."""
agent_framework: AgentFramework = AgentFramework.SB3
"The agent framework."
# Generic
agent_identifier: str = "STABLE_BASELINES3_A2C"
"The Red Agent algo/class to be used."
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
"The red agent/algo class."
action_type: ActionType = ActionType.ANY
"The ActionType to use."
@@ -38,8 +42,8 @@ class TrainingConfig:
"The delay between steps (ms). Applies to generic agents only."
# file
session_type: str = "TRAINING"
"the session type to run (TRAINING or EVALUATION)"
session_type: SessionType = SessionType.TRAINING
"The type of PrimAITE session to run."
load_agent: str = False
"Determine whether to load an agent from file."
@@ -137,6 +141,24 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system."
@classmethod
def from_dict(
cls,
config_dict: Dict[str, Union[str, int, bool]]
) -> TrainingConfig:
field_enum_map = {
"agent_framework": AgentFramework,
"red_agent_identifier": RedAgentIdentifier,
"action_type": ActionType,
"session_type": SessionType
}
for field, enum_class in field_enum_map.items():
if field in config_dict:
config_dict[field] = enum_class[field]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
"""
Serialise the ``TrainingConfig`` as dict.
@@ -196,10 +218,8 @@ def load(file_path: Union[str, Path],
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
# Convert values to Enums
config["action_type"] = ActionType[config["action_type"]]
try:
return TrainingConfig(**config)
return TrainingConfig.from_dict(**config)
except TypeError as e:
msg = (
f"Error when creating an instance of {TrainingConfig} "
@@ -214,22 +234,30 @@ def load(file_path: Union[str, Path],
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
num_steps: int = 256,
action_type: str = "ANY"
agent_framework: AgentFramework = AgentFramework.SB3,
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param red_agent_identifier: The red agent identifier to use as legacy
training configs don't have red_agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:return: The converted training config dict.
"""
config_dict = {
"num_steps": num_steps,
"action_type": action_type
"agent_framework": agent_framework.name,
"red_agent_identifier": red_agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps
}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
@@ -246,7 +274,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
:return: The mapped key.
"""
key_mapping = {
"agentIdentifier": "agent_identifier",
"agentIdentifier": None,
"numEpisodes": "num_episodes",
"timeDelay": "time_delay",
"configFilename": None,

View File

@@ -0,0 +1,216 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Final, Optional, Union
from uuid import uuid4
from primaite import getLogger, SESSIONS_DIR
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
return session_path
class PrimaiteSession:
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
auto: bool = True
):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._auto: Final[bool] = auto
self._uuid: str = str(uuid4())
self._session_timestamp: Final[datetime] = datetime.now()
self._session_path: Final[Path] = _get_session_path(
self._session_timestamp
)
self._timestamp_str: Final[str] = self._session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
self._metadata_path = self._session_path / "session_metadata.json"
self._env = None
self._training_config = None
self._can_learn: bool = False
_LOGGER.debug("")
if self._auto:
self.setup()
self.learn()
@property
def uuid(self):
"""The session UUID."""
return self._uuid
def _setup_primaite_env(self, transaction_list: Optional[list] = None):
if not transaction_list:
transaction_list = []
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=transaction_list,
session_path=self._session_path,
timestamp_str=self._timestamp_str
)
self._training_config: TrainingConfig = self._env.training_config
def _write_session_metadata_file(self):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_dir`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self._uuid,
"start_datetime": self._session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": self._env.training_config.to_dict(
json_serializable=True
),
"lay_down_config": self._env.lay_down_config,
},
}
_LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_dir`` directory
with the following key/value pairs:
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
"""
with open(self._metadata_path, "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._env.episode_count
metadata_dict["total_time_steps"] = self._env.total_step_count
_LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def setup(self):
self._setup_primaite_env()
self._can_learn = True
pass
def learn(
self,
time_steps: Optional[int],
episodes: Optional[int],
iterations: Optional[int],
**kwargs
):
if self._can_learn:
# Run environment against an agent
if self._training_config.agent_identifier == "GENERIC":
run_generic(env=env, config_values=config_values)
elif self._training_config == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif self._training_config == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Updating Session Metadata file...")
_update_session_metadata_file(session_dir=session_dir, env=env)
print("Finished")
_LOGGER.debug("Finished")
def evaluate(
self,
time_steps: Optional[int],
episodes: Optional[int],
**kwargs
):
pass
def export(self):
pass
@classmethod
def import_agent(
cls,
gent_path: str,
training_config_path: str,
lay_down_config_path: str
) -> PrimaiteSession:
session = PrimaiteSession(training_config_path, lay_down_config_path)
# Reset the UUID
session._uuid = ""
return session