#915 - Created app dirs and set as constants in the top-level init.

- renamed _config_values_main to training_config.py and renamed the ConfigValuesMain class to TrainingConfig.
Moved training_config.py to src/primaite/config/training_config.py
- Renamed all training config yaml file keys to make creating an instance of TrainingConfig easier.
Moved action_type and num_steps over to the training config.
- Decoupled the training config and lay down config.
- Refactored main.py so that it can be ran from CLI and can take a training config path and a lay down config path.
- refactored all outputs so that they save to the session dir.
- Added some necessary setup scripts that handle creating app dirs, fronting example config files to the user, fronting demo notebooks to the user, performing clean-up in between installations etc.
- Added functions that attempt to retrieve the file path of users example config files that have been fronted by the primaite setup.
- Added logging config and a getLogger function in the top-level init.
- Refactored all logs entries logged to use a logger using the primaite logging config.
- Added basic typer CLI for doing things like setup, viewing logs, viewing primaite version, running a basic session.
- Updated test to use new features and config structures.
- Began updating docs. More to do here.
This commit is contained in:
Chris McCarthy
2023-06-07 22:40:16 +01:00
parent a8ce699df3
commit 98fc1e4c71
44 changed files with 1527 additions and 1356 deletions

98
src/primaite/__init__.py Normal file
View File

@@ -0,0 +1,98 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import logging
import logging.config
import sys
from logging import Logger, StreamHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Final
from platformdirs import PlatformDirs
_PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
"""An instance of `PlatformDirs` set with appname='primaite'."""
_USER_DIRS: Final[Path] = Path.home() / "primaite"
"""The users home space for PrimAITE which is located at: ~/primaite."""
NOTEBOOKS_DIR: Final[Path] = _USER_DIRS / "notebooks"
"""
The path to the users notebooks directory as an instance of `Path` or
`PosixPath`, depending on the OS.
Users notebooks are stored at: ``~/primaite/notebooks``.
"""
USERS_CONFIG_DIR: Final[Path] = _USER_DIRS / "config"
"""
The path to the users config directory as an instance of `Path` or
`PosixPath`, depending on the OS.
Users config files are stored at: ``~/primaite/config``.
"""
SESSIONS_DIR: Final[Path] = _USER_DIRS / "sessions"
"""
The path to the users PrimAITE Sessions directory as an instance of `Path` or
`PosixPath`, depending on the OS.
Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
"""
# region Setup Logging
def _log_dir() -> Path:
if sys.platform == "win32":
dir_path = _PLATFORM_DIRS.user_data_path / "logs"
else:
dir_path = _PLATFORM_DIRS.user_log_path
return dir_path
LOG_DIR: Final[Path] = _log_dir()
"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS."""
LOG_PATH: Final[Path] = LOG_DIR / "primaite.log"
"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS."""
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
filename=LOG_PATH,
maxBytes=10485760, # 10MB
backupCount=9, # Max 100MB of logs
encoding="utf8",
)
_STREAM_HANDLER.setLevel(logging.INFO)
_FILE_HANDLER.setLevel(logging.INFO)
_LOG_FORMAT_STR: Final[
str
] = "%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s"
_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
_LOGGER = logging.getLogger(__name__)
_LOGGER.addHandler(_STREAM_HANDLER)
_LOGGER.addHandler(_FILE_HANDLER)
def getLogger(name: str) -> Logger:
"""
Get a PrimAITE logger.
:param name: The logger name. Use ``__name__``.
:return: An instance of :py:class:`logging.Logger` with the PrimAITE
logging config.
"""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
return logger
# endregion
with open(Path(__file__).parent.resolve() / "VERSION", "r") as file:
__version__ = file.readline()

125
src/primaite/cli.py Normal file
View File

