313 lines
13 KiB
Python
313 lines
13 KiB
Python
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
|
"""
|
|
PRIMAITE - main (harness) module
|
|
|
|
Coding Standards: PEP 8
|
|
"""
|
|
|
|
from sys import exc_info
|
|
import time
|
|
import yaml
|
|
import os.path
|
|
import logging
|
|
from datetime import datetime
|
|
|
|
from environment.primaite import PRIMAITE
|
|
from transactions.transactions_to_file import write_transaction_to_file
|
|
from common.config_values_main import config_values_main
|
|
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
|
from stable_baselines3 import A2C
|
|
from stable_baselines3.common.env_checker import check_env
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
|
|
################################# FUNCTIONS ######################################
|
|
|
|
def run_generic():
|
|
"""
|
|
Run against a generic agent
|
|
"""
|
|
|
|
for episode in range(0, config_values.num_episodes):
|
|
for step in range(0, config_values.num_steps):
|
|
|
|
# Send the observation space to the agent to get an action
|
|
# TEMP - random action for now
|
|
# action = env.blue_agent_action(obs)
|
|
action = env.action_space.sample()
|
|
|
|
# Run the simulation step on the live environment
|
|
obs, reward, done, info = env.step(action)
|
|
|
|
# Break if done is True
|
|
if done:
|
|
break
|
|
|
|
# Introduce a delay between steps
|
|
time.sleep(config_values.time_delay / 1000)
|
|
|
|
# Reset the environment at the end of the episode
|
|
env.reset()
|
|
|
|
env.close()
|
|
|
|
|
|
def run_stable_baselines3_ppo():
|
|
"""
|
|
Run against a stable_baselines3 PPO agent
|
|
"""
|
|
|
|
if config_values.load_agent == True:
|
|
try:
|
|
agent = PPO.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps)
|
|
except:
|
|
print("ERROR: Could not load agent at location: " + config_values.agent_load_file)
|
|
logging.error("Could not load agent")
|
|
logging.error("Exception occured", exc_info=True)
|
|
else:
|
|
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
|
|
|
if config_values.session_type == "TRAINING":
|
|
# We're in a training session
|
|
print("Starting training session...")
|
|
logging.info("Starting training session...")
|
|
for episode in range(0, config_values.num_episodes):
|
|
agent.learn(total_timesteps=1)
|
|
save_agent(agent)
|
|
else:
|
|
# Default to being in an evaluation session
|
|
print("Starting evaluation session...")
|
|
logging.info("Starting evaluation session...")
|
|
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
|
|
|
env.close()
|
|
|
|
def run_stable_baselines3_a2c():
|
|
"""
|
|
Run against a stable_baselines3 A2C agent
|
|
"""
|
|
|
|
if config_values.load_agent == True:
|
|
try:
|
|
agent = A2C.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps)
|
|
except:
|
|
print("ERROR: Could not load agent at location: " + config_values.agent_load_file)
|
|
logging.error("Could not load agent")
|
|
logging.error("Exception occured", exc_info=True)
|
|
else:
|
|
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
|
|
|
if config_values.session_type == "TRAINING":
|
|
# We're in a training session
|
|
print("Starting training session...")
|
|
logging.info("Starting training session...")
|
|
for episode in range(0, config_values.num_episodes):
|
|
agent.learn(total_timesteps=1)
|
|
save_agent(agent)
|
|
else:
|
|
# Default to being in an evaluation session
|
|
print("Starting evaluation session...")
|
|
logging.info("Starting evaluation session...")
|
|
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
|
|
|
env.close()
|
|
|
|
def save_agent(_agent):
|
|
"""
|
|
Persist an agent (only works for stable baselines3 agents at present)
|
|
"""
|
|
|
|
now = datetime.now() # current date and time
|
|
time = now.strftime("%Y%m%d_%H%M%S")
|
|
|
|
try:
|
|
path = 'outputs/agents/'
|
|
is_dir = os.path.isdir(path)
|
|
if not is_dir:
|
|
os.makedirs(path)
|
|
filename = "outputs/agents/agent_saved_" + time
|
|
_agent.save(filename)
|
|
logging.info("Trained agent saved as " + filename)
|
|
except Exception as e:
|
|
logging.error("Could not save agent")
|
|
logging.error("Exception occured", exc_info=True)
|
|
|
|
def configure_logging():
|
|
"""
|
|
Configures logging
|
|
"""
|
|
|
|
try:
|
|
now = datetime.now() # current date and time
|
|
time = now.strftime("%Y%m%d_%H%M%S")
|
|
filename = "logs/app_" + time + ".log"
|
|
path = 'logs/'
|
|
is_dir = os.path.isdir(path)
|
|
if not is_dir:
|
|
os.makedirs(path)
|
|
logging.basicConfig(filename=filename, filemode='w', format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)
|
|
except:
|
|
print("ERROR: Could not start logging")
|
|
|
|
def load_config_values():
|
|
"""
|
|
Loads the config values from the main config file into a config object
|
|
"""
|
|
|
|
try:
|
|
# Generic
|
|
config_values.agent_identifier = config_data['agentIdentifier']
|
|
config_values.num_episodes = int(config_data['numEpisodes'])
|
|
config_values.time_delay = int(config_data['timeDelay'])
|
|
config_values.config_filename_use_case = config_data['configFilename']
|
|
config_values.session_type = config_data['sessionType']
|
|
config_values.load_agent = bool(config_data['loadAgent'])
|
|
config_values.agent_load_file = config_data['agentLoadFile']
|
|
# Environment
|
|
config_values.observation_space_high_value = int(config_data['observationSpaceHighValue'])
|
|
# Reward values
|
|
# Generic
|
|
config_values.all_ok = int(config_data['allOk'])
|
|
# Node Operating State
|
|
config_values.off_should_be_on = int(config_data['offShouldBeOn'])
|
|
config_values.off_should_be_resetting = int(config_data['offShouldBeResetting'])
|
|
config_values.on_should_be_off = int(config_data['onShouldBeOff'])
|
|
config_values.on_should_be_resetting = int(config_data['onShouldBeResetting'])
|
|
config_values.resetting_should_be_on = int(config_data['resettingShouldBeOn'])
|
|
config_values.resetting_should_be_off = int(config_data['resettingShouldBeOff'])
|
|
config_values.resetting = int(config_data['resetting'])
|
|
# Node O/S or Service State
|
|
config_values.good_should_be_patching = int(config_data['goodShouldBePatching'])
|
|
config_values.good_should_be_compromised = int(config_data['goodShouldBeCompromised'])
|
|
config_values.good_should_be_overwhelmed = int(config_data['goodShouldBeOverwhelmed'])
|
|
config_values.patching_should_be_good = int(config_data['patchingShouldBeGood'])
|
|
config_values.patching_should_be_compromised = int(config_data['patchingShouldBeCompromised'])
|
|
config_values.patching_should_be_overwhelmed = int(config_data['patchingShouldBeOverwhelmed'])
|
|
config_values.patching = int(config_data['patching'])
|
|
config_values.compromised_should_be_good = int(config_data['compromisedShouldBeGood'])
|
|
config_values.compromised_should_be_patching = int(config_data['compromisedShouldBePatching'])
|
|
config_values.compromised_should_be_overwhelmed = int(config_data['compromisedShouldBeOverwhelmed'])
|
|
config_values.compromised = int(config_data['compromised'])
|
|
config_values.overwhelmed_should_be_good = int(config_data['overwhelmedShouldBeGood'])
|
|
config_values.overwhelmed_should_be_patching = int(config_data['overwhelmedShouldBePatching'])
|
|
config_values.overwhelmed_should_be_compromised = int(config_data['overwhelmedShouldBeCompromised'])
|
|
config_values.overwhelmed = int(config_data['overwhelmed'])
|
|
# Node File System State
|
|
config_values.good_should_be_repairing = int(config_data['goodShouldBeRepairing'])
|
|
config_values.good_should_be_restoring = int(config_data['goodShouldBeRestoring'])
|
|
config_values.good_should_be_corrupt = int(config_data['goodShouldBeCorrupt'])
|
|
config_values.good_should_be_destroyed = int(config_data['goodShouldBeDestroyed'])
|
|
config_values.repairing_should_be_good = int(config_data['repairingShouldBeGood'])
|
|
config_values.repairing_should_be_restoring = int(config_data['repairingShouldBeRestoring'])
|
|
config_values.repairing_should_be_corrupt = int(config_data['repairingShouldBeCorrupt'])
|
|
config_values.repairing_should_be_destroyed = int(config_data['repairingShouldBeDestroyed'])
|
|
config_values.repairing = int(config_data['repairing'])
|
|
config_values.restoring_should_be_good = int(config_data['restoringShouldBeGood'])
|
|
config_values.restoring_should_be_repairing = int(config_data['restoringShouldBeRepairing'])
|
|
config_values.restoring_should_be_corrupt = int(config_data['restoringShouldBeCorrupt'])
|
|
config_values.restoring_should_be_destroyed = int(config_data['restoringShouldBeDestroyed'])
|
|
config_values.restoring = int(config_data['restoring'])
|
|
config_values.corrupt_should_be_good = int(config_data['corruptShouldBeGood'])
|
|
config_values.corrupt_should_be_repairing = int(config_data['corruptShouldBeRepairing'])
|
|
config_values.corrupt_should_be_restoring = int(config_data['corruptShouldBeRestoring'])
|
|
config_values.corrupt_should_be_destroyed = int(config_data['corruptShouldBeDestroyed'])
|
|
config_values.corrupt = int(config_data['corrupt'])
|
|
config_values.destroyed_should_be_good = int(config_data['destroyedShouldBeGood'])
|
|
config_values.destroyed_should_be_repairing = int(config_data['destroyedShouldBeRepairing'])
|
|
config_values.destroyed_should_be_restoring = int(config_data['destroyedShouldBeRestoring'])
|
|
config_values.destroyed_should_be_corrupt = int(config_data['destroyedShouldBeCorrupt'])
|
|
config_values.destroyed = int(config_data['destroyed'])
|
|
config_values.scanning = int(config_data['scanning'])
|
|
# IER status
|
|
config_values.red_ier_running = int(config_data['redIerRunning'])
|
|
config_values.green_ier_blocked = int(config_data['greenIerBlocked'])
|
|
# Patching / Reset durations
|
|
config_values.os_patching_duration = int(config_data['osPatchingDuration'])
|
|
config_values.node_reset_duration = int(config_data['nodeResetDuration'])
|
|
config_values.service_patching_duration = int(config_data['servicePatchingDuration'])
|
|
config_values.file_system_repairing_limit = int(config_data['fileSystemRepairingLimit'])
|
|
config_values.file_system_restoring_limit = int(config_data['fileSystemRestoringLimit'])
|
|
config_values.file_system_scanning_limit = int(config_data['fileSystemScanningLimit'])
|
|
|
|
logging.info("Training agent: " + config_values.agent_identifier)
|
|
logging.info("Training environment config: " + config_values.config_filename_use_case)
|
|
logging.info("Training cycle has " + str(config_values.num_episodes) + " episodes")
|
|
|
|
except Exception as e:
|
|
logging.error("Could not save load config data")
|
|
logging.error("Exception occured", exc_info=True)
|
|
|
|
|
|
################################# MAIN PROCESS ############################################
|
|
|
|
# Starting point
|
|
|
|
# Welcome message
|
|
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
|
|
|
# Configure logging
|
|
configure_logging()
|
|
|
|
# Open the main config file
|
|
try:
|
|
config_file_main = open("config/config_main.yaml", "r")
|
|
config_data = yaml.safe_load(config_file_main)
|
|
# Create a config class
|
|
config_values = config_values_main()
|
|
# Load in config data
|
|
load_config_values()
|
|
except Exception as e:
|
|
logging.error("Could not load main config")
|
|
logging.error("Exception occured", exc_info=True)
|
|
|
|
# Create a list of transactions
|
|
# A transaction is an object holding the:
|
|
# - episode #
|
|
# - step #
|
|
# - initial observation space
|
|
# - action
|
|
# - reward
|
|
# - new observation space
|
|
transaction_list = []
|
|
|
|
# Create the PRIMAITE environment
|
|
try:
|
|
env = PRIMAITE(config_values, transaction_list)
|
|
logging.info("PrimAITE environment created")
|
|
except Exception as e:
|
|
logging.error("Could not create PrimAITE environment")
|
|
logging.error("Exception occured", exc_info=True)
|
|
|
|
# Get the number of steps (which is stored in the child config file)
|
|
config_values.num_steps = env.episode_steps
|
|
|
|
# Run environment against an agent
|
|
if config_values.agent_identifier == "GENERIC":
|
|
run_generic()
|
|
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
|
run_stable_baselines3_ppo()
|
|
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
|
run_stable_baselines3_a2c()
|
|
|
|
print("Session finished")
|
|
logging.info("Session finished")
|
|
|
|
print("Saving transaction logs...")
|
|
logging.info("Saving transaction logs...")
|
|
|
|
write_transaction_to_file(transaction_list)
|
|
|
|
config_file_main.close
|
|
|
|
print("Finished")
|
|
logging.info("Finished")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|