#915 - Created app dirs and set as constants in the top-level init.
- renamed _config_values_main to training_config.py and renamed the ConfigValuesMain class to TrainingConfig. Moved training_config.py to src/primaite/config/training_config.py - Renamed all training config yaml file keys to make creating an instance of TrainingConfig easier. Moved action_type and num_steps over to the training config. - Decoupled the training config and lay down config. - Refactored main.py so that it can be ran from CLI and can take a training config path and a lay down config path. - refactored all outputs so that they save to the session dir. - Added some necessary setup scripts that handle creating app dirs, fronting example config files to the user, fronting demo notebooks to the user, performing clean-up in between installations etc. - Added functions that attempt to retrieve the file path of users example config files that have been fronted by the primaite setup. - Added logging config and a getLogger function in the top-level init. - Refactored all logs entries logged to use a logger using the primaite logging config. - Added basic typer CLI for doing things like setup, viewing logs, viewing primaite version, running a basic session. - Updated test to use new features and config structures. - Began updating docs. More to do here.
This commit is contained in:
98
src/primaite/__init__.py
Normal file
98
src/primaite/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import logging
|
||||
import logging.config
|
||||
import sys
|
||||
from logging import Logger, StreamHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
from platformdirs import PlatformDirs
|
||||
|
||||
_PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
"""An instance of `PlatformDirs` set with appname='primaite'."""
|
||||
|
||||
_USER_DIRS: Final[Path] = Path.home() / "primaite"
|
||||
"""The users home space for PrimAITE which is located at: ~/primaite."""
|
||||
|
||||
NOTEBOOKS_DIR: Final[Path] = _USER_DIRS / "notebooks"
|
||||
"""
|
||||
The path to the users notebooks directory as an instance of `Path` or
|
||||
`PosixPath`, depending on the OS.
|
||||
|
||||
Users notebooks are stored at: ``~/primaite/notebooks``.
|
||||
"""
|
||||
|
||||
USERS_CONFIG_DIR: Final[Path] = _USER_DIRS / "config"
|
||||
"""
|
||||
The path to the users config directory as an instance of `Path` or
|
||||
`PosixPath`, depending on the OS.
|
||||
|
||||
Users config files are stored at: ``~/primaite/config``.
|
||||
"""
|
||||
|
||||
SESSIONS_DIR: Final[Path] = _USER_DIRS / "sessions"
|
||||
"""
|
||||
The path to the users PrimAITE Sessions directory as an instance of `Path` or
|
||||
`PosixPath`, depending on the OS.
|
||||
|
||||
Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
|
||||
"""
|
||||
|
||||
|
||||
# region Setup Logging
|
||||
def _log_dir() -> Path:
|
||||
if sys.platform == "win32":
|
||||
dir_path = _PLATFORM_DIRS.user_data_path / "logs"
|
||||
else:
|
||||
dir_path = _PLATFORM_DIRS.user_log_path
|
||||
return dir_path
|
||||
|
||||
|
||||
LOG_DIR: Final[Path] = _log_dir()
|
||||
"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS."""
|
||||
|
||||
LOG_PATH: Final[Path] = LOG_DIR / "primaite.log"
|
||||
"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS."""
|
||||
|
||||
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
|
||||
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
|
||||
filename=LOG_PATH,
|
||||
maxBytes=10485760, # 10MB
|
||||
backupCount=9, # Max 100MB of logs
|
||||
encoding="utf8",
|
||||
)
|
||||
_STREAM_HANDLER.setLevel(logging.INFO)
|
||||
_FILE_HANDLER.setLevel(logging.INFO)
|
||||
|
||||
_LOG_FORMAT_STR: Final[
|
||||
str
|
||||
] = "%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s"
|
||||
_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_LOGGER.addHandler(_STREAM_HANDLER)
|
||||
_LOGGER.addHandler(_FILE_HANDLER)
|
||||
|
||||
|
||||
def getLogger(name: str) -> Logger:
|
||||
"""
|
||||
Get a PrimAITE logger.
|
||||
|
||||
:param name: The logger name. Use ``__name__``.
|
||||
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
|
||||
logging config.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
with open(Path(__file__).parent.resolve() / "VERSION", "r") as file:
|
||||
__version__ = file.readline()
|
||||
125
src/primaite/cli.py
Normal file
125
src/primaite/cli.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Provides a CLI using Typer as an entry point."""
|
||||
import os
|
||||
import sys
|
||||
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def build_dirs():
|
||||
"""Build the PrimAITE app directories."""
|
||||
from primaite.setup import setup_app_dirs
|
||||
|
||||
setup_app_dirs.run()
|
||||
|
||||
|
||||
@app.command()
|
||||
def reset_notebooks(overwrite: bool = True):
|
||||
"""
|
||||
Force a reset of the demo notebooks in the users notebooks directory.
|
||||
|
||||
:param overwrite: If True, will overwrite existing demo notebooks.
|
||||
"""
|
||||
from primaite.setup import reset_demo_notebooks
|
||||
|
||||
reset_demo_notebooks.run(overwrite)
|
||||
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: int = 10):
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
import re
|
||||
|
||||
from platformdirs import PlatformDirs
|
||||
|
||||
yt_platform_dirs = PlatformDirs(appname="primaite")
|
||||
|
||||
if sys.platform == "win32":
|
||||
log_dir = yt_platform_dirs.user_data_path / "logs"
|
||||
else:
|
||||
log_dir = yt_platform_dirs.user_log_path
|
||||
log_path = os.path.join(log_dir, "primaite.log")
|
||||
|
||||
if os.path.isfile(log_path):
|
||||
with open(log_path) as file:
|
||||
lines = file.readlines()
|
||||
for line in lines[-last_n:]:
|
||||
print(re.sub(r"\n*", "", line))
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks():
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
start_jupyter_session()
|
||||
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
"""Get the installed PrimAITE version number."""
|
||||
import primaite
|
||||
|
||||
print(primaite.__version__)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up():
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
old_installation_clean_up.run()
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup():
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from primaite import getLogger
|
||||
from primaite.setup import (
|
||||
old_installation_clean_up,
|
||||
reset_demo_notebooks,
|
||||
reset_example_configs,
|
||||
setup_app_dirs,
|
||||
)
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_LOGGER.info("Performing the PrimAITE first-time setup...")
|
||||
|
||||
_LOGGER.info("Building the PrimAITE app directories...")
|
||||
setup_app_dirs.run()
|
||||
|
||||
_LOGGER.info("Rebuilding the demo notebooks...")
|
||||
reset_demo_notebooks.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
|
||||
old_installation_clean_up.run()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: str, ldc: str):
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
:param tc: The training config filepath.
|
||||
:param ldc: The lay down config file path.
|
||||
"""
|
||||
from primaite.main import run
|
||||
|
||||
run(training_config_path=tc, lay_down_config_path=ldc)
|
||||
@@ -1,90 +0,0 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""The config class."""
|
||||
|
||||
|
||||
class ConfigValuesMain(object):
|
||||
"""Class to hold main config values."""
|
||||
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
# Generic
|
||||
self.agent_identifier = "" # the agent in use
|
||||
self.num_episodes = 0 # number of episodes to train over
|
||||
self.num_steps = 0 # number of steps in an episode
|
||||
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||
self.config_filename_use_case = "" # the filename for the Use Case config file
|
||||
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
|
||||
|
||||
# Environment
|
||||
self.observation_space_high_value = (
|
||||
0 # The high value for the observation space
|
||||
)
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
self.all_ok = 0
|
||||
# Node Hardware State
|
||||
self.off_should_be_on = 0
|
||||
self.off_should_be_resetting = 0
|
||||
self.on_should_be_off = 0
|
||||
self.on_should_be_resetting = 0
|
||||
self.resetting_should_be_on = 0
|
||||
self.resetting_should_be_off = 0
|
||||
self.resetting = 0
|
||||
# Node Software or Service State
|
||||
self.good_should_be_patching = 0
|
||||
self.good_should_be_compromised = 0
|
||||
self.good_should_be_overwhelmed = 0
|
||||
self.patching_should_be_good = 0
|
||||
self.patching_should_be_compromised = 0
|
||||
self.patching_should_be_overwhelmed = 0
|
||||
self.patching = 0
|
||||
self.compromised_should_be_good = 0
|
||||
self.compromised_should_be_patching = 0
|
||||
self.compromised_should_be_overwhelmed = 0
|
||||
self.compromised = 0
|
||||
self.overwhelmed_should_be_good = 0
|
||||
self.overwhelmed_should_be_patching = 0
|
||||
self.overwhelmed_should_be_compromised = 0
|
||||
self.overwhelmed = 0
|
||||
# Node File System State
|
||||
self.good_should_be_repairing = 0
|
||||
self.good_should_be_restoring = 0
|
||||
self.good_should_be_corrupt = 0
|
||||
self.good_should_be_destroyed = 0
|
||||
self.repairing_should_be_good = 0
|
||||
self.repairing_should_be_restoring = 0
|
||||
self.repairing_should_be_corrupt = 0
|
||||
self.repairing_should_be_destroyed = (
|
||||
0 # Repairing does not fix destroyed state - you need to restore
|
||||
)
|
||||
self.repairing = 0
|
||||
self.restoring_should_be_good = 0
|
||||
self.restoring_should_be_repairing = 0
|
||||
self.restoring_should_be_corrupt = (
|
||||
0 # Not the optimal method (as repair will fix corruption)
|
||||
)
|
||||
self.restoring_should_be_destroyed = 0
|
||||
self.restoring = 0
|
||||
self.corrupt_should_be_good = 0
|
||||
self.corrupt_should_be_repairing = 0
|
||||
self.corrupt_should_be_restoring = 0
|
||||
self.corrupt_should_be_destroyed = 0
|
||||
self.corrupt = 0
|
||||
self.destroyed_should_be_good = 0
|
||||
self.destroyed_should_be_repairing = 0
|
||||
self.destroyed_should_be_restoring = 0
|
||||
self.destroyed_should_be_corrupt = 0
|
||||
self.destroyed = 0
|
||||
self.scanning = 0
|
||||
# IER status
|
||||
self.red_ier_running = 0
|
||||
self.green_ier_blocked = 0
|
||||
|
||||
# Patching / Reset
|
||||
self.os_patching_duration = 0 # The time taken to patch the OS
|
||||
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
|
||||
self.service_patching_duration = 0 # The time taken to patch a service
|
||||
self.file_system_repairing_limit = 0 # The time take to repair a file
|
||||
self.file_system_restoring_limit = 0 # The time take to restore a file
|
||||
self.file_system_scanning_limit = 0 # The time taken to scan the file system
|
||||
@@ -81,6 +81,7 @@ class ActionType(Enum):
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
|
||||
|
||||
class ObservationType(Enum):
|
||||
|
||||
91
src/primaite/common/training_config.py
Normal file
91
src/primaite/common/training_config.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""The config class."""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from primaite.common.enums import ActionType
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""Class to hold main config values."""
|
||||
|
||||
# Generic
|
||||
agent_identifier: str # The Red Agent algo/class to be used
|
||||
action_type: ActionType # type of action to use (NODE/ACL/ANY)
|
||||
num_episodes: int # number of episodes to train over
|
||||
num_steps: int # number of steps in an episode
|
||||
time_delay: int # delay between steps (ms) - applies to generic agents only
|
||||
# file
|
||||
session_type: str # the session type to run (TRAINING or EVALUATION)
|
||||
load_agent: str # Determine whether to load an agent from file
|
||||
agent_load_file: str # File path and file name of agent if you're loading one in
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int # The high value for the observation space
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: int
|
||||
# Node Hardware State
|
||||
off_should_be_on: int
|
||||
off_should_be_resetting: int
|
||||
on_should_be_off: int
|
||||
on_should_be_resetting: int
|
||||
resetting_should_be_on: int
|
||||
resetting_should_be_off: int
|
||||
resetting: int
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: int
|
||||
good_should_be_compromised: int
|
||||
good_should_be_overwhelmed: int
|
||||
patching_should_be_good: int
|
||||
patching_should_be_compromised: int
|
||||
patching_should_be_overwhelmed: int
|
||||
patching: int
|
||||
compromised_should_be_good: int
|
||||
compromised_should_be_patching: int
|
||||
compromised_should_be_overwhelmed: int
|
||||
compromised: int
|
||||
overwhelmed_should_be_good: int
|
||||
overwhelmed_should_be_patching: int
|
||||
overwhelmed_should_be_compromised: int
|
||||
overwhelmed: int
|
||||
# Node File System State
|
||||
good_should_be_repairing: int
|
||||
good_should_be_restoring: int
|
||||
good_should_be_corrupt: int
|
||||
good_should_be_destroyed: int
|
||||
repairing_should_be_good: int
|
||||
repairing_should_be_restoring: int
|
||||
repairing_should_be_corrupt: int
|
||||
repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
|
||||
|
||||
repairing: int
|
||||
restoring_should_be_good: int
|
||||
restoring_should_be_repairing: int
|
||||
restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
|
||||
|
||||
restoring_should_be_destroyed: int
|
||||
restoring: int
|
||||
corrupt_should_be_good: int
|
||||
corrupt_should_be_repairing: int
|
||||
corrupt_should_be_restoring: int
|
||||
corrupt_should_be_destroyed: int
|
||||
corrupt: int
|
||||
destroyed_should_be_good: int
|
||||
destroyed_should_be_repairing: int
|
||||
destroyed_should_be_restoring: int
|
||||
destroyed_should_be_corrupt: int
|
||||
destroyed: int
|
||||
scanning: int
|
||||
# IER status
|
||||
red_ier_running: int
|
||||
green_ier_blocked: int
|
||||
|
||||
# Patching / Reset
|
||||
os_patching_duration: int # The time taken to patch the OS
|
||||
node_reset_duration: int # The time taken to reset a node (hardware)
|
||||
service_patching_duration: int # The time taken to patch a service
|
||||
file_system_repairing_limit: int # The time take to repair a file
|
||||
file_system_restoring_limit: int # The time take to restore a file
|
||||
file_system_scanning_limit: int # The time taken to scan the file system
|
||||
@@ -1,7 +1,4 @@
|
||||
- itemType: ACTIONS
|
||||
type: NODE
|
||||
- itemType: STEPS
|
||||
steps: 256
|
||||
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
@@ -0,0 +1,94 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 10
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -1,533 +0,0 @@
|
||||
- itemType: ACTIONS
|
||||
type: NODE
|
||||
- itemType: STEPS
|
||||
steps: 256
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- itemType: NODE
|
||||
node_id: '1'
|
||||
name: CLIENT_1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.11
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '2'
|
||||
name: CLIENT_2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH_1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '4'
|
||||
name: SECURITY_SUITE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '6'
|
||||
name: SWITCH_2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '7'
|
||||
name: WEB_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '8'
|
||||
name: DATABASE_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
node_id: '9'
|
||||
name: BACKUP_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.16
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- itemType: LINK
|
||||
id: '11'
|
||||
name: LINK_2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '3'
|
||||
- itemType: LINK
|
||||
id: '12'
|
||||
name: LINK_3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '4'
|
||||
- itemType: LINK
|
||||
id: '13'
|
||||
name: LINK_4
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '5'
|
||||
- itemType: LINK
|
||||
id: '14'
|
||||
name: LINK_5
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '6'
|
||||
- itemType: LINK
|
||||
id: '15'
|
||||
name: LINK_6
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '6'
|
||||
- itemType: LINK
|
||||
id: '16'
|
||||
name: LINK_7
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '7'
|
||||
- itemType: LINK
|
||||
id: '17'
|
||||
name: LINK_8
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '8'
|
||||
- itemType: LINK
|
||||
id: '18'
|
||||
name: LINK_9
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '9'
|
||||
- itemType: GREEN_IER
|
||||
id: '19'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '7'
|
||||
missionCriticality: 5
|
||||
- itemType: GREEN_IER
|
||||
id: '20'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '1'
|
||||
missionCriticality: 5
|
||||
- itemType: GREEN_IER
|
||||
id: '21'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '7'
|
||||
missionCriticality: 5
|
||||
- itemType: GREEN_IER
|
||||
id: '22'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '2'
|
||||
missionCriticality: 5
|
||||
- itemType: GREEN_IER
|
||||
id: '23'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 5000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '7'
|
||||
destination: '8'
|
||||
missionCriticality: 5
|
||||
- itemType: GREEN_IER
|
||||
id: '24'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 100000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '8'
|
||||
destination: '7'
|
||||
missionCriticality: 5
|
||||
- itemType: GREEN_IER
|
||||
id: '25'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
missionCriticality: 2
|
||||
- itemType: GREEN_IER
|
||||
id: '26'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
missionCriticality: 2
|
||||
- itemType: GREEN_IER
|
||||
id: '27'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '7'
|
||||
missionCriticality: 1
|
||||
- itemType: GREEN_IER
|
||||
id: '28'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '5'
|
||||
missionCriticality: 1
|
||||
- itemType: GREEN_IER
|
||||
id: '29'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '8'
|
||||
missionCriticality: 1
|
||||
- itemType: GREEN_IER
|
||||
id: '30'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '8'
|
||||
destination: '5'
|
||||
missionCriticality: 1
|
||||
- itemType: GREEN_IER
|
||||
id: '31'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '9'
|
||||
missionCriticality: 1
|
||||
- itemType: GREEN_IER
|
||||
id: '32'
|
||||
startStep: 1
|
||||
endStep: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '9'
|
||||
destination: '5'
|
||||
missionCriticality: 1
|
||||
- itemType: ACL_RULE
|
||||
id: '33'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '34'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '35'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '36'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '37'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.11
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '38'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '39'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '40'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '41'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '42'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '43'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '44'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '45'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '46'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '47'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '48'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.16
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: ACL_RULE
|
||||
id: '49'
|
||||
permission: DENY
|
||||
source: ANY
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
- itemType: RED_POL
|
||||
id: '50'
|
||||
startStep: 50
|
||||
endStep: 50
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_IER
|
||||
id: '51'
|
||||
startStep: 75
|
||||
endStep: 105
|
||||
load: 10000
|
||||
protocol: UDP
|
||||
port: '53'
|
||||
source: '1'
|
||||
destination: '8'
|
||||
missionCriticality: 0
|
||||
- itemType: RED_POL
|
||||
id: '52'
|
||||
startStep: 100
|
||||
endStep: 100
|
||||
targetNodeId: '8'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '53'
|
||||
startStep: 105
|
||||
endStep: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: FILE
|
||||
protocol: NA
|
||||
state: CORRUPT
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- itemType: RED_POL
|
||||
id: '54'
|
||||
startStep: 105
|
||||
endStep: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP_SQL
|
||||
state: COMPROMISED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- itemType: RED_POL
|
||||
id: '55'
|
||||
startStep: 125
|
||||
endStep: 125
|
||||
targetNodeId: '7'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
@@ -1,89 +0,0 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agentIdentifier: STABLE_BASELINES3_A2C
|
||||
# Number of episodes to run per session
|
||||
numEpisodes: 10
|
||||
# Time delay between steps (for generic agents)
|
||||
timeDelay: 10
|
||||
# Filename of the scenario / laydown
|
||||
configFilename: config_5_DATA_MANIPULATION.yaml
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
sessionType: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
loadAgent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observationSpaceHighValue: 1000000000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
allOk: 0
|
||||
# Node Hardware State
|
||||
offShouldBeOn: -10
|
||||
offShouldBeResetting: -5
|
||||
onShouldBeOff: -2
|
||||
onShouldBeResetting: -5
|
||||
resettingShouldBeOn: -5
|
||||
resettingShouldBeOff: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
goodShouldBePatching: 2
|
||||
goodShouldBeCompromised: 5
|
||||
goodShouldBeOverwhelmed: 5
|
||||
patchingShouldBeGood: -5
|
||||
patchingShouldBeCompromised: 2
|
||||
patchingShouldBeOverwhelmed: 2
|
||||
patching: -3
|
||||
compromisedShouldBeGood: -20
|
||||
compromisedShouldBePatching: -20
|
||||
compromisedShouldBeOverwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmedShouldBeGood: -20
|
||||
overwhelmedShouldBePatching: -20
|
||||
overwhelmedShouldBeCompromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
goodShouldBeRepairing: 2
|
||||
goodShouldBeRestoring: 2
|
||||
goodShouldBeCorrupt: 5
|
||||
goodShouldBeDestroyed: 10
|
||||
repairingShouldBeGood: -5
|
||||
repairingShouldBeRestoring: 2
|
||||
repairingShouldBeCorrupt: 2
|
||||
repairingShouldBeDestroyed: 0
|
||||
repairing: -3
|
||||
restoringShouldBeGood: -10
|
||||
restoringShouldBeRepairing: -2
|
||||
restoringShouldBeCorrupt: 1
|
||||
restoringShouldBeDestroyed: 2
|
||||
restoring: -6
|
||||
corruptShouldBeGood: -10
|
||||
corruptShouldBeRepairing: -10
|
||||
corruptShouldBeRestoring: -10
|
||||
corruptShouldBeDestroyed: 2
|
||||
corrupt: -10
|
||||
destroyedShouldBeGood: -20
|
||||
destroyedShouldBeRepairing: -20
|
||||
destroyedShouldBeRestoring: -20
|
||||
destroyedShouldBeCorrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
redIerRunning: -5
|
||||
greenIerBlocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
osPatchingDuration: 5 # The time taken to patch the OS
|
||||
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||
servicePatchingDuration: 5 # The time taken to patch a service
|
||||
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||
69
src/primaite/config/lay_down_config.py
Normal file
69
src/primaite/config/lay_down_config.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def ddos_basic_two_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def dos_very_basic_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def data_manipulation_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
260
src/primaite/config/training_config.py
Normal file
260
src/primaite/config/training_config.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union, Final
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
from primaite.common.enums import ActionType
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[
|
||||
Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
|
||||
# Generic
|
||||
agent_identifier: str # The Red Agent algo/class to be used
|
||||
action_type: ActionType # type of action to use (NODE/ACL/ANY)
|
||||
num_episodes: int # number of episodes to train over
|
||||
num_steps: int # number of steps in an episode
|
||||
time_delay: int # delay between steps (ms) - applies to generic agents only
|
||||
# file
|
||||
session_type: str # the session type to run (TRAINING or EVALUATION)
|
||||
load_agent: str # Determine whether to load an agent from file
|
||||
agent_load_file: str # File path and file name of agent if you're loading one in
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int # The high value for the observation space
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: int
|
||||
# Node Hardware State
|
||||
off_should_be_on: int
|
||||
off_should_be_resetting: int
|
||||
on_should_be_off: int
|
||||
on_should_be_resetting: int
|
||||
resetting_should_be_on: int
|
||||
resetting_should_be_off: int
|
||||
resetting: int
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: int
|
||||
good_should_be_compromised: int
|
||||
good_should_be_overwhelmed: int
|
||||
patching_should_be_good: int
|
||||
patching_should_be_compromised: int
|
||||
patching_should_be_overwhelmed: int
|
||||
patching: int
|
||||
compromised_should_be_good: int
|
||||
compromised_should_be_patching: int
|
||||
compromised_should_be_overwhelmed: int
|
||||
compromised: int
|
||||
overwhelmed_should_be_good: int
|
||||
overwhelmed_should_be_patching: int
|
||||
overwhelmed_should_be_compromised: int
|
||||
overwhelmed: int
|
||||
# Node File System State
|
||||
good_should_be_repairing: int
|
||||
good_should_be_restoring: int
|
||||
good_should_be_corrupt: int
|
||||
good_should_be_destroyed: int
|
||||
repairing_should_be_good: int
|
||||
repairing_should_be_restoring: int
|
||||
repairing_should_be_corrupt: int
|
||||
repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
|
||||
|
||||
repairing: int
|
||||
restoring_should_be_good: int
|
||||
restoring_should_be_repairing: int
|
||||
restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
|
||||
|
||||
restoring_should_be_destroyed: int
|
||||
restoring: int
|
||||
corrupt_should_be_good: int
|
||||
corrupt_should_be_repairing: int
|
||||
corrupt_should_be_restoring: int
|
||||
corrupt_should_be_destroyed: int
|
||||
corrupt: int
|
||||
destroyed_should_be_good: int
|
||||
destroyed_should_be_repairing: int
|
||||
destroyed_should_be_restoring: int
|
||||
destroyed_should_be_corrupt: int
|
||||
destroyed: int
|
||||
scanning: int
|
||||
# IER status
|
||||
red_ier_running: int
|
||||
green_ier_blocked: int
|
||||
|
||||
# Patching / Reset
|
||||
os_patching_duration: int # The time taken to patch the OS
|
||||
node_reset_duration: int # The time taken to reset a node (hardware)
|
||||
service_patching_duration: int # The time taken to patch a service
|
||||
file_system_repairing_limit: int # The time take to repair a file
|
||||
file_system_restoring_limit: int # The time take to restore a file
|
||||
file_system_scanning_limit: int # The time taken to scan the file system
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path],
|
||||
legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
False.
|
||||
:return: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
:raises TypeError: When the TrainingConfig object cannot be created
|
||||
using the values from the config file read from ``file_path``.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading training config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_training_config_dict(config)
|
||||
except KeyError:
|
||||
msg = (
|
||||
f"Failed to convert training config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
# Convert values to Enums
|
||||
config["action_type"] = ActionType[config["action_type"]]
|
||||
try:
|
||||
return TrainingConfig(**config)
|
||||
except TypeError as e:
|
||||
msg = (
|
||||
f"Error when creating an instance of {TrainingConfig} "
|
||||
f"from the training config file {file_path}"
|
||||
)
|
||||
_LOGGER.critical(msg, exc_info=True)
|
||||
raise e
|
||||
msg = f"Cannot load the training config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
num_steps: int = 256,
|
||||
action_type: str = "ANY"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {"num_steps": num_steps, "action_type": action_type}
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
config_dict[new_key] = value
|
||||
return config_dict
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
|
||||
:param legacy_key: A legacy training config key.
|
||||
:return: The mapped key.
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": "agent_identifier",
|
||||
"numEpisodes": "num_episodes",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
"sessionType": "session_type",
|
||||
"loadAgent": "load_agent",
|
||||
"agentLoadFile": "agent_load_file",
|
||||
"observationSpaceHighValue": "observation_space_high_value",
|
||||
"allOk": "all_ok",
|
||||
"offShouldBeOn": "off_should_be_on",
|
||||
"offShouldBeResetting": "off_should_be_resetting",
|
||||
"onShouldBeOff": "on_should_be_off",
|
||||
"onShouldBeResetting": "on_should_be_resetting",
|
||||
"resettingShouldBeOn": "resetting_should_be_on",
|
||||
"resettingShouldBeOff": "resetting_should_be_off",
|
||||
"resetting": "resetting",
|
||||
"goodShouldBePatching": "good_should_be_patching",
|
||||
"goodShouldBeCompromised": "good_should_be_compromised",
|
||||
"goodShouldBeOverwhelmed": "good_should_be_overwhelmed",
|
||||
"patchingShouldBeGood": "patching_should_be_good",
|
||||
"patchingShouldBeCompromised": "patching_should_be_compromised",
|
||||
"patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed",
|
||||
"patching": "patching",
|
||||
"compromisedShouldBeGood": "compromised_should_be_good",
|
||||
"compromisedShouldBePatching": "compromised_should_be_patching",
|
||||
"compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed",
|
||||
"compromised": "compromised",
|
||||
"overwhelmedShouldBeGood": "overwhelmed_should_be_good",
|
||||
"overwhelmedShouldBePatching": "overwhelmed_should_be_patching",
|
||||
"overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised",
|
||||
"overwhelmed": "overwhelmed",
|
||||
"goodShouldBeRepairing": "good_should_be_repairing",
|
||||
"goodShouldBeRestoring": "good_should_be_restoring",
|
||||
"goodShouldBeCorrupt": "good_should_be_corrupt",
|
||||
"goodShouldBeDestroyed": "good_should_be_destroyed",
|
||||
"repairingShouldBeGood": "repairing_should_be_good",
|
||||
"repairingShouldBeRestoring": "repairing_should_be_restoring",
|
||||
"repairingShouldBeCorrupt": "repairing_should_be_corrupt",
|
||||
"repairingShouldBeDestroyed": "repairing_should_be_destroyed",
|
||||
"repairing": "repairing",
|
||||
"restoringShouldBeGood": "restoring_should_be_good",
|
||||
"restoringShouldBeRepairing": "restoring_should_be_repairing",
|
||||
"restoringShouldBeCorrupt": "restoring_should_be_corrupt",
|
||||
"restoringShouldBeDestroyed": "restoring_should_be_destroyed",
|
||||
"restoring": "restoring",
|
||||
"corruptShouldBeGood": "corrupt_should_be_good",
|
||||
"corruptShouldBeRepairing": "corrupt_should_be_repairing",
|
||||
"corruptShouldBeRestoring": "corrupt_should_be_restoring",
|
||||
"corruptShouldBeDestroyed": "corrupt_should_be_destroyed",
|
||||
"corrupt": "corrupt",
|
||||
"destroyedShouldBeGood": "destroyed_should_be_good",
|
||||
"destroyedShouldBeRepairing": "destroyed_should_be_repairing",
|
||||
"destroyedShouldBeRestoring": "destroyed_should_be_restoring",
|
||||
"destroyedShouldBeCorrupt": "destroyed_should_be_corrupt",
|
||||
"destroyed": "destroyed",
|
||||
"scanning": "scanning",
|
||||
"redIerRunning": "red_ier_running",
|
||||
"greenIerBlocked": "green_ier_blocked",
|
||||
"osPatchingDuration": "os_patching_duration",
|
||||
"nodeResetDuration": "node_reset_duration",
|
||||
"servicePatchingDuration": "service_patching_duration",
|
||||
"fileSystemRepairingLimit": "file_system_repairing_limit",
|
||||
"fileSystemRestoringLimit": "file_system_restoring_limit",
|
||||
"fileSystemScanningLimit": "file_system_scanning_limit",
|
||||
}
|
||||
return key_mapping[legacy_key]
|
||||
@@ -6,7 +6,8 @@ import csv
|
||||
import logging
|
||||
import os.path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -28,6 +29,8 @@ from primaite.common.enums import (
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -56,26 +59,36 @@ class Primaite(Env):
|
||||
|
||||
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
|
||||
|
||||
def __init__(self, _config_values, _transaction_list):
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
The Primaite constructor.
|
||||
|
||||
Args:
|
||||
_episode_steps: The number of steps for the episode
|
||||
_config_filename: The name of config file
|
||||
_transaction_list: The list of transactions to populate
|
||||
_agent_identifier: Identifier for the agent
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param transaction_list: The list of transactions to populate.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
super(Primaite, self).__init__()
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
|
||||
# Take a copy of the config values
|
||||
self.config_values = _config_values
|
||||
self.config_values: TrainingConfig = training_config.load(training_config_path)
|
||||
|
||||
# Number of steps in an episode
|
||||
self.episode_steps = 0
|
||||
self.episode_steps = self.config_values.num_steps
|
||||
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# Transaction list
|
||||
self.transaction_list = _transaction_list
|
||||
self.transaction_list = transaction_list
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.config_values.agent_identifier
|
||||
@@ -153,13 +166,10 @@ class Primaite(Env):
|
||||
self.observation_type = ObservationType.BOX
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
try:
|
||||
self.config_file = open(self.config_values.config_filename_use_case, "r")
|
||||
self.config_data = yaml.safe_load(self.config_file)
|
||||
self.load_config()
|
||||
except Exception:
|
||||
_LOGGER.error("Could not load the environment configuration")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
self.config_data = yaml.safe_load(file)
|
||||
self.load_lay_down_config()
|
||||
|
||||
# Store the node objects as node attributes
|
||||
# (This is so we can access them as objects)
|
||||
@@ -179,12 +189,8 @@ class Primaite(Env):
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
path = "outputs/diagrams"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/diagrams/network_" + time + ".png"
|
||||
plt.savefig(filename, format="PNG")
|
||||
file_path = session_path / f"network_{timestamp_str}.png"
|
||||
plt.savefig(file_path, format="PNG")
|
||||
plt.clf()
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save network diagram")
|
||||
@@ -236,13 +242,9 @@ class Primaite(Env):
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
header = ["Episode", "Average Reward"]
|
||||
|
||||
# Check whether the output/rerults folder exists (doesn't exist by default install)
|
||||
path = "outputs/results/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/results/average_reward_per_episode_" + time + ".csv"
|
||||
self.csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
file_name = f"average_reward_per_episode_{timestamp_str}.csv"
|
||||
file_path = session_path / file_name
|
||||
self.csv_file = open(file_path, "w", encoding="UTF8", newline="")
|
||||
self.csv_writer = csv.writer(self.csv_file)
|
||||
self.csv_writer.writerow(header)
|
||||
except Exception:
|
||||
@@ -404,7 +406,6 @@ class Primaite(Env):
|
||||
def __close__(self):
|
||||
"""Override close function."""
|
||||
self.csv_file.close()
|
||||
self.config_file.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""Initialise the Access Control List."""
|
||||
@@ -888,7 +889,7 @@ class Primaite(Env):
|
||||
elif self.observation_type == ObservationType.MULTIDISCRETE:
|
||||
self._update_env_obs_multidiscrete()
|
||||
|
||||
def load_config(self):
|
||||
def load_lay_down_config(self):
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
for item in self.config_data:
|
||||
if item["itemType"] == "NODE":
|
||||
@@ -918,15 +919,9 @@ class Primaite(Env):
|
||||
elif item["itemType"] == "PORTS":
|
||||
# Create the list of ports
|
||||
self.create_ports_list(item)
|
||||
elif item["itemType"] == "ACTIONS":
|
||||
# Get the action information
|
||||
self.get_action_info(item)
|
||||
elif item["itemType"] == "OBSERVATIONS":
|
||||
# Get the observation information
|
||||
self.get_observation_info(item)
|
||||
elif item["itemType"] == "STEPS":
|
||||
# Get the steps information
|
||||
self.get_steps_info(item)
|
||||
else:
|
||||
# Do nothing (bad formatting)
|
||||
pass
|
||||
@@ -1247,15 +1242,6 @@ class Primaite(Env):
|
||||
# Set the number of ports
|
||||
self.num_ports = len(self.ports_list)
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing action info
|
||||
"""
|
||||
self.action_type = ActionType[action_info["type"]]
|
||||
|
||||
def get_observation_info(self, observation_info):
|
||||
"""Extracts observation_info.
|
||||
|
||||
@@ -1264,16 +1250,6 @@ class Primaite(Env):
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
def get_steps_info(self, steps_info):
|
||||
"""
|
||||
Extracts steps_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing steps info
|
||||
"""
|
||||
self.episode_steps = int(steps_info["steps"])
|
||||
_LOGGER.info("Training episodes have " + str(self.episode_steps) + " steps")
|
||||
|
||||
def reset_environment(self):
|
||||
"""
|
||||
# Resets environment.
|
||||
|
||||
@@ -1,29 +1,41 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Primaite - main (harness) module.
|
||||
The main PrimAITE session runner module.
|
||||
|
||||
Coding Standards: PEP 8
|
||||
TODO: This will eventually be refactored out into a proper Session class.
|
||||
TODO: The passing about of session_dir and timestamp_str is temporary and
|
||||
will be cleaned up once we move to a proper Session class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os.path
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Final, Union
|
||||
|
||||
import yaml
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite.common.config_values_main import ConfigValuesMain
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.config.training_config import TrainingConfig, \
|
||||
main_training_config_path
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
# FUNCTIONS #
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run_generic():
|
||||
"""Run against a generic agent."""
|
||||
def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
"""
|
||||
Run against a generic agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
"""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
env.reset()
|
||||
for step in range(0, config_values.num_steps):
|
||||
@@ -47,9 +59,24 @@ def run_generic():
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo():
|
||||
"""Run against a stable_baselines3 PPO agent."""
|
||||
if config_values.load_agent == True:
|
||||
def run_stable_baselines3_ppo(
|
||||
env: Primaite,
|
||||
config_values: TrainingConfig,
|
||||
session_path: Path,
|
||||
timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
@@ -62,30 +89,44 @@ def run_stable_baselines3_ppo():
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_a2c():
|
||||
"""Run against a stable_baselines3 A2C agent."""
|
||||
if config_values.load_agent == True:
|
||||
def run_stable_baselines3_a2c(
|
||||
env: Primaite,
|
||||
config_values: TrainingConfig,
|
||||
session_path: Path, timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 A2C agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = A2C.load(
|
||||
config_values.agent_load_file,
|
||||
@@ -98,284 +139,151 @@ def run_stable_baselines3_a2c():
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def save_agent(_agent):
|
||||
"""Persist an agent (only works for stable baselines3 agents at present)."""
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Persist an agent.
|
||||
|
||||
try:
|
||||
path = "outputs/agents/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/agents/agent_saved_" + time
|
||||
_agent.save(filename)
|
||||
logging.info("Trained agent saved as " + filename)
|
||||
except Exception:
|
||||
logging.error("Could not save agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
Only works for stable baselines3 agents at present.
|
||||
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if not isinstance(agent, OnPolicyAlgorithm):
|
||||
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
|
||||
_LOGGER.error(msg)
|
||||
else:
|
||||
filepath = session_path / f"agent_saved_{timestamp_str}"
|
||||
agent.save(filepath)
|
||||
_LOGGER.debug(f"Trained agent saved as: {filepath}")
|
||||
|
||||
|
||||
def configure_logging():
|
||||
"""Configures logging."""
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
filename = "logs/app_" + time + ".log"
|
||||
path = "logs/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
logging.basicConfig(
|
||||
filename=filename,
|
||||
filemode="w",
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%d-%b-%y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
except Exception:
|
||||
print("ERROR: Could not start logging")
|
||||
def _get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = SESSIONS_DIR / date_dir / session_dir
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
def load_config_values():
|
||||
"""Loads the config values from the main config file into a config object."""
|
||||
try:
|
||||
# Generic
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = (
|
||||
"config/" + config_data["configFilename"]
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path]
|
||||
):
|
||||
"""Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
"""
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
|
||||
session_timestamp: Final[datetime] = datetime.now()
|
||||
session_path = _get_session_path(session_timestamp)
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
print(f"The output directory for this session is: {session_path}")
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the Primaite environment
|
||||
env = Primaite(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path,
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
config_values = env.config_values
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic(env=env, config_values=config_values)
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
config_values.session_type = config_data["sessionType"]
|
||||
config_values.load_agent = bool(config_data["loadAgent"])
|
||||
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||
# Environment
|
||||
config_values.observation_space_high_value = int(
|
||||
config_data["observationSpaceHighValue"]
|
||||
)
|
||||
# Reward values
|
||||
# Generic
|
||||
config_values.all_ok = int(config_data["allOk"])
|
||||
# Node Hardware State
|
||||
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
|
||||
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
|
||||
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
|
||||
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
|
||||
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
|
||||
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
|
||||
config_values.resetting = int(config_data["resetting"])
|
||||
# Node Software or Service State
|
||||
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
|
||||
config_values.good_should_be_compromised = int(
|
||||
config_data["goodShouldBeCompromised"]
|
||||
)
|
||||
config_values.good_should_be_overwhelmed = int(
|
||||
config_data["goodShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
|
||||
config_values.patching_should_be_compromised = int(
|
||||
config_data["patchingShouldBeCompromised"]
|
||||
)
|
||||
config_values.patching_should_be_overwhelmed = int(
|
||||
config_data["patchingShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching = int(config_data["patching"])
|
||||
config_values.compromised_should_be_good = int(
|
||||
config_data["compromisedShouldBeGood"]
|
||||
)
|
||||
config_values.compromised_should_be_patching = int(
|
||||
config_data["compromisedShouldBePatching"]
|
||||
)
|
||||
config_values.compromised_should_be_overwhelmed = int(
|
||||
config_data["compromisedShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.compromised = int(config_data["compromised"])
|
||||
config_values.overwhelmed_should_be_good = int(
|
||||
config_data["overwhelmedShouldBeGood"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_patching = int(
|
||||
config_data["overwhelmedShouldBePatching"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_compromised = int(
|
||||
config_data["overwhelmedShouldBeCompromised"]
|
||||
)
|
||||
config_values.overwhelmed = int(config_data["overwhelmed"])
|
||||
# Node File System State
|
||||
config_values.good_should_be_repairing = int(
|
||||
config_data["goodShouldBeRepairing"]
|
||||
)
|
||||
config_values.good_should_be_restoring = int(
|
||||
config_data["goodShouldBeRestoring"]
|
||||
)
|
||||
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
|
||||
config_values.good_should_be_destroyed = int(
|
||||
config_data["goodShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing_should_be_good = int(
|
||||
config_data["repairingShouldBeGood"]
|
||||
)
|
||||
config_values.repairing_should_be_restoring = int(
|
||||
config_data["repairingShouldBeRestoring"]
|
||||
)
|
||||
config_values.repairing_should_be_corrupt = int(
|
||||
config_data["repairingShouldBeCorrupt"]
|
||||
)
|
||||
config_values.repairing_should_be_destroyed = int(
|
||||
config_data["repairingShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing = int(config_data["repairing"])
|
||||
config_values.restoring_should_be_good = int(
|
||||
config_data["restoringShouldBeGood"]
|
||||
)
|
||||
config_values.restoring_should_be_repairing = int(
|
||||
config_data["restoringShouldBeRepairing"]
|
||||
)
|
||||
config_values.restoring_should_be_corrupt = int(
|
||||
config_data["restoringShouldBeCorrupt"]
|
||||
)
|
||||
config_values.restoring_should_be_destroyed = int(
|
||||
config_data["restoringShouldBeDestroyed"]
|
||||
)
|
||||
config_values.restoring = int(config_data["restoring"])
|
||||
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
|
||||
config_values.corrupt_should_be_repairing = int(
|
||||
config_data["corruptShouldBeRepairing"]
|
||||
)
|
||||
config_values.corrupt_should_be_restoring = int(
|
||||
config_data["corruptShouldBeRestoring"]
|
||||
)
|
||||
config_values.corrupt_should_be_destroyed = int(
|
||||
config_data["corruptShouldBeDestroyed"]
|
||||
)
|
||||
config_values.corrupt = int(config_data["corrupt"])
|
||||
config_values.destroyed_should_be_good = int(
|
||||
config_data["destroyedShouldBeGood"]
|
||||
)
|
||||
config_values.destroyed_should_be_repairing = int(
|
||||
config_data["destroyedShouldBeRepairing"]
|
||||
)
|
||||
config_values.destroyed_should_be_restoring = int(
|
||||
config_data["destroyedShouldBeRestoring"]
|
||||
)
|
||||
config_values.destroyed_should_be_corrupt = int(
|
||||
config_data["destroyedShouldBeCorrupt"]
|
||||
)
|
||||
config_values.destroyed = int(config_data["destroyed"])
|
||||
config_values.scanning = int(config_data["scanning"])
|
||||
# IER status
|
||||
config_values.red_ier_running = int(config_data["redIerRunning"])
|
||||
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
|
||||
# Patching / Reset durations
|
||||
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
|
||||
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
|
||||
config_values.service_patching_duration = int(
|
||||
config_data["servicePatchingDuration"]
|
||||
)
|
||||
config_values.file_system_repairing_limit = int(
|
||||
config_data["fileSystemRepairingLimit"]
|
||||
)
|
||||
config_values.file_system_restoring_limit = int(
|
||||
config_data["fileSystemRestoringLimit"]
|
||||
)
|
||||
config_values.file_system_scanning_limit = int(
|
||||
config_data["fileSystemScanningLimit"]
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
logging.info("Training agent: " + config_values.agent_identifier)
|
||||
logging.info(
|
||||
"Training environment config: " + config_values.config_filename_use_case
|
||||
print("Session finished")
|
||||
_LOGGER.debug("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
_LOGGER.debug("Saving transaction logs...")
|
||||
|
||||
write_transaction_to_file(
|
||||
transaction_list=transaction_list,
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
print("Finished")
|
||||
_LOGGER.debug("Finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error(
|
||||
"Please provide a training config file using the --tc " "argument"
|
||||
)
|
||||
logging.info(
|
||||
"Training cycle has " + str(config_values.num_episodes) + " episodes"
|
||||
if not args.ldc:
|
||||
_LOGGER.error(
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logging.error("Could not save load config data")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
# MAIN PROCESS #
|
||||
|
||||
# Starting point
|
||||
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
|
||||
# Configure logging
|
||||
configure_logging()
|
||||
|
||||
# Open the main config file
|
||||
try:
|
||||
config_file_main = open("config/config_main.yaml", "r")
|
||||
config_data = yaml.safe_load(config_file_main)
|
||||
# Create a config class
|
||||
config_values = ConfigValuesMain()
|
||||
# Load in config data
|
||||
load_config_values()
|
||||
except Exception:
|
||||
logging.error("Could not load main config")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the Primaite environment
|
||||
# try:
|
||||
env = Primaite(config_values, transaction_list)
|
||||
# logging.info("PrimAITE environment created")
|
||||
# except Exception:
|
||||
# logging.error("Could not create PrimAITE environment")
|
||||
# logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c()
|
||||
|
||||
print("Session finished")
|
||||
logging.info("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
logging.info("Saving transaction logs...")
|
||||
|
||||
write_transaction_to_file(transaction_list)
|
||||
|
||||
config_file_main.close()
|
||||
|
||||
print("Finished")
|
||||
logging.info("Finished")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.config_values_main import ConfigValuesMain
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
@@ -11,6 +10,7 @@ from primaite.common.enums import (
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
@@ -29,7 +29,7 @@ class ActiveNode(Node):
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: ConfigValuesMain,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
"""The base Node class."""
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.config_values_main import ConfigValuesMain
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
|
||||
class Node:
|
||||
@@ -16,7 +16,7 @@ class Node:
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: ConfigValuesMain,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
@@ -34,7 +34,7 @@ class Node:
|
||||
self.priority = priority
|
||||
self.hardware_state: HardwareState = hardware_state
|
||||
self.resetting_count: int = 0
|
||||
self.config_values: ConfigValuesMain = config_values
|
||||
self.config_values: TrainingConfig = config_values
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns the name of the node."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
from primaite.common.config_values_main import ConfigValuesMain
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class PassiveNode(Node):
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: ConfigValuesMain,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
from primaite.common.config_values_main import ConfigValuesMain
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
@@ -12,6 +11,7 @@ from primaite.common.enums import (
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
@@ -30,7 +30,7 @@ class ServiceNode(ActiveNode):
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: ConfigValuesMain,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
"""
|
||||
Init.
|
||||
|
||||
36
src/primaite/notebooks/__init__.py
Normal file
36
src/primaite/notebooks/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from primaite import NOTEBOOKS_DIR, getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session():
|
||||
"""
|
||||
Starts a new Jupyter notebook session in the app notebooks directory.
|
||||
|
||||
Currently only works on Windows OS.
|
||||
|
||||
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
if importlib.util.find_spec("jupyter") is not None:
|
||||
# Jupyter is installed
|
||||
working_dir = os.getcwd()
|
||||
os.chdir(NOTEBOOKS_DIR)
|
||||
subprocess.Popen("jupyter lab")
|
||||
os.chdir(working_dir)
|
||||
else:
|
||||
# Jupyter is not installed
|
||||
_LOGGER.error("Cannot start jupyter lab as it is not installed")
|
||||
else:
|
||||
msg = (
|
||||
"Feature currently only supported on Windows OS. For "
|
||||
"Linux/MacOS users, run 'cd ~/primaite/notebooks; jupyter "
|
||||
"lab' from your Python environment."
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
1
src/primaite/setup/__init__.py
Normal file
1
src/primaite/setup/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
13
src/primaite/setup/old_installation_clean_up.py
Normal file
13
src/primaite/setup/old_installation_clean_up.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
"""Perform the full clean-up."""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
39
src/primaite/setup/reset_demo_notebooks.py
Normal file
39
src/primaite/setup/reset_demo_notebooks.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import NOTEBOOKS_DIR, getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True):
|
||||
"""
|
||||
Resets the demo jupyter notebooks in the users app notebooks directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
notebooks on or off.
|
||||
"""
|
||||
notebooks_package_data_root = pkg_resources.resource_filename(
|
||||
"primaite", "notebooks/_package_data"
|
||||
)
|
||||
for subdir, dirs, files in os.walk(notebooks_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
|
||||
target_fp = NOTEBOOKS_DIR / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
if overwrite_existing and not copy_file:
|
||||
copy_file = (not filecmp.cmp(fp, target_fp)) and (
|
||||
".ipynb_checkpoints" not in str(target_fp)
|
||||
)
|
||||
|
||||
if copy_file:
|
||||
shutil.copy2(fp, target_fp)
|
||||
_LOGGER.info(f"Reset example notebook: {target_fp}")
|
||||
37
src/primaite/setup/reset_example_configs.py
Normal file
37
src/primaite/setup/reset_example_configs.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing=True):
|
||||
"""
|
||||
Resets the example config files in the users app config directory.
|
||||
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
config on or off.
|
||||
"""
|
||||
configs_package_data_root = pkg_resources.resource_filename(
|
||||
"primaite", "config/_package_data"
|
||||
)
|
||||
|
||||
for subdir, dirs, files in os.walk(configs_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
|
||||
target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
if overwrite_existing and not copy_file:
|
||||
copy_file = not filecmp.cmp(fp, target_fp)
|
||||
|
||||
if copy_file:
|
||||
shutil.copy2(fp, target_fp)
|
||||
_LOGGER.info(f"Reset example config: {target_fp}")
|
||||
27
src/primaite/setup/setup_app_dirs.py
Normal file
27
src/primaite/setup/setup_app_dirs.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run():
|
||||
"""
|
||||
Handles creation of application directories and user directories.
|
||||
|
||||
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
|
||||
app directories in the correct locations based on the users OS.
|
||||
"""
|
||||
app_dirs = [
|
||||
_USER_DIRS,
|
||||
NOTEBOOKS_DIR,
|
||||
LOG_DIR,
|
||||
]
|
||||
|
||||
for app_dir in app_dirs:
|
||||
if not app_dir.is_dir():
|
||||
app_dir.mkdir(parents=True, exist_ok=True)
|
||||
_LOGGER.info(f"Created directory: {app_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -2,9 +2,11 @@
|
||||
"""Writes the Transaction log list out to file for evaluation to utilse."""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os.path
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def turn_action_space_to_array(_action_space):
|
||||
@@ -38,18 +40,22 @@ def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
return return_array
|
||||
|
||||
|
||||
def write_transaction_to_file(_transaction_list):
|
||||
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
Args:
|
||||
_transaction_list: The list of transactions from all steps and all episodes
|
||||
_num_episodes: The number of episodes that were conducted.
|
||||
:param transaction_list: The list of transactions from all steps and all
|
||||
episodes.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
# Get the first transaction and use it to determine the makeup of the observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action space as "AS_1"
|
||||
# Get the first transaction and use it to determine the makeup of the
|
||||
# observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
|
||||
# space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
template_transation = _transaction_list[0]
|
||||
template_transation = transaction_list[0]
|
||||
action_length = template_transation.action_space.size
|
||||
obs_shape = template_transation.obs_space_post.shape
|
||||
obs_assets = template_transation.obs_space_post.shape[0]
|
||||
@@ -75,21 +81,15 @@ def write_transaction_to_file(_transaction_list):
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
try:
|
||||
path = "outputs/results/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
|
||||
filename = "outputs/results/all_transactions_" + time + ".csv"
|
||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||
csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(header)
|
||||
|
||||
for transaction in _transaction_list:
|
||||
for transaction in transaction_list:
|
||||
csv_data = [
|
||||
str(transaction.timestamp),
|
||||
str(transaction.episode_number),
|
||||
@@ -110,5 +110,4 @@ def write_transaction_to_file(_transaction_list):
|
||||
|
||||
csv_file.close()
|
||||
except Exception:
|
||||
logging.error("Could not save the transaction file")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Could not save the transaction file", exc_info=True)
|
||||
|
||||
0
src/primaite/utils/__init__.py
Normal file
0
src/primaite/utils/__init__.py
Normal file
32
src/primaite/utils/package_data.py
Normal file
32
src/primaite/utils/package_data.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
"""
|
||||
Get PrimAITE package data.
|
||||
|
||||
:Example:
|
||||
|
||||
>>> from primaite.utils.package_data import get_file_path
|
||||
>>> main_env_config = get_file_path("config/_package_data/training_config_main.yaml")
|
||||
|
||||
|
||||
:param path: The path from the primaite root.
|
||||
:return: The file path of the package data file.
|
||||
:raise FileNotFoundError: When the filepath does not exist.
|
||||
"""
|
||||
fp = pkg_resources.resource_filename("primaite", path)
|
||||
if os.path.isfile(fp):
|
||||
return Path(fp)
|
||||
else:
|
||||
msg = f"Cannot PrimAITE package data: {fp}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
Reference in New Issue
Block a user