temp commit
This commit is contained in:
@@ -22,46 +22,64 @@ The environment config file consists of the following attributes:
|
||||
|
||||
* **agent_identifier** [enum]
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* GENERIC - Where a user developed agent is to be used
|
||||
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
||||
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
||||
* GENERIC - Where a user developed agent is to be used
|
||||
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
||||
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
||||
|
||||
* **agent_framework** [enum]
|
||||
|
||||
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
|
||||
|
||||
* NONE - Where a user developed agent is to be used
|
||||
* SB3 - Stable Baselines3
|
||||
* RLLIB - Ray RLlib.
|
||||
|
||||
* **red_agent_identifier**
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* A2C - Advantage Actor Critic
|
||||
* PPO - Proximal Policy Optimization
|
||||
* HARDCODED - A custom built deterministic agent
|
||||
* RANDOM - A Stochastic random agent
|
||||
|
||||
|
||||
* **action_type** [enum]
|
||||
|
||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||
|
||||
|
||||
* **num_episodes** [int]
|
||||
|
||||
This defines the number of episodes that the agent will train or be evaluated over.
|
||||
This defines the number of episodes that the agent will train or be evaluated over.
|
||||
|
||||
* **num_steps** [int]
|
||||
|
||||
Determines the number of steps to run in each episode of the session
|
||||
Determines the number of steps to run in each episode of the session
|
||||
|
||||
|
||||
* **time_delay** [int]
|
||||
|
||||
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
|
||||
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
|
||||
|
||||
|
||||
* **session_type** [text]
|
||||
|
||||
Type of session to be run (TRAINING or EVALUATION)
|
||||
Type of session to be run (TRAINING, EVALUATION, or BOTH)
|
||||
|
||||
* **load_agent** [bool]
|
||||
|
||||
Determine whether to load an agent from file
|
||||
Determine whether to load an agent from file
|
||||
|
||||
* **agent_load_file** [text]
|
||||
|
||||
File path and file name of agent if you're loading one in
|
||||
File path and file name of agent if you're loading one in
|
||||
|
||||
* **observation_space_high_value** [int]
|
||||
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
|
||||
|
||||
**Reward-Based Config Values**
|
||||
|
||||
@@ -69,95 +87,95 @@ Rewards are calculated based on the difference between the current state and ref
|
||||
|
||||
* **Generic [all_ok]** [int]
|
||||
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
|
||||
|
||||
* **Node Hardware State [off_should_be_on]** [int]
|
||||
|
||||
The score to give when the node should be on, but is off
|
||||
The score to give when the node should be on, but is off
|
||||
|
||||
* **Node Hardware State [off_should_be_resetting]** [int]
|
||||
|
||||
The score to give when the node should be resetting, but is off
|
||||
The score to give when the node should be resetting, but is off
|
||||
|
||||
* **Node Hardware State [on_should_be_off]** [int]
|
||||
|
||||
The score to give when the node should be off, but is on
|
||||
The score to give when the node should be off, but is on
|
||||
|
||||
* **Node Hardware State [on_should_be_resetting]** [int]
|
||||
|
||||
The score to give when the node should be resetting, but is on
|
||||
The score to give when the node should be resetting, but is on
|
||||
|
||||
* **Node Hardware State [resetting_should_be_on]** [int]
|
||||
|
||||
The score to give when the node should be on, but is resetting
|
||||
The score to give when the node should be on, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting_should_be_off]** [int]
|
||||
|
||||
The score to give when the node should be off, but is resetting
|
||||
The score to give when the node should be off, but is resetting
|
||||
|
||||
* **Node Hardware State [resetting]** [int]
|
||||
|
||||
The score to give when the node is resetting
|
||||
The score to give when the node is resetting
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_patching]** [int]
|
||||
|
||||
The score to give when the state should be patching, but is good
|
||||
The score to give when the state should be patching, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_compromised]** [int]
|
||||
|
||||
The score to give when the state should be compromised, but is good
|
||||
The score to give when the state should be compromised, but is good
|
||||
|
||||
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_good]** [int]
|
||||
|
||||
The score to give when the state should be good, but is patching
|
||||
The score to give when the state should be good, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_compromised]** [int]
|
||||
|
||||
The score to give when the state should be compromised, but is patching
|
||||
The score to give when the state should be compromised, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching]** [int]
|
||||
|
||||
The score to give when the state is patching
|
||||
The score to give when the state is patching
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_good]** [int]
|
||||
|
||||
The score to give when the state should be good, but is compromised
|
||||
The score to give when the state should be good, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_patching]** [int]
|
||||
|
||||
The score to give when the state should be patching, but is compromised
|
||||
The score to give when the state should be patching, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int]
|
||||
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised]** [int]
|
||||
|
||||
The score to give when the state is compromised
|
||||
The score to give when the state is compromised
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int]
|
||||
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int]
|
||||
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int]
|
||||
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed]** [int]
|
||||
|
||||
The score to give when the state is overwhelmed
|
||||
The score to give when the state is overwhelmed
|
||||
|
||||
* **Node File System State [good_should_be_repairing]** [int]
|
||||
|
||||
@@ -261,37 +279,37 @@ Rewards are calculated based on the difference between the current state and ref
|
||||
|
||||
* **IER Status [red_ier_running]** [int]
|
||||
|
||||
The score to give when a red agent IER is permitted to run
|
||||
The score to give when a red agent IER is permitted to run
|
||||
|
||||
* **IER Status [green_ier_blocked]** [int]
|
||||
|
||||
The score to give when a green agent IER is prevented from running
|
||||
The score to give when a green agent IER is prevented from running
|
||||
|
||||
**Patching / Reset Durations**
|
||||
|
||||
* **os_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching an Operating System
|
||||
The number of steps to take when patching an Operating System
|
||||
|
||||
* **node_reset_duration** [int]
|
||||
|
||||
The number of steps to take when resetting a node's hardware state
|
||||
The number of steps to take when resetting a node's hardware state
|
||||
|
||||
* **service_patching_duration** [int]
|
||||
|
||||
The number of steps to take when patching a service
|
||||
The number of steps to take when patching a service
|
||||
|
||||
* **file_system_repairing_limit** [int]:
|
||||
|
||||
The number of steps to take when repairing the file system
|
||||
The number of steps to take when repairing the file system
|
||||
|
||||
* **file_system_restoring_limit** [int]
|
||||
|
||||
The number of steps to take when restoring the file system
|
||||
The number of steps to take when restoring the file system
|
||||
|
||||
* **file_system_scanning_limit** [int]
|
||||
|
||||
The number of steps to take when scanning the file system
|
||||
The number of steps to take when scanning the file system
|
||||
|
||||
The Lay Down Config
|
||||
*******************
|
||||
@@ -300,22 +318,22 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **itemType: ACTIONS** [enum]
|
||||
|
||||
Determines whether a NODE or ACL action space format is adopted for the session
|
||||
Determines whether a NODE or ACL action space format is adopted for the session
|
||||
|
||||
* **itemType: OBSERVATION_SPACE** [dict]
|
||||
|
||||
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||
components is is :py:mod:'primaite.environment.observations'.
|
||||
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||
components is is :py:mod:'primaite.environment.observations'.
|
||||
|
||||
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||
component while it is being initialised.
|
||||
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||
component while it is being initialised.
|
||||
|
||||
This example illustrates the correct format for the observation space config item
|
||||
This example illustrates the correct format for the observation space config item
|
||||
|
||||
.. code-block::yaml
|
||||
|
||||
- itemType: OBSERVATION_SPACE
|
||||
- item_type: OBSERVATION_SPACE
|
||||
components:
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
options:
|
||||
@@ -328,15 +346,15 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: PORTS** [int]
|
||||
|
||||
Provides a list of ports modelled in this session
|
||||
Provides a list of ports modelled in this session
|
||||
|
||||
* **item_type: SERVICES** [freetext]
|
||||
|
||||
Provides a list of services modelled in this session
|
||||
Provides a list of services modelled in this session
|
||||
|
||||
* **item_type: NODE**
|
||||
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
@@ -355,7 +373,7 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: LINK**
|
||||
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **name** [freetext]: Human-readable name of the component
|
||||
@@ -365,7 +383,7 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: GREEN_IER**
|
||||
|
||||
Defines a green agent Information Exchange Requirement (IER). It should consist of:
|
||||
Defines a green agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
@@ -379,7 +397,7 @@ The lay down config file consists of the following attributes:
|
||||
|
||||
* **item_type: RED_IER**
|
||||
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
* **start_step** [int]: The start step (in the episode) for this IER to begin
|
||||
|
||||
@@ -32,6 +32,7 @@ dependencies = [
|
||||
"numpy==1.23.5",
|
||||
"platformdirs==3.5.1",
|
||||
"PyYAML==6.0",
|
||||
"ray[rllib]==2.2.0",
|
||||
"stable-baselines3==1.6.2",
|
||||
"typer[all]==0.9.0"
|
||||
]
|
||||
|
||||
36
src/primaite/agents/agent_abc.py
Normal file
36
src/primaite/agents/agent_abc.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
class AgentABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: Primaite):
|
||||
self._env: Primaite = env
|
||||
self._agent = None
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
177
src/primaite/agents/rllib.py
Normal file
177
src/primaite/agents/rllib.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import glob
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite.config import training_config
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
class DLFramework(Enum):
|
||||
"""The DL Frameworks enumeration."""
|
||||
TF = "tf"
|
||||
TF2 = "tf2"
|
||||
TORCH = "torch"
|
||||
|
||||
|
||||
def env_creator(env_config):
|
||||
training_config_path = env_config["training_config_path"]
|
||||
lay_down_config_path = env_config["lay_down_config_path"]
|
||||
return Primaite(training_config_path, lay_down_config_path, [])
|
||||
|
||||
|
||||
def get_ppo_config(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
framework: Optional[DLFramework] = DLFramework.TORCH
|
||||
) -> PPOConfig():
|
||||
# Register environment
|
||||
register_env("primaite", env_creator)
|
||||
|
||||
# Setup PPO
|
||||
config = PPOConfig()
|
||||
|
||||
config_values = training_config.load(training_config_path)
|
||||
|
||||
# Setup our config object to use our environment
|
||||
config.environment(
|
||||
env="primaite",
|
||||
env_config=dict(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path
|
||||
)
|
||||
)
|
||||
|
||||
env_config = config_values
|
||||
action_type = env_config.action_type
|
||||
red_agent = env_config.red_agent_identifier
|
||||
|
||||
if red_agent == "RANDOM" and action_type == "NODE":
|
||||
config.training(
|
||||
train_batch_size=6000, lr=5e-5
|
||||
) # number of steps in a training iteration
|
||||
elif red_agent == "RANDOM" and action_type != "NODE":
|
||||
config.training(train_batch_size=6000, lr=5e-5)
|
||||
elif red_agent == "CONFIG" and action_type == "NODE":
|
||||
config.training(train_batch_size=400, lr=5e-5)
|
||||
elif red_agent == "CONFIG" and action_type != "NONE":
|
||||
config.training(train_batch_size=500, lr=5e-5)
|
||||
else:
|
||||
config.training(train_batch_size=500, lr=5e-5)
|
||||
|
||||
# Decide if you want torch or tensorflow DL framework. Default is "tf"
|
||||
config.framework(framework=framework.value)
|
||||
|
||||
# Set the log level to DEBUG, INFO, WARN, or ERROR
|
||||
config.debugging(seed=415, log_level="ERROR")
|
||||
|
||||
# Setup evaluation
|
||||
# Explicitly set "explore"=False to override default
|
||||
# config.evaluation(
|
||||
# evaluation_interval=100,
|
||||
# evaluation_duration=20,
|
||||
# # evaluation_duration_unit="timesteps",) #default episodes
|
||||
# evaluation_config={"explore": False},
|
||||
# )
|
||||
|
||||
# Setup sampling rollout workers
|
||||
config.rollouts(
|
||||
num_rollout_workers=4,
|
||||
num_envs_per_worker=1,
|
||||
horizon=128, # num parralel workiers
|
||||
) # max num steps in an episode
|
||||
|
||||
config.build() # Build config
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def train(
|
||||
num_iterations: int,
|
||||
config: Optional[PPOConfig] = None,
|
||||
algo: Optional[Algorithm] = None
|
||||
):
|
||||
"""
|
||||
|
||||
Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint)
|
||||
"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if algo is None:
|
||||
algo = config.build()
|
||||
elif config is None:
|
||||
config = algo.get_config()
|
||||
|
||||
print(f"Algorithm type: {type(algo)}")
|
||||
|
||||
# iterations are not the same as episodes.
|
||||
for i in range(num_iterations):
|
||||
result = algo.train()
|
||||
# # Save every 10 iterations or after last iteration in training
|
||||
# if (i % 100 == 0) or (i == num_iterations - 1):
|
||||
print(
|
||||
f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}")
|
||||
# save checkpoint file
|
||||
checkpoint_file = algo.save("./")
|
||||
print(f"Checkpoint saved at {checkpoint_file}")
|
||||
|
||||
# convert num_iterations to num_episodes
|
||||
num_episodes = len(
|
||||
result["hist_stats"]["episode_lengths"]) * num_iterations
|
||||
# convert num_iterations to num_timesteps
|
||||
num_timesteps = sum(
|
||||
result["hist_stats"]["episode_lengths"] * num_iterations)
|
||||
# calculate number of wins
|
||||
|
||||
# train time
|
||||
print(f"Training took {time.time() - start_time:.2f} seconds")
|
||||
print(
|
||||
f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}")
|
||||
return result
|
||||
|
||||
|
||||
def load_model_from_checkpoint(config, checkpoint=None):
|
||||
# create an empty Algorithm
|
||||
algo = config.build()
|
||||
|
||||
if checkpoint is None:
|
||||
# Get the checkpoint with the highest iteration number
|
||||
checkpoint = get_most_recent_checkpoint(config)
|
||||
|
||||
# restore the agent from the checkpoint
|
||||
algo.restore(checkpoint)
|
||||
|
||||
return algo
|
||||
|
||||
|
||||
def get_most_recent_checkpoint(config):
|
||||
"""
|
||||
Get the most recent checkpoint for specified action type, red agent and algorithm
|
||||
"""
|
||||
|
||||
env_config = list(config.env_config.values())[0]
|
||||
action_type = env_config.action_type
|
||||
red_agent = env_config.red_agent_identifier
|
||||
algo_name = config.algo_class.__name__
|
||||
|
||||
# Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model
|
||||
relevant_checkpoints = glob.glob(
|
||||
f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*"
|
||||
)
|
||||
checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints]
|
||||
max_checkpoint = str(max(checkpoint_numbers))
|
||||
checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint
|
||||
checkpoint = (
|
||||
relevant_checkpoints[0].split("_")[0]
|
||||
+ "_"
|
||||
+ checkpoint_number_to_use
|
||||
+ "/rllib_checkpoint.json"
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
28
src/primaite/agents/sb3.py
Normal file
28
src/primaite/agents/sb3.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# from typing import Optional
|
||||
#
|
||||
# from primaite.agents.agent_abc import AgentABC
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
#
|
||||
#
|
||||
# class SB3PPO(AgentABC):
|
||||
# def __init__(self, env: Primaite):
|
||||
# super().__init__(env)
|
||||
#
|
||||
# def _setup(self):
|
||||
# if self._env.training_config
|
||||
# pass
|
||||
#
|
||||
# def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
# pass
|
||||
#
|
||||
# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
# pass
|
||||
#
|
||||
# def load(self):
|
||||
# pass
|
||||
#
|
||||
# def save(self):
|
||||
# pass
|
||||
#
|
||||
# def export(self):
|
||||
# pass
|
||||
@@ -79,6 +79,33 @@ class Protocol(Enum):
|
||||
NONE = 7
|
||||
|
||||
|
||||
class SessionType(Enum):
|
||||
"The type of PrimAITE Session to be run."
|
||||
TRAINING = 1
|
||||
EVALUATION = 2
|
||||
BOTH = 3
|
||||
|
||||
|
||||
class VerboseLevel(Enum):
|
||||
"""PrimAITE Session Output verbose level."""
|
||||
NO_OUTPUT = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
|
||||
class AgentFramework(Enum):
|
||||
NONE = 0
|
||||
SB3 = 1
|
||||
RLLIB = 2
|
||||
|
||||
|
||||
class RedAgentIdentifier(Enum):
|
||||
A2C = 1
|
||||
PPO = 2
|
||||
HARDCODED = 3
|
||||
RANDOM = 4
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Action type enumeration."""
|
||||
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# """The config class."""
|
||||
# from dataclasses import dataclass
|
||||
#
|
||||
# from primaite.common.enums import ActionType
|
||||
#
|
||||
#
|
||||
# @dataclass()
|
||||
# class TrainingConfig:
|
||||
# """Class to hold main config values."""
|
||||
#
|
||||
# # Generic
|
||||
# agent_identifier: str # The Red Agent algo/class to be used
|
||||
# action_type: ActionType # type of action to use (NODE/ACL/ANY)
|
||||
# num_episodes: int # number of episodes to train over
|
||||
# num_steps: int # number of steps in an episode
|
||||
# time_delay: int # delay between steps (ms) - applies to generic agents only
|
||||
# # file
|
||||
# session_type: str # the session type to run (TRAINING or EVALUATION)
|
||||
# load_agent: str # Determine whether to load an agent from file
|
||||
# agent_load_file: str # File path and file name of agent if you're loading one in
|
||||
#
|
||||
# # Environment
|
||||
# observation_space_high_value: int # The high value for the observation space
|
||||
#
|
||||
# # Reward values
|
||||
# # Generic
|
||||
# all_ok: int
|
||||
# # Node Hardware State
|
||||
# off_should_be_on: int
|
||||
# off_should_be_resetting: int
|
||||
# on_should_be_off: int
|
||||
# on_should_be_resetting: int
|
||||
# resetting_should_be_on: int
|
||||
# resetting_should_be_off: int
|
||||
# resetting: int
|
||||
# # Node Software or Service State
|
||||
# good_should_be_patching: int
|
||||
# good_should_be_compromised: int
|
||||
# good_should_be_overwhelmed: int
|
||||
# patching_should_be_good: int
|
||||
# patching_should_be_compromised: int
|
||||
# patching_should_be_overwhelmed: int
|
||||
# patching: int
|
||||
# compromised_should_be_good: int
|
||||
# compromised_should_be_patching: int
|
||||
# compromised_should_be_overwhelmed: int
|
||||
# compromised: int
|
||||
# overwhelmed_should_be_good: int
|
||||
# overwhelmed_should_be_patching: int
|
||||
# overwhelmed_should_be_compromised: int
|
||||
# overwhelmed: int
|
||||
# # Node File System State
|
||||
# good_should_be_repairing: int
|
||||
# good_should_be_restoring: int
|
||||
# good_should_be_corrupt: int
|
||||
# good_should_be_destroyed: int
|
||||
# repairing_should_be_good: int
|
||||
# repairing_should_be_restoring: int
|
||||
# repairing_should_be_corrupt: int
|
||||
# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
|
||||
#
|
||||
# repairing: int
|
||||
# restoring_should_be_good: int
|
||||
# restoring_should_be_repairing: int
|
||||
# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
|
||||
#
|
||||
# restoring_should_be_destroyed: int
|
||||
# restoring: int
|
||||
# corrupt_should_be_good: int
|
||||
# corrupt_should_be_repairing: int
|
||||
# corrupt_should_be_restoring: int
|
||||
# corrupt_should_be_destroyed: int
|
||||
# corrupt: int
|
||||
# destroyed_should_be_good: int
|
||||
# destroyed_should_be_repairing: int
|
||||
# destroyed_should_be_restoring: int
|
||||
# destroyed_should_be_corrupt: int
|
||||
# destroyed: int
|
||||
# scanning: int
|
||||
# # IER status
|
||||
# red_ier_running: int
|
||||
# green_ier_blocked: int
|
||||
#
|
||||
# # Patching / Reset
|
||||
# os_patching_duration: int # The time taken to patch the OS
|
||||
# node_reset_duration: int # The time taken to reset a node (hardware)
|
||||
# node_booting_duration = 0 # The Time taken to turn on the node
|
||||
# node_shutdown_duration = 0 # The time taken to turn off the node
|
||||
# service_patching_duration: int # The time taken to patch a service
|
||||
# file_system_repairing_limit: int # The time take to repair a file
|
||||
# file_system_restoring_limit: int # The time take to restore a file
|
||||
# file_system_scanning_limit: int # The time taken to scan the file system
|
||||
# # Patching / Reset
|
||||
#
|
||||
@@ -2,6 +2,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import networkx
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -9,6 +11,12 @@ _LOGGER = getLogger(__name__)
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
# class LayDownConfig:
|
||||
# network: networkx.Graph
|
||||
# POL
|
||||
# EIR
|
||||
# ACL
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union, Optional
|
||||
@@ -6,7 +8,8 @@ from typing import Any, Dict, Final, Union, Optional
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import ActionType
|
||||
from primaite.common.enums import ActionType, RedAgentIdentifier, \
|
||||
AgentFramework, SessionType
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -16,10 +19,11 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The agent framework."
|
||||
|
||||
# Generic
|
||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||
"The Red Agent algo/class to be used."
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
|
||||
"The red agent/algo class."
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use."
|
||||
@@ -38,8 +42,8 @@ class TrainingConfig:
|
||||
"The delay between steps (ms). Applies to generic agents only."
|
||||
|
||||
# file
|
||||
session_type: str = "TRAINING"
|
||||
"the session type to run (TRAINING or EVALUATION)"
|
||||
session_type: SessionType = SessionType.TRAINING
|
||||
"The type of PrimAITE session to run."
|
||||
|
||||
load_agent: str = False
|
||||
"Determine whether to load an agent from file."
|
||||
@@ -137,6 +141,24 @@ class TrainingConfig:
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system."
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
config_dict: Dict[str, Union[str, int, bool]]
|
||||
) -> TrainingConfig:
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"red_agent_identifier": RedAgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType
|
||||
}
|
||||
|
||||
for field, enum_class in field_enum_map.items():
|
||||
if field in config_dict:
|
||||
config_dict[field] = enum_class[field]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
@@ -196,10 +218,8 @@ def load(file_path: Union[str, Path],
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
# Convert values to Enums
|
||||
config["action_type"] = ActionType[config["action_type"]]
|
||||
try:
|
||||
return TrainingConfig(**config)
|
||||
return TrainingConfig.from_dict(**config)
|
||||
except TypeError as e:
|
||||
msg = (
|
||||
f"Error when creating an instance of {TrainingConfig} "
|
||||
@@ -214,22 +234,30 @@ def load(file_path: Union[str, Path],
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
num_steps: int = 256,
|
||||
action_type: str = "ANY"
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param red_agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have red_agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"num_steps": num_steps,
|
||||
"action_type": action_type
|
||||
"agent_framework": agent_framework.name,
|
||||
"red_agent_identifier": red_agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_steps": num_steps
|
||||
}
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
@@ -246,7 +274,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
:return: The mapped key.
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": "agent_identifier",
|
||||
"agentIdentifier": None,
|
||||
"numEpisodes": "num_episodes",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
|
||||
216
src/primaite/primaite_session.py
Normal file
216
src/primaite/primaite_session.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_dir
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
auto: bool = True
|
||||
):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
|
||||
self._auto: Final[bool] = auto
|
||||
|
||||
self._uuid: str = str(uuid4())
|
||||
self._session_timestamp: Final[datetime] = datetime.now()
|
||||
self._session_path: Final[Path] = _get_session_path(
|
||||
self._session_timestamp
|
||||
)
|
||||
self._timestamp_str: Final[str] = self._session_timestamp.strftime(
|
||||
"%Y-%m-%d_%H-%M-%S")
|
||||
self._metadata_path = self._session_path / "session_metadata.json"
|
||||
|
||||
|
||||
self._env = None
|
||||
self._training_config = None
|
||||
self._can_learn: bool = False
|
||||
_LOGGER.debug("")
|
||||
|
||||
if self._auto:
|
||||
self.setup()
|
||||
self.learn()
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
"""The session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _setup_primaite_env(self, transaction_list: Optional[list] = None):
|
||||
if not transaction_list:
|
||||
transaction_list = []
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=transaction_list,
|
||||
session_path=self._session_path,
|
||||
timestamp_str=self._timestamp_str
|
||||
)
|
||||
self._training_config: TrainingConfig = self._env.training_config
|
||||
|
||||
def _write_session_metadata_file(self):
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_dir`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": self._uuid,
|
||||
"start_datetime": self._session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"total_episodes": None,
|
||||
"total_time_steps": None,
|
||||
"env": {
|
||||
"training_config": self._env.training_config.to_dict(
|
||||
json_serializable=True
|
||||
),
|
||||
"lay_down_config": self._env.lay_down_config,
|
||||
},
|
||||
}
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}")
|
||||
with open(self._metadata_path, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
def _update_session_metadata_file(self):
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_dir`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
"""
|
||||
with open(self._metadata_path, "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._env.episode_count
|
||||
metadata_dict["total_time_steps"] = self._env.total_step_count
|
||||
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}")
|
||||
with open(self._metadata_path, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
|
||||
def setup(self):
|
||||
self._setup_primaite_env()
|
||||
self._can_learn = True
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int],
|
||||
episodes: Optional[int],
|
||||
iterations: Optional[int],
|
||||
**kwargs
|
||||
):
|
||||
if self._can_learn:
|
||||
# Run environment against an agent
|
||||
if self._training_config.agent_identifier == "GENERIC":
|
||||
run_generic(env=env, config_values=config_values)
|
||||
elif self._training_config == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
elif self._training_config == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Session finished")
|
||||
_LOGGER.debug("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
write_transaction_to_file(
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Updating Session Metadata file...")
|
||||
_update_session_metadata_file(session_dir=session_dir, env=env)
|
||||
|
||||
print("Finished")
|
||||
_LOGGER.debug("Finished")
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int],
|
||||
episodes: Optional[int],
|
||||
**kwargs
|
||||
):
|
||||
pass
|
||||
|
||||
def export(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def import_agent(
|
||||
cls,
|
||||
gent_path: str,
|
||||
training_config_path: str,
|
||||
lay_down_config_path: str
|
||||
) -> PrimaiteSession:
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
# Reset the UUID
|
||||
session._uuid = ""
|
||||
|
||||
return session
|
||||
@@ -1,11 +1,20 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
|
||||
# Sets which agent algorithm framework will be used:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray[RLlib])
|
||||
# "NONE" (Custom Agent)
|
||||
agent_framework: RLLIB
|
||||
|
||||
# Sets which Red Agent algo/class will be used:
|
||||
# "PPO" (Proximal Policy Optimization)
|
||||
# "A2C" (Advantage Actor Critic)
|
||||
# "HARDCODED" (Custom Agent)
|
||||
# "RANDOM" (Random Action)
|
||||
red_agent_identifier: PPO
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
|
||||
Reference in New Issue
Block a user