@@ -0,0 +1,125 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Provides a CLI using Typer as an entry point."""
import os
import sys
import typer
app = typer.Typer()
@app.command()
def build_dirs():
"""Build the PrimAITE app directories."""
from primaite.setup import setup_app_dirs
setup_app_dirs.run()
@app.command()
def reset_notebooks(overwrite: bool = True):
"""
Force a reset of the demo notebooks in the users notebooks directory.
:param overwrite: If True, will overwrite existing demo notebooks.
"""
from primaite.setup import reset_demo_notebooks
reset_demo_notebooks.run(overwrite)
@app.command()
def logs(last_n: int = 10):
"""
Print the PrimAITE log file.
:param last_n: The number of lines to print. Default value is 10.
"""
import re
from platformdirs import PlatformDirs
yt_platform_dirs = PlatformDirs(appname="primaite")
if sys.platform == "win32":
log_dir = yt_platform_dirs.user_data_path / "logs"
else:
log_dir = yt_platform_dirs.user_log_path
log_path = os.path.join(log_dir, "primaite.log")
if os.path.isfile(log_path):
with open(log_path) as file:
lines = file.readlines()
for line in lines[-last_n:]:
print(re.sub(r"\n*", "", line))
@app.command()
def notebooks():
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
from primaite.notebooks import start_jupyter_session
start_jupyter_session()
@app.command()
def version():
"""Get the installed PrimAITE version number."""
import primaite
print(primaite.__version__)
@app.command()
def clean_up():
"""Cleans up left over files from previous version installations."""
from primaite.setup import old_installation_clean_up
old_installation_clean_up.run()
@app.command()
def setup():
"""
Perform the PrimAITE first-time setup.
WARNING: All user-data will be lost.
"""
from primaite import getLogger
from primaite.setup import (
old_installation_clean_up,
reset_demo_notebooks,
reset_example_configs,
setup_app_dirs,
)
_LOGGER = getLogger(__name__)
_LOGGER.info("Performing the PrimAITE first-time setup...")
_LOGGER.info("Building the PrimAITE app directories...")
setup_app_dirs.run()
_LOGGER.info("Rebuilding the demo notebooks...")
reset_demo_notebooks.run(overwrite_existing=True)
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
old_installation_clean_up.run()
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(tc: str, ldc: str):
"""
Run a PrimAITE session.
:param tc: The training config filepath.
:param ldc: The lay down config file path.
"""
from primaite.main import run
run(training_config_path=tc, lay_down_config_path=ldc)

View File

@@ -1,90 +0,0 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The config class."""
class ConfigValuesMain(object):
"""Class to hold main config values."""
def __init__(self):
"""Init."""
# Generic
self.agent_identifier = "" # the agent in use
self.num_episodes = 0 # number of episodes to train over
self.num_steps = 0 # number of steps in an episode
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
self.config_filename_use_case = "" # the filename for the Use Case config file
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
# Environment
self.observation_space_high_value = (
0 # The high value for the observation space
)
# Reward values
# Generic
self.all_ok = 0
# Node Hardware State
self.off_should_be_on = 0
self.off_should_be_resetting = 0
self.on_should_be_off = 0
self.on_should_be_resetting = 0
self.resetting_should_be_on = 0
self.resetting_should_be_off = 0
self.resetting = 0
# Node Software or Service State
self.good_should_be_patching = 0
self.good_should_be_compromised = 0
self.good_should_be_overwhelmed = 0
self.patching_should_be_good = 0
self.patching_should_be_compromised = 0
self.patching_should_be_overwhelmed = 0
self.patching = 0
self.compromised_should_be_good = 0
self.compromised_should_be_patching = 0
self.compromised_should_be_overwhelmed = 0
self.compromised = 0
self.overwhelmed_should_be_good = 0
self.overwhelmed_should_be_patching = 0
self.overwhelmed_should_be_compromised = 0
self.overwhelmed = 0
# Node File System State
self.good_should_be_repairing = 0
self.good_should_be_restoring = 0
self.good_should_be_corrupt = 0
self.good_should_be_destroyed = 0
self.repairing_should_be_good = 0
self.repairing_should_be_restoring = 0
self.repairing_should_be_corrupt = 0
self.repairing_should_be_destroyed = (
0 # Repairing does not fix destroyed state - you need to restore
)
self.repairing = 0
self.restoring_should_be_good = 0
self.restoring_should_be_repairing = 0
self.restoring_should_be_corrupt = (
0 # Not the optimal method (as repair will fix corruption)
)
self.restoring_should_be_destroyed = 0
self.restoring = 0
self.corrupt_should_be_good = 0
self.corrupt_should_be_repairing = 0
self.corrupt_should_be_restoring = 0
self.corrupt_should_be_destroyed = 0
self.corrupt = 0
self.destroyed_should_be_good = 0
self.destroyed_should_be_repairing = 0
self.destroyed_should_be_restoring = 0
self.destroyed_should_be_corrupt = 0
self.destroyed = 0
self.scanning = 0
# IER status
self.red_ier_running = 0
self.green_ier_blocked = 0
# Patching / Reset
self.os_patching_duration = 0 # The time taken to patch the OS
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
self.service_patching_duration = 0 # The time taken to patch a service
self.file_system_repairing_limit = 0 # The time take to repair a file
self.file_system_restoring_limit = 0 # The time take to restore a file
self.file_system_scanning_limit = 0 # The time taken to scan the file system

View File

@@ -81,6 +81,7 @@ class ActionType(Enum):
NODE = 0
ACL = 1
ANY = 2
class ObservationType(Enum):

View File

