#915 - Synced with dev
This commit is contained in:
@@ -7,11 +7,33 @@ from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
from platformdirs import PlatformDirs
|
||||
|
||||
_PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
"""An instance of `PlatformDirs` set with appname='primaite'."""
|
||||
|
||||
def _get_primaite_config():
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
)
|
||||
)
|
||||
with open(config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config
|
||||
|
||||
|
||||
_PRIMAITE_CONFIG = _get_primaite_config()
|
||||
|
||||
# PrimAITE config items
|
||||
_LOG_LEVEL: int = _PRIMAITE_CONFIG["log_level"]
|
||||
_LOGGER_FORMAT: str = _PRIMAITE_CONFIG["logger_format"]
|
||||
|
||||
|
||||
_USER_DIRS: Final[Path] = Path.home() / "primaite"
|
||||
"""The users home space for PrimAITE which is located at: ~/primaite."""
|
||||
|
||||
@@ -64,12 +86,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
|
||||
backupCount=9, # Max 100MB of logs
|
||||
encoding="utf8",
|
||||
)
|
||||
_STREAM_HANDLER.setLevel(logging.DEBUG)
|
||||
_FILE_HANDLER.setLevel(logging.DEBUG)
|
||||
_STREAM_HANDLER.setLevel(_LOG_LEVEL)
|
||||
_FILE_HANDLER.setLevel(_LOG_LEVEL)
|
||||
|
||||
_LOG_FORMAT_STR: Final[
|
||||
str
|
||||
] = "%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s"
|
||||
_LOG_FORMAT_STR: Final[str] = _LOGGER_FORMAT
|
||||
_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
|
||||
@@ -88,7 +108,7 @@ def getLogger(name: str) -> Logger:
|
||||
logging config.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.setLevel(_LOG_LEVEL)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""Provides a CLI using Typer as an entry point."""
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
import typer
|
||||
from platformdirs import PlatformDirs
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -79,12 +83,24 @@ def clean_up():
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup():
|
||||
def setup(overwrite_existing: bool = True):
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
app_dirs = PlatformDirs(appname="primaite")
|
||||
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
||||
build_config = overwrite_existing or (not user_config_path.exists())
|
||||
if build_config:
|
||||
pkg_config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
)
|
||||
)
|
||||
|
||||
shutil.copy2(pkg_config_path, user_config_path)
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import (
|
||||
old_installation_clean_up,
|
||||
@@ -97,6 +113,9 @@ def setup():
|
||||
|
||||
_LOGGER.info("Performing the PrimAITE first-time setup...")
|
||||
|
||||
if build_config:
|
||||
_LOGGER.info("Building primaite_config.yaml...")
|
||||
|
||||
_LOGGER.info("Building the PrimAITE app directories...")
|
||||
setup_app_dirs.run()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union
|
||||
from typing import Any, Dict, Final, Union, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -18,85 +18,118 @@ class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
|
||||
# Generic
|
||||
agent_identifier: str # The Red Agent algo/class to be used
|
||||
action_type: ActionType # type of action to use (NODE/ACL/ANY)
|
||||
num_episodes: int # number of episodes to train over
|
||||
num_steps: int # number of steps in an episode
|
||||
time_delay: int # delay between steps (ms) - applies to generic agents only
|
||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||
"The Red Agent algo/class to be used."
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use."
|
||||
|
||||
num_episodes: int = 10
|
||||
"The number of episodes to train over."
|
||||
|
||||
num_steps: int = 256
|
||||
"The number of steps in an episode."
|
||||
observation_space: dict = field(
|
||||
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
)
|
||||
"The observation space config dict."
|
||||
|
||||
time_delay: int = 10
|
||||
"The delay between steps (ms). Applies to generic agents only."
|
||||
|
||||
# file
|
||||
session_type: str # the session type to run (TRAINING or EVALUATION)
|
||||
load_agent: str # Determine whether to load an agent from file
|
||||
agent_load_file: str # File path and file name of agent if you're loading one in
|
||||
session_type: str = "TRAINING"
|
||||
"the session type to run (TRAINING or EVALUATION)"
|
||||
|
||||
load_agent: str = False
|
||||
"Determine whether to load an agent from file."
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
"File path and file name of agent if you're loading one in."
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int # The high value for the observation space
|
||||
observation_space_high_value: int = 1000000000
|
||||
"The high value for the observation space."
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: int
|
||||
all_ok: int = 0
|
||||
|
||||
# Node Hardware State
|
||||
off_should_be_on: int
|
||||
off_should_be_resetting: int
|
||||
on_should_be_off: int
|
||||
on_should_be_resetting: int
|
||||
resetting_should_be_on: int
|
||||
resetting_should_be_off: int
|
||||
resetting: int
|
||||
off_should_be_on: int = -10
|
||||
off_should_be_resetting: int = -5
|
||||
on_should_be_off: int = -2
|
||||
on_should_be_resetting: int = -5
|
||||
resetting_should_be_on: int = -5
|
||||
resetting_should_be_off: int = -2
|
||||
resetting: int = -3
|
||||
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: int
|
||||
good_should_be_compromised: int
|
||||
good_should_be_overwhelmed: int
|
||||
patching_should_be_good: int
|
||||
patching_should_be_compromised: int
|
||||
patching_should_be_overwhelmed: int
|
||||
patching: int
|
||||
compromised_should_be_good: int
|
||||
compromised_should_be_patching: int
|
||||
compromised_should_be_overwhelmed: int
|
||||
compromised: int
|
||||
overwhelmed_should_be_good: int
|
||||
overwhelmed_should_be_patching: int
|
||||
overwhelmed_should_be_compromised: int
|
||||
overwhelmed: int
|
||||
good_should_be_patching: int = 2
|
||||
good_should_be_compromised: int = 5
|
||||
good_should_be_overwhelmed: int = 5
|
||||
patching_should_be_good: int = -5
|
||||
patching_should_be_compromised: int = 2
|
||||
patching_should_be_overwhelmed: int = 2
|
||||
patching: int = -3
|
||||
compromised_should_be_good: int = -20
|
||||
compromised_should_be_patching: int = -20
|
||||
compromised_should_be_overwhelmed: int = -20
|
||||
compromised: int = -20
|
||||
overwhelmed_should_be_good: int = -20
|
||||
overwhelmed_should_be_patching: int = -20
|
||||
overwhelmed_should_be_compromised: int = -20
|
||||
overwhelmed: int = -20
|
||||
|
||||
# Node File System State
|
||||
good_should_be_repairing: int
|
||||
good_should_be_restoring: int
|
||||
good_should_be_corrupt: int
|
||||
good_should_be_destroyed: int
|
||||
repairing_should_be_good: int
|
||||
repairing_should_be_restoring: int
|
||||
repairing_should_be_corrupt: int
|
||||
repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
|
||||
good_should_be_repairing: int = 2
|
||||
good_should_be_restoring: int = 2
|
||||
good_should_be_corrupt: int = 5
|
||||
good_should_be_destroyed: int = 10
|
||||
repairing_should_be_good: int = -5
|
||||
repairing_should_be_restoring: int = 2
|
||||
repairing_should_be_corrupt: int = 2
|
||||
repairing_should_be_destroyed: int = 0
|
||||
repairing: int = -3
|
||||
restoring_should_be_good: int = -10
|
||||
restoring_should_be_repairing: int = -2
|
||||
restoring_should_be_corrupt: int = 1
|
||||
restoring_should_be_destroyed: int = 2
|
||||
restoring: int = -6
|
||||
corrupt_should_be_good: int = -10
|
||||
corrupt_should_be_repairing: int = -10
|
||||
corrupt_should_be_restoring: int = -10
|
||||
corrupt_should_be_destroyed: int = 2
|
||||
corrupt: int = -10
|
||||
destroyed_should_be_good: int = -20
|
||||
destroyed_should_be_repairing: int = -20
|
||||
destroyed_should_be_restoring: int = -20
|
||||
destroyed_should_be_corrupt: int = -20
|
||||
destroyed: int = -20
|
||||
scanning: int = -2
|
||||
|
||||
repairing: int
|
||||
restoring_should_be_good: int
|
||||
restoring_should_be_repairing: int
|
||||
restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
|
||||
|
||||
restoring_should_be_destroyed: int
|
||||
restoring: int
|
||||
corrupt_should_be_good: int
|
||||
corrupt_should_be_repairing: int
|
||||
corrupt_should_be_restoring: int
|
||||
corrupt_should_be_destroyed: int
|
||||
corrupt: int
|
||||
destroyed_should_be_good: int
|
||||
destroyed_should_be_repairing: int
|
||||
destroyed_should_be_restoring: int
|
||||
destroyed_should_be_corrupt: int
|
||||
destroyed: int
|
||||
scanning: int
|
||||
# IER status
|
||||
red_ier_running: int
|
||||
green_ier_blocked: int
|
||||
red_ier_running: int = -5
|
||||
green_ier_blocked: int = -10
|
||||
|
||||
# Patching / Reset
|
||||
os_patching_duration: int # The time taken to patch the OS
|
||||
node_reset_duration: int # The time taken to reset a node (hardware)
|
||||
service_patching_duration: int # The time taken to patch a service
|
||||
file_system_repairing_limit: int # The time take to repair a file
|
||||
file_system_restoring_limit: int # The time take to restore a file
|
||||
file_system_scanning_limit: int # The time taken to scan the file system
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: int = 5
|
||||
"The time taken to patch the OS."
|
||||
|
||||
node_reset_duration: int = 5
|
||||
"The time taken to reset a node (hardware)."
|
||||
|
||||
service_patching_duration: int = 5
|
||||
"The time taken to patch a service."
|
||||
|
||||
file_system_repairing_limit: int = 5
|
||||
"The time take to repair the file system."
|
||||
|
||||
file_system_restoring_limit: int = 5
|
||||
"The time take to restore the file system."
|
||||
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system."
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
"""
|
||||
@@ -128,7 +161,8 @@ def main_training_config_path() -> Path:
|
||||
return path
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
def load(file_path: Union[str, Path],
|
||||
legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
@@ -173,7 +207,8 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
|
||||
legacy_config_dict: Dict[str, Any], num_steps: int = 256,
|
||||
action_type: str = "ANY"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
@@ -168,8 +168,8 @@ class Primaite(Env):
|
||||
# TODO fix up with TrainingConfig
|
||||
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
||||
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
if self.config_values.observation_config is not None:
|
||||
self.obs_config = self.config_values.observation_config
|
||||
if self.training_config.observation_space is not None:
|
||||
self.obs_config = self.training_config.observation_space
|
||||
|
||||
# Observation Handler manages the user-configurable observation space.
|
||||
# It will be initialised later.
|
||||
|
||||
5
src/primaite/setup/_package_data/primaite_config.yaml
Normal file
5
src/primaite/setup/_package_data/primaite_config.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
# The main PrimAITE application config file
|
||||
|
||||
# Logging
|
||||
log_level: 'INFO'
|
||||
logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
Reference in New Issue
Block a user