#915 - Created app dirs and set as constants in the top-level init.
- renamed _config_values_main to training_config.py and renamed the ConfigValuesMain class to TrainingConfig. Moved training_config.py to src/primaite/config/training_config.py - Renamed all training config yaml file keys to make creating an instance of TrainingConfig easier. Moved action_type and num_steps over to the training config. - Decoupled the training config and lay down config. - Refactored main.py so that it can be ran from CLI and can take a training config path and a lay down config path. - refactored all outputs so that they save to the session dir. - Added some necessary setup scripts that handle creating app dirs, fronting example config files to the user, fronting demo notebooks to the user, performing clean-up in between installations etc. - Added functions that attempt to retrieve the file path of users example config files that have been fronted by the primaite setup. - Added logging config and a getLogger function in the top-level init. - Refactored all logs entries logged to use a logger using the primaite logging config. - Added basic typer CLI for doing things like setup, viewing logs, viewing primaite version, running a basic session. - Updated test to use new features and config structures. - Began updating docs. More to do here.
This commit is contained in:
@@ -1,29 +1,41 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Primaite - main (harness) module.
|
||||
The main PrimAITE session runner module.
|
||||
|
||||
Coding Standards: PEP 8
|
||||
TODO: This will eventually be refactored out into a proper Session class.
|
||||
TODO: The passing about of session_dir and timestamp_str is temporary and
|
||||
will be cleaned up once we move to a proper Session class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os.path
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Union
|
||||
|
||||
import yaml
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite.common.config_values_main import ConfigValuesMain
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.config.training_config import TrainingConfig, \
|
||||
main_training_config_path
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
# FUNCTIONS #
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run_generic():
|
||||
"""Run against a generic agent."""
|
||||
def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
"""
|
||||
Run against a generic agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
"""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
env.reset()
|
||||
for step in range(0, config_values.num_steps):
|
||||
@@ -47,9 +59,24 @@ def run_generic():
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo():
|
||||
"""Run against a stable_baselines3 PPO agent."""
|
||||
if config_values.load_agent == True:
|
||||
def run_stable_baselines3_ppo(
|
||||
env: Primaite,
|
||||
config_values: TrainingConfig,
|
||||
session_path: Path,
|
||||
timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
@@ -62,30 +89,44 @@ def run_stable_baselines3_ppo():
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_a2c():
|
||||
"""Run against a stable_baselines3 A2C agent."""
|
||||
if config_values.load_agent == True:
|
||||
def run_stable_baselines3_a2c(
|
||||
env: Primaite,
|
||||
config_values: TrainingConfig,
|
||||
session_path: Path, timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 A2C agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = A2C.load(
|
||||
config_values.agent_load_file,
|
||||
@@ -98,284 +139,151 @@ def run_stable_baselines3_a2c():
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def save_agent(_agent):
|
||||
"""Persist an agent (only works for stable baselines3 agents at present)."""
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Persist an agent.
|
||||
|
||||
try:
|
||||
path = "outputs/agents/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/agents/agent_saved_" + time
|
||||
_agent.save(filename)
|
||||
logging.info("Trained agent saved as " + filename)
|
||||
except Exception:
|
||||
logging.error("Could not save agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
Only works for stable baselines3 agents at present.
|
||||
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if not isinstance(agent, OnPolicyAlgorithm):
|
||||
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
_LOGGER.error(msg)
|
||||
else:
|
||||
filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
agent.save(filepath)
|
||||
_LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
|
||||
|
||||
def configure_logging():
|
||||
"""Configures logging."""
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
filename = "logs/app_" + time + ".log"
|
||||
path = "logs/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
logging.basicConfig(
|
||||
filename=filename,
|
||||
filemode="w",
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%d-%b-%y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
except Exception:
|
||||
print("ERROR: Could not start logging")
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_dir
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
def load_config_values():
|
||||
"""Loads the config values from the main config file into a config object."""
|
||||
try:
|
||||
# Generic
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = (
|
||||
"config/" + config_data["configFilename"]
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path]
|
||||
):
|
||||
"""Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
"""
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
|
||||
session_timestamp: Final[datetime] = datetime.now()
|
||||
session_path = _get_session_path(session_timestamp)
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
print(f"The output directory for this session is: {session_path}")
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the Primaite environment
|
||||
env = Primaite(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
config_values = env.config_values
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic(env=env, config_values=config_values)
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
config_values.session_type = config_data["sessionType"]
|
||||
config_values.load_agent = bool(config_data["loadAgent"])
|
||||
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||
# Environment
|
||||
config_values.observation_space_high_value = int(
|
||||
config_data["observationSpaceHighValue"]
|
||||
)
|
||||
# Reward values
|
||||
# Generic
|
||||
config_values.all_ok = int(config_data["allOk"])
|
||||
# Node Hardware State
|
||||
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
|
||||
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
|
||||
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
|
||||
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
|
||||
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
|
||||
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
|
||||
config_values.resetting = int(config_data["resetting"])
|
||||
# Node Software or Service State
|
||||
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
|
||||
config_values.good_should_be_compromised = int(
|
||||
config_data["goodShouldBeCompromised"]
|
||||
)
|
||||
config_values.good_should_be_overwhelmed = int(
|
||||
config_data["goodShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
|
||||
config_values.patching_should_be_compromised = int(
|
||||
config_data["patchingShouldBeCompromised"]
|
||||
)
|
||||
config_values.patching_should_be_overwhelmed = int(
|
||||
config_data["patchingShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching = int(config_data["patching"])
|
||||
config_values.compromised_should_be_good = int(
|
||||
config_data["compromisedShouldBeGood"]
|
||||
)
|
||||
config_values.compromised_should_be_patching = int(
|
||||
config_data["compromisedShouldBePatching"]
|
||||
)
|
||||
config_values.compromised_should_be_overwhelmed = int(
|
||||
config_data["compromisedShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.compromised = int(config_data["compromised"])
|
||||
config_values.overwhelmed_should_be_good = int(
|
||||
config_data["overwhelmedShouldBeGood"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_patching = int(
|
||||
config_data["overwhelmedShouldBePatching"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_compromised = int(
|
||||
config_data["overwhelmedShouldBeCompromised"]
|
||||
)
|
||||
config_values.overwhelmed = int(config_data["overwhelmed"])
|
||||
# Node File System State
|
||||
config_values.good_should_be_repairing = int(
|
||||
config_data["goodShouldBeRepairing"]
|
||||
)
|
||||
config_values.good_should_be_restoring = int(
|
||||
config_data["goodShouldBeRestoring"]
|
||||
)
|
||||
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
|
||||
config_values.good_should_be_destroyed = int(
|
||||
config_data["goodShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing_should_be_good = int(
|
||||
config_data["repairingShouldBeGood"]
|
||||
)
|
||||
config_values.repairing_should_be_restoring = int(
|
||||
config_data["repairingShouldBeRestoring"]
|
||||
)
|
||||
config_values.repairing_should_be_corrupt = int(
|
||||
config_data["repairingShouldBeCorrupt"]
|
||||
)
|
||||
config_values.repairing_should_be_destroyed = int(
|
||||
config_data["repairingShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing = int(config_data["repairing"])
|
||||
config_values.restoring_should_be_good = int(
|
||||
config_data["restoringShouldBeGood"]
|
||||
)
|
||||
config_values.restoring_should_be_repairing = int(
|
||||
config_data["restoringShouldBeRepairing"]
|
||||
)
|
||||
config_values.restoring_should_be_corrupt = int(
|
||||
config_data["restoringShouldBeCorrupt"]
|
||||
)
|
||||
config_values.restoring_should_be_destroyed = int(
|
||||
config_data["restoringShouldBeDestroyed"]
|
||||
)
|
||||
config_values.restoring = int(config_data["restoring"])
|
||||
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
|
||||
config_values.corrupt_should_be_repairing = int(
|
||||
config_data["corruptShouldBeRepairing"]
|
||||
)
|
||||
config_values.corrupt_should_be_restoring = int(
|
||||
config_data["corruptShouldBeRestoring"]
|
||||
)
|
||||
config_values.corrupt_should_be_destroyed = int(
|
||||
config_data["corruptShouldBeDestroyed"]
|
||||
)
|
||||
config_values.corrupt = int(config_data["corrupt"])
|
||||
config_values.destroyed_should_be_good = int(
|
||||
config_data["destroyedShouldBeGood"]
|
||||
)
|
||||
config_values.destroyed_should_be_repairing = int(
|
||||
config_data["destroyedShouldBeRepairing"]
|
||||
)
|
||||
config_values.destroyed_should_be_restoring = int(
|
||||
config_data["destroyedShouldBeRestoring"]
|
||||
)
|
||||
config_values.destroyed_should_be_corrupt = int(
|
||||
config_data["destroyedShouldBeCorrupt"]
|
||||
)
|
||||
config_values.destroyed = int(config_data["destroyed"])
|
||||
config_values.scanning = int(config_data["scanning"])
|
||||
# IER status
|
||||
config_values.red_ier_running = int(config_data["redIerRunning"])
|
||||
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
|
||||
# Patching / Reset durations
|
||||
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
|
||||
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
|
||||
config_values.service_patching_duration = int(
|
||||
config_data["servicePatchingDuration"]
|
||||
)
|
||||
config_values.file_system_repairing_limit = int(
|
||||
config_data["fileSystemRepairingLimit"]
|
||||
)
|
||||
config_values.file_system_restoring_limit = int(
|
||||
config_data["fileSystemRestoringLimit"]
|
||||
)
|
||||
config_values.file_system_scanning_limit = int(
|
||||
config_data["fileSystemScanningLimit"]
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
logging.info("Training agent: " + config_values.agent_identifier)
|
||||
logging.info(
|
||||
"Training environment config: " + config_values.config_filename_use_case
|
||||
print("Session finished")
|
||||
_LOGGER.debug("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
_LOGGER.debug("Saving transaction logs...")
|
||||
|
||||
write_transaction_to_file(
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Finished")
|
||||
_LOGGER.debug("Finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error(
|
||||
"Please provide a training config file using the --tc " "argument"
|
||||
)
|
||||
logging.info(
|
||||
"Training cycle has " + str(config_values.num_episodes) + " episodes"
|
||||
if not args.ldc:
|
||||
_LOGGER.error(
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logging.error("Could not save load config data")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
# MAIN PROCESS #
|
||||
|
||||
# Starting point
|
||||
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
|
||||
# Configure logging
|
||||
configure_logging()
|
||||
|
||||
# Open the main config file
|
||||
try:
|
||||
config_file_main = open("config/config_main.yaml", "r")
|
||||
config_data = yaml.safe_load(config_file_main)
|
||||
# Create a config class
|
||||
config_values = ConfigValuesMain()
|
||||
# Load in config data
|
||||
load_config_values()
|
||||
except Exception:
|
||||
logging.error("Could not load main config")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the Primaite environment
|
||||
# try:
|
||||
env = Primaite(config_values, transaction_list)
|
||||
# logging.info("PrimAITE environment created")
|
||||
# except Exception:
|
||||
# logging.error("Could not create PrimAITE environment")
|
||||
# logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c()
|
||||
|
||||
print("Session finished")
|
||||
logging.info("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
logging.info("Saving transaction logs...")
|
||||
|
||||
write_transaction_to_file(transaction_list)
|
||||
|
||||
config_file_main.close()
|
||||
|
||||
print("Finished")
|
||||
logging.info("Finished")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
Reference in New Issue
Block a user