@@ -0,0 +1,91 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The config class."""
from dataclasses import dataclass
from primaite.common.enums import ActionType
@dataclass()
class TrainingConfig:
"""Class to hold main config values."""
# Generic
agent_identifier: str # The Red Agent algo/class to be used
action_type: ActionType # type of action to use (NODE/ACL/ANY)
num_episodes: int # number of episodes to train over
num_steps: int # number of steps in an episode
time_delay: int # delay between steps (ms) - applies to generic agents only
# file
session_type: str # the session type to run (TRAINING or EVALUATION)
load_agent: str # Determine whether to load an agent from file
agent_load_file: str # File path and file name of agent if you're loading one in
# Environment
observation_space_high_value: int # The high value for the observation space
# Reward values
# Generic
all_ok: int
# Node Hardware State
off_should_be_on: int
off_should_be_resetting: int
on_should_be_off: int
on_should_be_resetting: int
resetting_should_be_on: int
resetting_should_be_off: int
resetting: int
# Node Software or Service State
good_should_be_patching: int
good_should_be_compromised: int
good_should_be_overwhelmed: int
patching_should_be_good: int
patching_should_be_compromised: int
patching_should_be_overwhelmed: int
patching: int
compromised_should_be_good: int
compromised_should_be_patching: int
compromised_should_be_overwhelmed: int
compromised: int
overwhelmed_should_be_good: int
overwhelmed_should_be_patching: int
overwhelmed_should_be_compromised: int
overwhelmed: int
# Node File System State
good_should_be_repairing: int
good_should_be_restoring: int
good_should_be_corrupt: int
good_should_be_destroyed: int
repairing_should_be_good: int
repairing_should_be_restoring: int
repairing_should_be_corrupt: int
repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
repairing: int
restoring_should_be_good: int
restoring_should_be_repairing: int
restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
restoring_should_be_destroyed: int
restoring: int
corrupt_should_be_good: int
corrupt_should_be_repairing: int
corrupt_should_be_restoring: int
corrupt_should_be_destroyed: int
corrupt: int
destroyed_should_be_good: int
destroyed_should_be_repairing: int
destroyed_should_be_restoring: int
destroyed_should_be_corrupt: int
destroyed: int
scanning: int
# IER status
red_ier_running: int
green_ier_blocked: int
# Patching / Reset
os_patching_duration: int # The time taken to patch the OS
node_reset_duration: int # The time taken to reset a node (hardware)
service_patching_duration: int # The time taken to patch a service
file_system_repairing_limit: int # The time take to repair a file
file_system_restoring_limit: int # The time take to restore a file
file_system_scanning_limit: int # The time taken to scan the file system

View File

@@ -1,7 +1,4 @@
- itemType: ACTIONS
type: NODE
- itemType: STEPS
steps: 256
- itemType: PORTS
portsList:
- port: '80'

View File

@@ -0,0 +1,94 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,533 +0,0 @@
- itemType: ACTIONS
type: NODE
- itemType: STEPS
steps: 256
- itemType: PORTS
portsList:
- port: '80'
- port: '1433'
- port: '53'
- itemType: SERVICES
serviceList:
- name: TCP
- name: TCP_SQL
- name: UDP
- itemType: NODE
node_id: '1'
name: CLIENT_1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.11
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '2'
name: CLIENT_2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- itemType: NODE
node_id: '3'
name: SWITCH_1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.10.1
software_state: GOOD
file_system_state: GOOD
- itemType: NODE
node_id: '4'
name: SECURITY_SUITE
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.10
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '5'
name: MANAGEMENT_CONSOLE
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '6'
name: SWITCH_2
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.2.1
software_state: GOOD
file_system_state: GOOD
- itemType: NODE
node_id: '7'
name: WEB_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.10
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: TCP_SQL
port: '1433'
state: GOOD
- itemType: NODE
node_id: '8'
name: DATABASE_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.14
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: TCP_SQL
port: '1433'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '9'
name: BACKUP_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.16
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- itemType: LINK
id: '10'
name: LINK_1
bandwidth: 1000000000
source: '1'
destination: '3'
- itemType: LINK
id: '11'
name: LINK_2
bandwidth: 1000000000
source: '2'
destination: '3'
- itemType: LINK
id: '12'
name: LINK_3
bandwidth: 1000000000
source: '3'
destination: '4'
- itemType: LINK
id: '13'
name: LINK_4
bandwidth: 1000000000
source: '3'
destination: '5'
- itemType: LINK
id: '14'
name: LINK_5
bandwidth: 1000000000
source: '4'
destination: '6'
- itemType: LINK
id: '15'
name: LINK_6
bandwidth: 1000000000
source: '5'
destination: '6'
- itemType: LINK
id: '16'
name: LINK_7
bandwidth: 1000000000
source: '6'
destination: '7'
- itemType: LINK
id: '17'
name: LINK_8
bandwidth: 1000000000
source: '6'
destination: '8'
- itemType: LINK
id: '18'
name: LINK_9
bandwidth: 1000000000
source: '6'
destination: '9'
- itemType: GREEN_IER
id: '19'
startStep: 1
endStep: 256
load: 10000
protocol: TCP
port: '80'
source: '1'
destination: '7'
missionCriticality: 5
- itemType: GREEN_IER
id: '20'
startStep: 1
endStep: 256
load: 10000
protocol: TCP
port: '80'
source: '7'
destination: '1'
missionCriticality: 5
- itemType: GREEN_IER
id: '21'
startStep: 1
endStep: 256
load: 10000
protocol: TCP
port: '80'
source: '2'
destination: '7'
missionCriticality: 5
- itemType: GREEN_IER
id: '22'
startStep: 1
endStep: 256
load: 10000
protocol: TCP
port: '80'
source: '7'
destination: '2'
missionCriticality: 5
- itemType: GREEN_IER
id: '23'
startStep: 1
endStep: 256
load: 5000
protocol: TCP_SQL
port: '1433'
source: '7'
destination: '8'
missionCriticality: 5
- itemType: GREEN_IER
id: '24'
startStep: 1
endStep: 256
load: 100000
protocol: TCP_SQL
port: '1433'
source: '8'
destination: '7'
missionCriticality: 5
- itemType: GREEN_IER
id: '25'
startStep: 1
endStep: 256
load: 50000
protocol: TCP
port: '80'
source: '1'
destination: '9'
missionCriticality: 2
- itemType: GREEN_IER
id: '26'
startStep: 1
endStep: 256
load: 50000
protocol: TCP
port: '80'
source: '2'
destination: '9'
missionCriticality: 2
- itemType: GREEN_IER
id: '27'
startStep: 1
endStep: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '7'
missionCriticality: 1
- itemType: GREEN_IER
id: '28'
startStep: 1
endStep: 256
load: 5000
protocol: TCP
port: '80'
source: '7'
destination: '5'
missionCriticality: 1
- itemType: GREEN_IER
id: '29'
startStep: 1
endStep: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '8'
missionCriticality: 1
- itemType: GREEN_IER
id: '30'
startStep: 1
endStep: 256
load: 5000
protocol: TCP
port: '80'
source: '8'
destination: '5'
missionCriticality: 1
- itemType: GREEN_IER
id: '31'
startStep: 1
endStep: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '9'
missionCriticality: 1
- itemType: GREEN_IER
id: '32'
startStep: 1
endStep: 256
load: 5000
protocol: TCP
port: '80'
source: '9'
destination: '5'
missionCriticality: 1
- itemType: ACL_RULE
id: '33'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.10
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '34'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.14
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '35'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.14
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '36'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.10
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '37'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.10.11
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '38'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.10.12
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '39'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.2.14
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '40'
permission: ALLOW
source: 192.168.2.14
destination: 192.168.2.10
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '41'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.16
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '42'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.16
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '43'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.10
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '44'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.14
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '45'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.16
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '46'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.1.12
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '47'
permission: ALLOW
source: 192.168.2.14
destination: 192.168.1.12
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '48'
permission: ALLOW
source: 192.168.2.16
destination: 192.168.1.12
protocol: ANY
port: ANY
- itemType: ACL_RULE
id: '49'
permission: DENY
source: ANY
destination: ANY
protocol: ANY
port: ANY
- itemType: RED_POL
id: '50'
startStep: 50
endStep: 50
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: UDP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_IER
id: '51'
startStep: 75
endStep: 105
load: 10000
protocol: UDP
port: '53'
source: '1'
destination: '8'
missionCriticality: 0
- itemType: RED_POL
id: '52'
startStep: 100
endStep: 100
targetNodeId: '8'
initiator: IER
type: SERVICE
protocol: UDP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '53'
startStep: 105
endStep: 105
targetNodeId: '8'
initiator: SERVICE
type: FILE
protocol: NA
state: CORRUPT
sourceNodeId: '8'
sourceNodeService: UDP
sourceNodeServiceState: COMPROMISED
- itemType: RED_POL
id: '54'
startStep: 105
endStep: 105
targetNodeId: '8'
initiator: SERVICE
type: SERVICE
protocol: TCP_SQL
state: COMPROMISED
sourceNodeId: '8'
sourceNodeService: UDP
sourceNodeServiceState: COMPROMISED
- itemType: RED_POL
id: '55'
startStep: 125
endStep: 125
targetNodeId: '7'
initiator: SERVICE
type: SERVICE
protocol: TCP
state: OVERWHELMED
sourceNodeId: '8'
sourceNodeService: TCP_SQL
sourceNodeServiceState: COMPROMISED

View File

@@ -1,89 +0,0 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: STABLE_BASELINES3_A2C
# Number of episodes to run per session
numEpisodes: 10
# Time delay between steps (for generic agents)
timeDelay: 10
# Filename of the scenario / laydown
configFilename: config_5_DATA_MANIPULATION.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1000000000
# Reward values
# Generic
allOk: 0
# Node Hardware State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node Software or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,69 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final
from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
def ddos_basic_one_config_path() -> Path:
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml"
if not path.exists():
msg = f"Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def ddos_basic_two_config_path() -> Path:
"""
The path to the example lay_down_config_2_DDOS_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml"
if not path.exists():
msg = f"Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def dos_very_basic_config_path() -> Path:
"""
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml"
if not path.exists():
msg = f"Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def data_manipulation_config_path() -> Path:
"""
The path to the example lay_down_config_5_data_manipulation.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml"
if not path.exists():
msg = f"Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -0,0 +1,260 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Union, Final
import yaml
from primaite import getLogger, USERS_CONFIG_DIR
from primaite.common.enums import ActionType
_LOGGER = getLogger(__name__)
_EXAMPLE_TRAINING: Final[
Path] = USERS_CONFIG_DIR / "example_config" / "training"
@dataclass()
class TrainingConfig:
"""The Training Config class."""
# Generic
agent_identifier: str # The Red Agent algo/class to be used
action_type: ActionType # type of action to use (NODE/ACL/ANY)
num_episodes: int # number of episodes to train over
num_steps: int # number of steps in an episode
time_delay: int # delay between steps (ms) - applies to generic agents only
# file
session_type: str # the session type to run (TRAINING or EVALUATION)
load_agent: str # Determine whether to load an agent from file
agent_load_file: str # File path and file name of agent if you're loading one in
# Environment
observation_space_high_value: int # The high value for the observation space
# Reward values
# Generic
all_ok: int
# Node Hardware State
off_should_be_on: int
off_should_be_resetting: int
on_should_be_off: int
on_should_be_resetting: int
resetting_should_be_on: int
resetting_should_be_off: int
resetting: int
# Node Software or Service State
good_should_be_patching: int
good_should_be_compromised: int
good_should_be_overwhelmed: int
patching_should_be_good: int
patching_should_be_compromised: int
patching_should_be_overwhelmed: int
patching: int
compromised_should_be_good: int
compromised_should_be_patching: int
compromised_should_be_overwhelmed: int
compromised: int
overwhelmed_should_be_good: int
overwhelmed_should_be_patching: int
overwhelmed_should_be_compromised: int
overwhelmed: int
# Node File System State
good_should_be_repairing: int
good_should_be_restoring: int
good_should_be_corrupt: int
good_should_be_destroyed: int
repairing_should_be_good: int
repairing_should_be_restoring: int
repairing_should_be_corrupt: int
repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
repairing: int
restoring_should_be_good: int
restoring_should_be_repairing: int
restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
restoring_should_be_destroyed: int
restoring: int
corrupt_should_be_good: int
corrupt_should_be_repairing: int
corrupt_should_be_restoring: int
corrupt_should_be_destroyed: int
corrupt: int
destroyed_should_be_good: int
destroyed_should_be_repairing: int
destroyed_should_be_restoring: int
destroyed_should_be_corrupt: int
destroyed: int
scanning: int
# IER status
red_ier_running: int
green_ier_blocked: int
# Patching / Reset
os_patching_duration: int # The time taken to patch the OS
node_reset_duration: int # The time taken to reset a node (hardware)
service_patching_duration: int # The time taken to patch a service
file_system_repairing_limit: int # The time take to repair a file
file_system_restoring_limit: int # The time take to restore a file
file_system_scanning_limit: int # The time taken to scan the file system
def main_training_config_path() -> Path:
"""
The path to the example training_config_main.yaml file
:return: The file path.
"""
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
if not path.exists():
msg = f"Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def load(file_path: Union[str, Path],
legacy_file: bool = False) -> TrainingConfig:
"""
Read in a training config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise
False.
:return: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:raises ValueError: If the file_path does not exist.
:raises TypeError: When the TrainingConfig object cannot be created
using the values from the config file read from ``file_path``.
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loading training config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_training_config_dict(config)
except KeyError:
msg = (
f"Failed to convert training config file {file_path} "
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
# Convert values to Enums
config["action_type"] = ActionType[config["action_type"]]
try:
return TrainingConfig(**config)
except TypeError as e:
msg = (
f"Error when creating an instance of {TrainingConfig} "
f"from the training config file {file_path}"
)
_LOGGER.critical(msg, exc_info=True)
raise e
msg = f"Cannot load the training config as it does not exist: {file_path}"
_LOGGER.error(msg)
raise ValueError(msg)
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
num_steps: int = 256,
action_type: str = "ANY"
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:return: The converted training config dict.
"""
config_dict = {"num_steps": num_steps, "action_type": action_type}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:
config_dict[new_key] = value
return config_dict
def _get_new_key_from_legacy(legacy_key: str) -> str:
"""
Maps legacy training config keys to the new format keys.
:param legacy_key: A legacy training config key.
:return: The mapped key.
"""
key_mapping = {
"agentIdentifier": "agent_identifier",
"numEpisodes": "num_episodes",
"timeDelay": "time_delay",
"configFilename": None,
"sessionType": "session_type",
"loadAgent": "load_agent",
"agentLoadFile": "agent_load_file",
"observationSpaceHighValue": "observation_space_high_value",
"allOk": "all_ok",
"offShouldBeOn": "off_should_be_on",
"offShouldBeResetting": "off_should_be_resetting",
"onShouldBeOff": "on_should_be_off",
"onShouldBeResetting": "on_should_be_resetting",
"resettingShouldBeOn": "resetting_should_be_on",
"resettingShouldBeOff": "resetting_should_be_off",
"resetting": "resetting",
"goodShouldBePatching": "good_should_be_patching",
"goodShouldBeCompromised": "good_should_be_compromised",
"goodShouldBeOverwhelmed": "good_should_be_overwhelmed",
"patchingShouldBeGood": "patching_should_be_good",
"patchingShouldBeCompromised": "patching_should_be_compromised",
"patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed",
"patching": "patching",
"compromisedShouldBeGood": "compromised_should_be_good",
"compromisedShouldBePatching": "compromised_should_be_patching",
"compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed",
"compromised": "compromised",
"overwhelmedShouldBeGood": "overwhelmed_should_be_good",
"overwhelmedShouldBePatching": "overwhelmed_should_be_patching",
"overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised",
"overwhelmed": "overwhelmed",
"goodShouldBeRepairing": "good_should_be_repairing",
"goodShouldBeRestoring": "good_should_be_restoring",
"goodShouldBeCorrupt": "good_should_be_corrupt",
"goodShouldBeDestroyed": "good_should_be_destroyed",
"repairingShouldBeGood": "repairing_should_be_good",
"repairingShouldBeRestoring": "repairing_should_be_restoring",
"repairingShouldBeCorrupt": "repairing_should_be_corrupt",
"repairingShouldBeDestroyed": "repairing_should_be_destroyed",
"repairing": "repairing",
"restoringShouldBeGood": "restoring_should_be_good",
"restoringShouldBeRepairing": "restoring_should_be_repairing",
"restoringShouldBeCorrupt": "restoring_should_be_corrupt",
"restoringShouldBeDestroyed": "restoring_should_be_destroyed",
"restoring": "restoring",
"corruptShouldBeGood": "corrupt_should_be_good",
"corruptShouldBeRepairing": "corrupt_should_be_repairing",
"corruptShouldBeRestoring": "corrupt_should_be_restoring",
"corruptShouldBeDestroyed": "corrupt_should_be_destroyed",
"corrupt": "corrupt",
"destroyedShouldBeGood": "destroyed_should_be_good",
"destroyedShouldBeRepairing": "destroyed_should_be_repairing",
"destroyedShouldBeRestoring": "destroyed_should_be_restoring",
"destroyedShouldBeCorrupt": "destroyed_should_be_corrupt",
"destroyed": "destroyed",
"scanning": "scanning",
"redIerRunning": "red_ier_running",
"greenIerBlocked": "green_ier_blocked",
"osPatchingDuration": "os_patching_duration",
"nodeResetDuration": "node_reset_duration",
"servicePatchingDuration": "service_patching_duration",
"fileSystemRepairingLimit": "file_system_repairing_limit",
"fileSystemRestoringLimit": "file_system_restoring_limit",
"fileSystemScanningLimit": "file_system_scanning_limit",
}
return key_mapping[legacy_key]

View File

@@ -6,7 +6,8 @@ import csv
import logging
import os.path
from datetime import datetime
from typing import Dict, Tuple
from pathlib import Path
from typing import Dict, Tuple, Union
import networkx as nx
import numpy as np
@@ -28,6 +29,8 @@ from primaite.common.enums import (
SoftwareState,
)
from primaite.common.service import Service
from primaite.config import training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
@@ -56,26 +59,36 @@ class Primaite(Env):
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
def __init__(self, _config_values, _transaction_list):
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
):
"""
Init.
The Primaite constructor.
Args:
_episode_steps: The number of steps for the episode
_config_filename: The name of config file
_transaction_list: The list of transactions to populate
_agent_identifier: Identifier for the agent
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
:param transaction_list: The list of transactions to populate.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
super(Primaite, self).__init__()
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
# Take a copy of the config values
self.config_values = _config_values
self.config_values: TrainingConfig = training_config.load(training_config_path)
# Number of steps in an episode
self.episode_steps = 0
self.episode_steps = self.config_values.num_steps
super(Primaite, self).__init__()
# Transaction list
self.transaction_list = _transaction_list
self.transaction_list = transaction_list
# The agent in use
self.agent_identifier = self.config_values.agent_identifier
@@ -153,13 +166,10 @@ class Primaite(Env):
self.observation_type = ObservationType.BOX
# Open the config file and build the environment laydown
try:
self.config_file = open(self.config_values.config_filename_use_case, "r")
self.config_data = yaml.safe_load(self.config_file)
self.load_config()
except Exception:
_LOGGER.error("Could not load the environment configuration")
_LOGGER.error("Exception occured", exc_info=True)
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
self.config_data = yaml.safe_load(file)
self.load_lay_down_config()
# Store the node objects as node attributes
# (This is so we can access them as objects)
@@ -179,12 +189,8 @@ class Primaite(Env):
now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S")
path = "outputs/diagrams"
is_dir = os.path.isdir(path)
if not is_dir:
os.makedirs(path)
filename = "outputs/diagrams/network_" + time + ".png"
plt.savefig(filename, format="PNG")
file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG")
plt.clf()
except Exception:
_LOGGER.error("Could not save network diagram")
@@ -236,13 +242,9 @@ class Primaite(Env):
time = now.strftime("%Y%m%d_%H%M%S")
header = ["Episode", "Average Reward"]
# Check whether the output/rerults folder exists (doesn't exist by default install)
path = "outputs/results/"
is_dir = os.path.isdir(path)
if not is_dir:
os.makedirs(path)
filename = "outputs/results/average_reward_per_episode_" + time + ".csv"
self.csv_file = open(filename, "w", encoding="UTF8", newline="")
file_name = f"average_reward_per_episode_{timestamp_str}.csv"
file_path = session_path / file_name
self.csv_file = open(file_path, "w", encoding="UTF8", newline="")
self.csv_writer = csv.writer(self.csv_file)
self.csv_writer.writerow(header)
except Exception:
@@ -404,7 +406,6 @@ class Primaite(Env):
def __close__(self):
"""Override close function."""
self.csv_file.close()
self.config_file.close()
def init_acl(self):
"""Initialise the Access Control List."""
@@ -888,7 +889,7 @@ class Primaite(Env):
elif self.observation_type == ObservationType.MULTIDISCRETE:
self._update_env_obs_multidiscrete()
def load_config(self):
def load_lay_down_config(self):
"""Loads config data in order to build the environment configuration."""
for item in self.config_data:
if item["itemType"] == "NODE":
@@ -918,15 +919,9 @@ class Primaite(Env):
elif item["itemType"] == "PORTS":
# Create the list of ports
self.create_ports_list(item)
elif item["itemType"] == "ACTIONS":
# Get the action information
self.get_action_info(item)
elif item["itemType"] == "OBSERVATIONS":
# Get the observation information
self.get_observation_info(item)
elif item["itemType"] == "STEPS":
# Get the steps information
self.get_steps_info(item)
else:
# Do nothing (bad formatting)
pass
@@ -1247,15 +1242,6 @@ class Primaite(Env):
# Set the number of ports
self.num_ports = len(self.ports_list)
def get_action_info(self, action_info):
"""
Extracts action_info.
Args:
item: A config data item representing action info
"""
self.action_type = ActionType[action_info["type"]]
def get_observation_info(self, observation_info):
"""Extracts observation_info.
@@ -1264,16 +1250,6 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_steps_info(self, steps_info):
"""
Extracts steps_info.
Args:
item: A config data item representing steps info
"""
self.episode_steps = int(steps_info["steps"])
_LOGGER.info("Training episodes have " + str(self.episode_steps) + " steps")
def reset_environment(self):
"""
# Resets environment.

View File

@@ -1,29 +1,41 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""
Primaite - main (harness) module.
The main PrimAITE session runner module.
Coding Standards: PEP 8
TODO: This will eventually be refactored out into a proper Session class.
TODO: The passing about of session_dir and timestamp_str is temporary and
will be cleaned up once we move to a proper Session class.
"""
import logging
import os.path
import argparse
import time
from datetime import datetime
from pathlib import Path
from typing import Final, Union
import yaml
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite.common.config_values_main import ConfigValuesMain
from primaite import SESSIONS_DIR, getLogger
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.config.training_config import TrainingConfig, \
main_training_config_path
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import write_transaction_to_file
# FUNCTIONS #
_LOGGER = getLogger(__name__)
def run_generic():
"""Run against a generic agent."""
def run_generic(env: Primaite, config_values: TrainingConfig):
"""
Run against a generic agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
"""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
@@ -47,9 +59,24 @@ def run_generic():
env.close()
def run_stable_baselines3_ppo():
"""Run against a stable_baselines3 PPO agent."""
if config_values.load_agent == True:
def run_stable_baselines3_ppo(
env: Primaite,
config_values: TrainingConfig,
session_path: Path,
timestamp_str: str
):
"""
Run against a stable_baselines3 PPO agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = PPO.load(
config_values.agent_load_file,
@@ -62,30 +89,44 @@ def run_stable_baselines3_ppo():
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
logging.error("Could not load agent")
logging.error("Exception occured", exc_info=True)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
logging.info("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(0, config_values.num_episodes):
agent.learn(total_timesteps=1)
save_agent(agent)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
logging.info("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def run_stable_baselines3_a2c():
"""Run against a stable_baselines3 A2C agent."""
if config_values.load_agent == True:
def run_stable_baselines3_a2c(
env: Primaite,
config_values: TrainingConfig,
session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 A2C agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = A2C.load(
config_values.agent_load_file,
@@ -98,284 +139,151 @@ def run_stable_baselines3_a2c():
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
logging.error("Could not load agent")
logging.error("Exception occured", exc_info=True)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
logging.info("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(0, config_values.num_episodes):
agent.learn(total_timesteps=1)
save_agent(agent)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
logging.info("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def save_agent(_agent):
"""Persist an agent (only works for stable baselines3 agents at present)."""
now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S")
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
"""
Persist an agent.
try:
path = "outputs/agents/"
is_dir = os.path.isdir(path)
if not is_dir:
os.makedirs(path)
filename = "outputs/agents/agent_saved_" + time
_agent.save(filename)
logging.info("Trained agent saved as " + filename)
except Exception:
logging.error("Could not save agent")
logging.error("Exception occured", exc_info=True)
Only works for stable baselines3 agents at present.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if not isinstance(agent, OnPolicyAlgorithm):
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
_LOGGER.error(msg)
else:
filepath = session_path / f"agent_saved_{timestamp_str}"
agent.save(filepath)
_LOGGER.debug(f"Trained agent saved as: {filepath}")
def configure_logging():
"""Configures logging."""
try:
now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S")
filename = "logs/app_" + time + ".log"
path = "logs/"
is_dir = os.path.isdir(path)
if not is_dir:
os.makedirs(path)
logging.basicConfig(
filename=filename,
filemode="w",
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
level=logging.INFO,
)
except Exception:
print("ERROR: Could not start logging")
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def load_config_values():
"""Loads the config values from the main config file into a config object."""
try:
# Generic
config_values.agent_identifier = config_data["agentIdentifier"]
config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = (
"config/" + config_data["configFilename"]
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path]
):
"""Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
"""
# Welcome message
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
session_timestamp: Final[datetime] = datetime.now()
session_path = _get_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
print(f"The output directory for this session is: {session_path}")
# Create a list of transactions
# A transaction is an object holding the:
# - episode #
# - step #
# - initial observation space
# - action
# - reward
# - new observation space
transaction_list = []
# Create the Primaite environment
env = Primaite(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
transaction_list=transaction_list,
session_path=session_path,
timestamp_str=timestamp_str,
)
config_values = env.config_values
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = env.episode_steps
# Run environment against an agent
if config_values.agent_identifier == "GENERIC":
run_generic(env=env, config_values=config_values)
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_path,
timestamp_str=timestamp_str,
)
config_values.session_type = config_data["sessionType"]
config_values.load_agent = bool(config_data["loadAgent"])
config_values.agent_load_file = config_data["agentLoadFile"]
# Environment
config_values.observation_space_high_value = int(
config_data["observationSpaceHighValue"]
)
# Reward values
# Generic
config_values.all_ok = int(config_data["allOk"])
# Node Hardware State
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
config_values.resetting = int(config_data["resetting"])
# Node Software or Service State
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
config_values.good_should_be_compromised = int(
config_data["goodShouldBeCompromised"]
)
config_values.good_should_be_overwhelmed = int(
config_data["goodShouldBeOverwhelmed"]
)
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
config_values.patching_should_be_compromised = int(
config_data["patchingShouldBeCompromised"]
)
config_values.patching_should_be_overwhelmed = int(
config_data["patchingShouldBeOverwhelmed"]
)
config_values.patching = int(config_data["patching"])
config_values.compromised_should_be_good = int(
config_data["compromisedShouldBeGood"]
)
config_values.compromised_should_be_patching = int(
config_data["compromisedShouldBePatching"]
)
config_values.compromised_should_be_overwhelmed = int(
config_data["compromisedShouldBeOverwhelmed"]
)
config_values.compromised = int(config_data["compromised"])
config_values.overwhelmed_should_be_good = int(
config_data["overwhelmedShouldBeGood"]
)
config_values.overwhelmed_should_be_patching = int(
config_data["overwhelmedShouldBePatching"]
)
config_values.overwhelmed_should_be_compromised = int(
config_data["overwhelmedShouldBeCompromised"]
)
config_values.overwhelmed = int(config_data["overwhelmed"])
# Node File System State
config_values.good_should_be_repairing = int(
config_data["goodShouldBeRepairing"]
)
config_values.good_should_be_restoring = int(
config_data["goodShouldBeRestoring"]
)
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
config_values.good_should_be_destroyed = int(
config_data["goodShouldBeDestroyed"]
)
config_values.repairing_should_be_good = int(
config_data["repairingShouldBeGood"]
)
config_values.repairing_should_be_restoring = int(
config_data["repairingShouldBeRestoring"]
)
config_values.repairing_should_be_corrupt = int(
config_data["repairingShouldBeCorrupt"]
)
config_values.repairing_should_be_destroyed = int(
config_data["repairingShouldBeDestroyed"]
)
config_values.repairing = int(config_data["repairing"])
config_values.restoring_should_be_good = int(
config_data["restoringShouldBeGood"]
)
config_values.restoring_should_be_repairing = int(
config_data["restoringShouldBeRepairing"]
)
config_values.restoring_should_be_corrupt = int(
config_data["restoringShouldBeCorrupt"]
)
config_values.restoring_should_be_destroyed = int(
config_data["restoringShouldBeDestroyed"]
)
config_values.restoring = int(config_data["restoring"])
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
config_values.corrupt_should_be_repairing = int(
config_data["corruptShouldBeRepairing"]
)
config_values.corrupt_should_be_restoring = int(
config_data["corruptShouldBeRestoring"]
)
config_values.corrupt_should_be_destroyed = int(
config_data["corruptShouldBeDestroyed"]
)
config_values.corrupt = int(config_data["corrupt"])
config_values.destroyed_should_be_good = int(
config_data["destroyedShouldBeGood"]
)
config_values.destroyed_should_be_repairing = int(
config_data["destroyedShouldBeRepairing"]
)
config_values.destroyed_should_be_restoring = int(
config_data["destroyedShouldBeRestoring"]
)
config_values.destroyed_should_be_corrupt = int(
config_data["destroyedShouldBeCorrupt"]
)
config_values.destroyed = int(config_data["destroyed"])
config_values.scanning = int(config_data["scanning"])
# IER status
config_values.red_ier_running = int(config_data["redIerRunning"])
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
# Patching / Reset durations
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
config_values.service_patching_duration = int(
config_data["servicePatchingDuration"]
)
config_values.file_system_repairing_limit = int(
config_data["fileSystemRepairingLimit"]
)
config_values.file_system_restoring_limit = int(
config_data["fileSystemRestoringLimit"]
)
config_values.file_system_scanning_limit = int(
config_data["fileSystemScanningLimit"]
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_path,
timestamp_str=timestamp_str,
)
logging.info("Training agent: " + config_values.agent_identifier)
logging.info(
"Training environment config: " + config_values.config_filename_use_case
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
_LOGGER.debug("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_path,
timestamp_str=timestamp_str,
)
print("Finished")
_LOGGER.debug("Finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
args = parser.parse_args()
if not args.tc:
_LOGGER.error(
"Please provide a training config file using the --tc " "argument"
)
logging.info(
"Training cycle has " + str(config_values.num_episodes) + " episodes"
if not args.ldc:
_LOGGER.error(
"Please provide a lay down config file using the --ldc " "argument"
)
except Exception:
logging.error("Could not save load config data")
logging.error("Exception occured", exc_info=True)
# MAIN PROCESS #
# Starting point
# Welcome message
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
# Configure logging
configure_logging()
# Open the main config file
try:
config_file_main = open("config/config_main.yaml", "r")
config_data = yaml.safe_load(config_file_main)
# Create a config class
config_values = ConfigValuesMain()
# Load in config data
load_config_values()
except Exception:
logging.error("Could not load main config")
logging.error("Exception occured", exc_info=True)
# Create a list of transactions
# A transaction is an object holding the:
# - episode #
# - step #
# - initial observation space
# - action
# - reward
# - new observation space
transaction_list = []
# Create the Primaite environment
# try:
env = Primaite(config_values, transaction_list)
# logging.info("PrimAITE environment created")
# except Exception:
# logging.error("Could not create PrimAITE environment")
# logging.error("Exception occured", exc_info=True)
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = env.episode_steps
# Run environment against an agent
if config_values.agent_identifier == "GENERIC":
run_generic()
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo()
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c()
print("Session finished")
logging.info("Session finished")
print("Saving transaction logs...")
logging.info("Saving transaction logs...")
write_transaction_to_file(transaction_list)
config_file_main.close()
print("Finished")
logging.info("Finished")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -3,7 +3,6 @@
import logging
from typing import Final
from primaite.common.config_values_main import ConfigValuesMain
from primaite.common.enums import (
FileSystemState,
HardwareState,
@@ -11,6 +10,7 @@ from primaite.common.enums import (
Priority,
SoftwareState,
)
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ class ActiveNode(Node):
ip_address: str,
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: ConfigValuesMain,
config_values: TrainingConfig,
):
"""
Init.

View File

@@ -2,8 +2,8 @@
"""The base Node class."""
from typing import Final
from primaite.common.config_values_main import ConfigValuesMain
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
class Node:
@@ -16,7 +16,7 @@ class Node:
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
config_values: ConfigValuesMain,
config_values: TrainingConfig,
):
"""
Init.
@@ -34,7 +34,7 @@ class Node:
self.priority = priority
self.hardware_state: HardwareState = hardware_state
self.resetting_count: int = 0
self.config_values: ConfigValuesMain = config_values
self.config_values: TrainingConfig = config_values
def __repr__(self):
"""Returns the name of the node."""

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.config_values_main import ConfigValuesMain
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
@@ -15,7 +15,7 @@ class PassiveNode(Node):
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
config_values: ConfigValuesMain,
config_values: TrainingConfig,
):
"""
Init.

View File

@@ -3,7 +3,6 @@
import logging
from typing import Dict, Final
from primaite.common.config_values_main import ConfigValuesMain
from primaite.common.enums import (
FileSystemState,
HardwareState,
@@ -12,6 +11,7 @@ from primaite.common.enums import (
SoftwareState,
)
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ class ServiceNode(ActiveNode):
ip_address: str,
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: ConfigValuesMain,
config_values: TrainingConfig,
):
"""
Init.

View File

@@ -0,0 +1,36 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import importlib.util
import os
import subprocess
import sys
from primaite import NOTEBOOKS_DIR, getLogger
_LOGGER = getLogger(__name__)
def start_jupyter_session():
"""
Starts a new Jupyter notebook session in the app notebooks directory.
Currently only works on Windows OS.
.. todo:: Figure out how to get this working for Linux and MacOS too.
"""
if sys.platform == "win32":
if importlib.util.find_spec("jupyter") is not None:
# Jupyter is installed
working_dir = os.getcwd()
os.chdir(NOTEBOOKS_DIR)
subprocess.Popen("jupyter lab")
os.chdir(working_dir)
else:
# Jupyter is not installed
_LOGGER.error("Cannot start jupyter lab as it is not installed")
else:
msg = (
"Feature currently only supported on Windows OS. For "
"Linux/MacOS users, run 'cd ~/primaite/notebooks; jupyter "
"lab' from your Python environment."
)
_LOGGER.warning(msg)

View File

@@ -0,0 +1 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -0,0 +1,13 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from primaite import getLogger
_LOGGER = getLogger(__name__)
def run():
"""Perform the full clean-up."""
pass
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,39 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import filecmp
import os
import shutil
from pathlib import Path
import pkg_resources
from primaite import NOTEBOOKS_DIR, getLogger
_LOGGER = getLogger(__name__)
def run(overwrite_existing: bool = True):
"""
Resets the demo jupyter notebooks in the users app notebooks directory.
:param overwrite_existing: A bool to toggle replacing existing edited
notebooks on or off.
"""
notebooks_package_data_root = pkg_resources.resource_filename(
"primaite", "notebooks/_package_data"
)
for subdir, dirs, files in os.walk(notebooks_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
target_fp = NOTEBOOKS_DIR / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()
if overwrite_existing and not copy_file:
copy_file = (not filecmp.cmp(fp, target_fp)) and (
".ipynb_checkpoints" not in str(target_fp)
)
if copy_file:
shutil.copy2(fp, target_fp)
_LOGGER.info(f"Reset example notebook: {target_fp}")

View File

@@ -0,0 +1,37 @@
import filecmp
import os
import shutil
from pathlib import Path
import pkg_resources
from primaite import USERS_CONFIG_DIR, getLogger
_LOGGER = getLogger(__name__)
def run(overwrite_existing=True):
"""
Resets the example config files in the users app config directory.
:param overwrite_existing: A bool to toggle replacing existing edited
config on or off.
"""
configs_package_data_root = pkg_resources.resource_filename(
"primaite", "config/_package_data"
)
for subdir, dirs, files in os.walk(configs_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()
if overwrite_existing and not copy_file:
copy_file = not filecmp.cmp(fp, target_fp)
if copy_file:
shutil.copy2(fp, target_fp)
_LOGGER.info(f"Reset example config: {target_fp}")

View File

@@ -0,0 +1,27 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger
_LOGGER = getLogger(__name__)
def run():
"""
Handles creation of application directories and user directories.
Uses `platformdirs.PlatformDirs` and `pathlib.Path` to create the required
app directories in the correct locations based on the users OS.
"""
app_dirs = [
_USER_DIRS,
NOTEBOOKS_DIR,
LOG_DIR,
]
for app_dir in app_dirs:
if not app_dir.is_dir():
app_dir.mkdir(parents=True, exist_ok=True)
_LOGGER.info(f"Created directory: {app_dir}")
if __name__ == "__main__":
run()

View File

@@ -2,9 +2,11 @@
"""Writes the Transaction log list out to file for evaluation to utilse."""
import csv
import logging
import os.path
from datetime import datetime
from pathlib import Path
from primaite import getLogger
_LOGGER = getLogger(__name__)
def turn_action_space_to_array(_action_space):
@@ -38,18 +40,22 @@ def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
return return_array
def write_transaction_to_file(_transaction_list):
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
"""
Writes transaction logs to file to support training evaluation.
Args:
_transaction_list: The list of transactions from all steps and all episodes
_num_episodes: The number of episodes that were conducted.
:param transaction_list: The list of transactions from all steps and all
episodes.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
# Get the first transaction and use it to determine the makeup of the observation space and action space
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action space as "AS_1"
# Get the first transaction and use it to determine the makeup of the
# observation space and action space
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
# space as "AS_1"
# This will be tied into the PrimAITE Use Case so that they make sense
template_transation = _transaction_list[0]
template_transation = transaction_list[0]
action_length = template_transation.action_space.size
obs_shape = template_transation.obs_space_post.shape
obs_assets = template_transation.obs_space_post.shape[0]
@@ -75,21 +81,15 @@ def write_transaction_to_file(_transaction_list):
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_header_initial + obs_header_new
now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S")
try:
path = "outputs/results/"
is_dir = os.path.isdir(path)
if not is_dir:
os.makedirs(path)
filename = "outputs/results/all_transactions_" + time + ".csv"
filename = session_path / f"all_transactions_{timestamp_str}.csv"
csv_file = open(filename, "w", encoding="UTF8", newline="")
csv_writer = csv.writer(csv_file)
csv_writer.writerow(header)
for transaction in _transaction_list:
for transaction in transaction_list:
csv_data = [
str(transaction.timestamp),
str(transaction.episode_number),
@@ -110,5 +110,4 @@ def write_transaction_to_file(_transaction_list):
csv_file.close()
except Exception:
logging.error("Could not save the transaction file")
logging.error("Exception occured", exc_info=True)
_LOGGER.error("Could not save the transaction file", exc_info=True)

View File

View File

@@ -0,0 +1,32 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import os
from pathlib import Path
import pkg_resources
from primaite import getLogger
_LOGGER = getLogger(__name__)
def get_file_path(path: str) -> Path:
"""
Get PrimAITE package data.
:Example:
>>> from primaite.utils.package_data import get_file_path
>>> main_env_config = get_file_path("config/_package_data/training_config_main.yaml")
:param path: The path from the primaite root.
:return: The file path of the package data file.
:raise FileNotFoundError: When the filepath does not exist.
"""
fp = pkg_resources.resource_filename("primaite", path)
if os.path.isfile(fp):
return Path(fp)
else:
msg = f"Cannot PrimAITE package data: {fp}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)