Merged PR 53: release v1.2.0
This commit is contained in:
@@ -3,4 +3,4 @@ Index-servers =
|
||||
PrimAITE
|
||||
|
||||
[PrimAITE]
|
||||
Repository = https://pkgs.dev.azure.com/ma-dev-uk/PrimAITE/_packaging/PrimAITE/pypi/upload/
|
||||
Repository = https://pkgs.dev.azure.com/ma-dev-uk/PrimAITE/_packaging/PrimAITE/pypi/upload/
|
||||
|
||||
@@ -15,16 +15,17 @@ steps:
|
||||
displayName: 'Use Python $(python.version)'
|
||||
|
||||
- script: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
pip install wheel
|
||||
python -m pip install --upgrade pip==23.0.1
|
||||
pip install wheel==0.38.4 --upgrade
|
||||
pip install setuptools==66 --upgrade
|
||||
pip install build==0.10.0
|
||||
pip install twine
|
||||
pip install keyring
|
||||
pip install artifacts-keyring
|
||||
displayName: 'Install build dependencies'
|
||||
|
||||
- script: |
|
||||
python setup.py sdist bdist_wheel
|
||||
python -m build
|
||||
displayName: 'Build PrimAITE sdist and wheel'
|
||||
|
||||
- task: TwineAuthenticate@1
|
||||
@@ -33,5 +34,5 @@ steps:
|
||||
artifactFeed: PrimAITE/PrimAITE
|
||||
|
||||
- script: |
|
||||
python -m twine upload --verbose -r PrimAITE --config-file $(PYPIRC_PATH) dist/*
|
||||
python -m twine upload --verbose -r PrimAITE --config-file $(PYPIRC_PATH) dist/*.whl
|
||||
displayName: 'Artifact Upload'
|
||||
|
||||
45
.azure/azure-build-deploy-docs-pipeline.yml
Normal file
45
.azure/azure-build-deploy-docs-pipeline.yml
Normal file
@@ -0,0 +1,45 @@
|
||||
name: Azure Static Web Apps CI/CD
|
||||
|
||||
pr: none
|
||||
trigger:
|
||||
branches:
|
||||
include:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
- job: build_and_deploy_job
|
||||
displayName: Build and Deploy Job
|
||||
condition: or(eq(variables['Build.Reason'], 'Manual'),or(eq(variables['Build.Reason'], 'PullRequest'),eq(variables['Build.Reason'], 'IndividualCI')))
|
||||
pool:
|
||||
vmImage: ubuntu-latest
|
||||
variables:
|
||||
- group: Azure-Static-Web-Apps-nice-bay-0ad032c03-variable-group
|
||||
steps:
|
||||
- checkout: self
|
||||
submodules: true
|
||||
|
||||
- script: |
|
||||
python -m pip install --upgrade pip==23.0.1
|
||||
pip install wheel==0.38.4 --upgrade
|
||||
pip install setuptools==66 --upgrade
|
||||
pip install build==0.10.0
|
||||
displayName: 'Install build dependencies'
|
||||
|
||||
- script: |
|
||||
pip install -e .[dev]
|
||||
displayName: 'Install Yawning-Titan for docs autosummary'
|
||||
|
||||
- script: |
|
||||
cd docs
|
||||
make html
|
||||
cd ..
|
||||
cd ..
|
||||
displayName: 'Build Docs'
|
||||
|
||||
- task: AzureStaticWebApp@0
|
||||
inputs:
|
||||
azure_static_web_apps_api_token: $(AZURE_STATIC_WEB_APPS_API_TOKEN_NICE_BAY_0AD032C03)
|
||||
app_location: "/docs/_build/html"
|
||||
api_location: ""
|
||||
output_location: "/"
|
||||
displayName: 'Deploy Docs to nice-bay-0ad032c03'
|
||||
52
.azure/azure-ci-build-pipeline.yaml
Normal file
52
.azure/azure-ci-build-pipeline.yaml
Normal file
@@ -0,0 +1,52 @@
|
||||
trigger:
|
||||
- main
|
||||
- dev
|
||||
- feature/*
|
||||
- hotfix/*
|
||||
- bugfix/*
|
||||
- release/*
|
||||
|
||||
pool:
|
||||
vmImage: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
Python38:
|
||||
python.version: '3.8'
|
||||
Python39:
|
||||
python.version: '3.9'
|
||||
Python310:
|
||||
python.version: '3.10'
|
||||
Python311:
|
||||
python.version: '3.11'
|
||||
|
||||
steps:
|
||||
- task: UsePythonVersion@0
|
||||
inputs:
|
||||
versionSpec: '$(python.version)'
|
||||
displayName: 'Use Python $(python.version)'
|
||||
|
||||
- script: |
|
||||
python -m pip install --upgrade pip==23.0.1
|
||||
pip install wheel==0.38.4 --upgrade
|
||||
pip install setuptools==66 --upgrade
|
||||
pip install build==0.10.0
|
||||
pip install pytest-azurepipelines
|
||||
displayName: 'Install build dependencies'
|
||||
|
||||
- script: |
|
||||
python -m build
|
||||
displayName: 'Build PrimAITE'
|
||||
|
||||
- script: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
displayName: 'Install PrimAITE'
|
||||
|
||||
#- script: |
|
||||
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
# displayName: 'Lint with flake8'
|
||||
|
||||
- script: |
|
||||
pytest tests/
|
||||
displayName: 'Run unmarked tests'
|
||||
12
.flake8
Normal file
12
.flake8
Normal file
@@ -0,0 +1,12 @@
|
||||
[flake8]
|
||||
max-line-length=120
|
||||
extend-ignore =
|
||||
D105
|
||||
D107
|
||||
D100
|
||||
D104
|
||||
E203
|
||||
E712
|
||||
D401
|
||||
exclude =
|
||||
docs/source/*
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,6 +1,8 @@
|
||||
# PrimAITE Package
|
||||
PRIMAITE/outputs
|
||||
PRIMAITE/outputs/*
|
||||
src/primaite/outputs
|
||||
src/primaite/outputs/*
|
||||
src/primaite/logs
|
||||
src/primaite/logs/*
|
||||
TestResults
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
||||
25
.pre-commit-config.yaml
Normal file
25
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
repos:
|
||||
- repo: http://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1000']
|
||||
- id: mixed-line-ending
|
||||
- id: requirements-txt-fixer
|
||||
- repo: http://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: http://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args: [ "--profile", "black" ]
|
||||
- repo: http://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies: [ flake8-docstrings ]
|
||||
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@@ -0,0 +1 @@
|
||||
include src/primaite/config/*.yaml
|
||||
312
PRIMAITE/Main.py
312
PRIMAITE/Main.py
@@ -1,312 +0,0 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
PRIMAITE - main (harness) module
|
||||
|
||||
Coding Standards: PEP 8
|
||||
"""
|
||||
|
||||
from sys import exc_info
|
||||
import time
|
||||
import yaml
|
||||
import os.path
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from environment.primaite import PRIMAITE
|
||||
from transactions.transactions_to_file import write_transaction_to_file
|
||||
from common.config_values_main import config_values_main
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
################################# FUNCTIONS ######################################
|
||||
|
||||
def run_generic():
|
||||
"""
|
||||
Run against a generic agent
|
||||
"""
|
||||
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
for step in range(0, config_values.num_steps):
|
||||
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo():
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent
|
||||
"""
|
||||
|
||||
if config_values.load_agent == True:
|
||||
try:
|
||||
agent = PPO.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps)
|
||||
except:
|
||||
print("ERROR: Could not load agent at location: " + config_values.agent_load_file)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
def run_stable_baselines3_a2c():
|
||||
"""
|
||||
Run against a stable_baselines3 A2C agent
|
||||
"""
|
||||
|
||||
if config_values.load_agent == True:
|
||||
try:
|
||||
agent = A2C.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps)
|
||||
except:
|
||||
print("ERROR: Could not load agent at location: " + config_values.agent_load_file)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
def save_agent(_agent):
|
||||
"""
|
||||
Persist an agent (only works for stable baselines3 agents at present)
|
||||
"""
|
||||
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
try:
|
||||
path = 'outputs/agents/'
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/agents/agent_saved_" + time
|
||||
_agent.save(filename)
|
||||
logging.info("Trained agent saved as " + filename)
|
||||
except Exception as e:
|
||||
logging.error("Could not save agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
def configure_logging():
|
||||
"""
|
||||
Configures logging
|
||||
"""
|
||||
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
filename = "logs/app_" + time + ".log"
|
||||
path = 'logs/'
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
logging.basicConfig(filename=filename, filemode='w', format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)
|
||||
except:
|
||||
print("ERROR: Could not start logging")
|
||||
|
||||
def load_config_values():
|
||||
"""
|
||||
Loads the config values from the main config file into a config object
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generic
|
||||
config_values.agent_identifier = config_data['agentIdentifier']
|
||||
config_values.num_episodes = int(config_data['numEpisodes'])
|
||||
config_values.time_delay = int(config_data['timeDelay'])
|
||||
config_values.config_filename_use_case = config_data['configFilename']
|
||||
config_values.session_type = config_data['sessionType']
|
||||
config_values.load_agent = bool(config_data['loadAgent'])
|
||||
config_values.agent_load_file = config_data['agentLoadFile']
|
||||
# Environment
|
||||
config_values.observation_space_high_value = int(config_data['observationSpaceHighValue'])
|
||||
# Reward values
|
||||
# Generic
|
||||
config_values.all_ok = int(config_data['allOk'])
|
||||
# Node Operating State
|
||||
config_values.off_should_be_on = int(config_data['offShouldBeOn'])
|
||||
config_values.off_should_be_resetting = int(config_data['offShouldBeResetting'])
|
||||
config_values.on_should_be_off = int(config_data['onShouldBeOff'])
|
||||
config_values.on_should_be_resetting = int(config_data['onShouldBeResetting'])
|
||||
config_values.resetting_should_be_on = int(config_data['resettingShouldBeOn'])
|
||||
config_values.resetting_should_be_off = int(config_data['resettingShouldBeOff'])
|
||||
config_values.resetting = int(config_data['resetting'])
|
||||
# Node O/S or Service State
|
||||
config_values.good_should_be_patching = int(config_data['goodShouldBePatching'])
|
||||
config_values.good_should_be_compromised = int(config_data['goodShouldBeCompromised'])
|
||||
config_values.good_should_be_overwhelmed = int(config_data['goodShouldBeOverwhelmed'])
|
||||
config_values.patching_should_be_good = int(config_data['patchingShouldBeGood'])
|
||||
config_values.patching_should_be_compromised = int(config_data['patchingShouldBeCompromised'])
|
||||
config_values.patching_should_be_overwhelmed = int(config_data['patchingShouldBeOverwhelmed'])
|
||||
config_values.patching = int(config_data['patching'])
|
||||
config_values.compromised_should_be_good = int(config_data['compromisedShouldBeGood'])
|
||||
config_values.compromised_should_be_patching = int(config_data['compromisedShouldBePatching'])
|
||||
config_values.compromised_should_be_overwhelmed = int(config_data['compromisedShouldBeOverwhelmed'])
|
||||
config_values.compromised = int(config_data['compromised'])
|
||||
config_values.overwhelmed_should_be_good = int(config_data['overwhelmedShouldBeGood'])
|
||||
config_values.overwhelmed_should_be_patching = int(config_data['overwhelmedShouldBePatching'])
|
||||
config_values.overwhelmed_should_be_compromised = int(config_data['overwhelmedShouldBeCompromised'])
|
||||
config_values.overwhelmed = int(config_data['overwhelmed'])
|
||||
# Node File System State
|
||||
config_values.good_should_be_repairing = int(config_data['goodShouldBeRepairing'])
|
||||
config_values.good_should_be_restoring = int(config_data['goodShouldBeRestoring'])
|
||||
config_values.good_should_be_corrupt = int(config_data['goodShouldBeCorrupt'])
|
||||
config_values.good_should_be_destroyed = int(config_data['goodShouldBeDestroyed'])
|
||||
config_values.repairing_should_be_good = int(config_data['repairingShouldBeGood'])
|
||||
config_values.repairing_should_be_restoring = int(config_data['repairingShouldBeRestoring'])
|
||||
config_values.repairing_should_be_corrupt = int(config_data['repairingShouldBeCorrupt'])
|
||||
config_values.repairing_should_be_destroyed = int(config_data['repairingShouldBeDestroyed'])
|
||||
config_values.repairing = int(config_data['repairing'])
|
||||
config_values.restoring_should_be_good = int(config_data['restoringShouldBeGood'])
|
||||
config_values.restoring_should_be_repairing = int(config_data['restoringShouldBeRepairing'])
|
||||
config_values.restoring_should_be_corrupt = int(config_data['restoringShouldBeCorrupt'])
|
||||
config_values.restoring_should_be_destroyed = int(config_data['restoringShouldBeDestroyed'])
|
||||
config_values.restoring = int(config_data['restoring'])
|
||||
config_values.corrupt_should_be_good = int(config_data['corruptShouldBeGood'])
|
||||
config_values.corrupt_should_be_repairing = int(config_data['corruptShouldBeRepairing'])
|
||||
config_values.corrupt_should_be_restoring = int(config_data['corruptShouldBeRestoring'])
|
||||
config_values.corrupt_should_be_destroyed = int(config_data['corruptShouldBeDestroyed'])
|
||||
config_values.corrupt = int(config_data['corrupt'])
|
||||
config_values.destroyed_should_be_good = int(config_data['destroyedShouldBeGood'])
|
||||
config_values.destroyed_should_be_repairing = int(config_data['destroyedShouldBeRepairing'])
|
||||
config_values.destroyed_should_be_restoring = int(config_data['destroyedShouldBeRestoring'])
|
||||
config_values.destroyed_should_be_corrupt = int(config_data['destroyedShouldBeCorrupt'])
|
||||
config_values.destroyed = int(config_data['destroyed'])
|
||||
config_values.scanning = int(config_data['scanning'])
|
||||
# IER status
|
||||
config_values.red_ier_running = int(config_data['redIerRunning'])
|
||||
config_values.green_ier_blocked = int(config_data['greenIerBlocked'])
|
||||
# Patching / Reset durations
|
||||
config_values.os_patching_duration = int(config_data['osPatchingDuration'])
|
||||
config_values.node_reset_duration = int(config_data['nodeResetDuration'])
|
||||
config_values.service_patching_duration = int(config_data['servicePatchingDuration'])
|
||||
config_values.file_system_repairing_limit = int(config_data['fileSystemRepairingLimit'])
|
||||
config_values.file_system_restoring_limit = int(config_data['fileSystemRestoringLimit'])
|
||||
config_values.file_system_scanning_limit = int(config_data['fileSystemScanningLimit'])
|
||||
|
||||
logging.info("Training agent: " + config_values.agent_identifier)
|
||||
logging.info("Training environment config: " + config_values.config_filename_use_case)
|
||||
logging.info("Training cycle has " + str(config_values.num_episodes) + " episodes")
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Could not save load config data")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
################################# MAIN PROCESS ############################################
|
||||
|
||||
# Starting point
|
||||
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
|
||||
# Configure logging
|
||||
configure_logging()
|
||||
|
||||
# Open the main config file
|
||||
try:
|
||||
config_file_main = open("config/config_main.yaml", "r")
|
||||
config_data = yaml.safe_load(config_file_main)
|
||||
# Create a config class
|
||||
config_values = config_values_main()
|
||||
# Load in config data
|
||||
load_config_values()
|
||||
except Exception as e:
|
||||
logging.error("Could not load main config")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the PRIMAITE environment
|
||||
try:
|
||||
env = PRIMAITE(config_values, transaction_list)
|
||||
logging.info("PrimAITE environment created")
|
||||
except Exception as e:
|
||||
logging.error("Could not create PrimAITE environment")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c()
|
||||
|
||||
print("Session finished")
|
||||
logging.info("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
logging.info("Saving transaction logs...")
|
||||
|
||||
write_transaction_to_file(transaction_list)
|
||||
|
||||
config_file_main.close
|
||||
|
||||
print("Finished")
|
||||
logging.info("Finished")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = 'PrimAITE'
|
||||
copyright = '2022, jashort'
|
||||
author = 'jashort'
|
||||
release = '0.1.0'
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = ['sphinx_rtd_theme']
|
||||
|
||||
templates_path = ['_templates']
|
||||
exclude_patterns = []
|
||||
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_static_path = ['_static']
|
||||
@@ -5,8 +5,8 @@
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
43
docs/conf.py
Normal file
43
docs/conf.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
import datetime
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
import os
|
||||
import sys
|
||||
|
||||
import furo # noqa
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../"))
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
year = datetime.datetime.now().year
|
||||
project = "primaite"
|
||||
copyright = f"Copyright (C) QinetiQ Training and Simulation Ltd 2021 - {year}"
|
||||
author = "QinetiQ Training and Simulation Ltd"
|
||||
|
||||
# The short Major.Minor.Build version
|
||||
with open("../src/primaite/VERSION", "r") as file:
|
||||
version = file.readline()
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = version
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = []
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "furo"
|
||||
html_static_path = ["_static"]
|
||||
@@ -9,11 +9,11 @@ Welcome to PrimAITE's documentation
|
||||
What is PrimAITE?
|
||||
------------------------
|
||||
|
||||
PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme. It incorporates the functionality required of a Primary-level environment, as specified in the Dstl ARCD Training Environment Matrix document:
|
||||
PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme. It incorporates the functionality required of a Primary-level environment, as specified in the Dstl ARCD Training Environment Matrix document:
|
||||
|
||||
* The ability to model a relevant platform / system context;
|
||||
* The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, traffic loading, operating systems, file system, services and processes;
|
||||
* Operates at machine-speed to enable fast training cycles.
|
||||
* The ability to model a relevant platform / system context;
|
||||
* The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, traffic loading, operating systems, file system, services and processes;
|
||||
* Operates at machine-speed to enable fast training cycles.
|
||||
|
||||
PrimAITE aims to evolve into an ARCD environment that could be used as the follow-on from Reception level approaches (e.g. YAWNING TITAN), and help bridge the Sim-to-Real gap into Secondary level environments (e.g. IMAGINARY YAK).
|
||||
|
||||
@@ -35,8 +35,8 @@ The best place to start is :ref:`about`
|
||||
:maxdepth: 8
|
||||
:caption: Contents:
|
||||
|
||||
about
|
||||
dependencies
|
||||
config
|
||||
session
|
||||
results
|
||||
source/about
|
||||
source/dependencies
|
||||
source/config
|
||||
source/session
|
||||
source/results
|
||||
@@ -7,8 +7,8 @@ REM Command file for Sphinx documentation
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
@@ -105,7 +105,7 @@ The status changes that can be made to a node are as follows:
|
||||
|
||||
* ON
|
||||
* OFF
|
||||
* RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state
|
||||
* RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state
|
||||
|
||||
* Active Nodes and Service Nodes:
|
||||
|
||||
@@ -194,7 +194,7 @@ An example observation space is provided below:
|
||||
:widths: 25 25 25 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* -
|
||||
* -
|
||||
- ID
|
||||
- Operating State
|
||||
- O/S State
|
||||
@@ -326,8 +326,8 @@ A reward value is presented back to the blue agent on the conclusion of every st
|
||||
|
||||
**Node and service status**
|
||||
|
||||
On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment)
|
||||
and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a
|
||||
On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment)
|
||||
and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a
|
||||
difference with respect to the reference environment, the before and after states are compared, and a score determined. See :ref:`config` for details of reward values.
|
||||
|
||||
**IER status**
|
||||
@@ -66,83 +66,83 @@ The config_main.yaml file consists of the following attributes:
|
||||
The score to give when the node should be resetting, but is off
|
||||
|
||||
* **Node Operating State [onShouldBeOff]** [int]
|
||||
|
||||
|
||||
The score to give when the node should be off, but is on
|
||||
|
||||
* **Node Operating State [onShouldBeResetting]** [int]
|
||||
|
||||
|
||||
The score to give when the node should be resetting, but is on
|
||||
|
||||
* **Node Operating State [resettingShouldBeOn]** [int]
|
||||
|
||||
|
||||
The score to give when the node should be on, but is resetting
|
||||
|
||||
* **Node Operating State [resettingShouldBeOff]** [int]
|
||||
|
||||
|
||||
The score to give when the node should be off, but is resetting
|
||||
|
||||
* **Node Operating State [resetting]** [int]
|
||||
|
||||
|
||||
The score to give when the node is resetting
|
||||
|
||||
* **Node Operating System or Service State [goodShouldBePatching]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be patching, but is good
|
||||
|
||||
* **Node Operating System or Service State [goodShouldBeCompromised]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be compromised, but is good
|
||||
|
||||
* **Node Operating System or Service State [goodShouldBeOverwhelmed]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be overwhelmed, but is good
|
||||
|
||||
* **Node Operating System or Service State [patchingShouldBeGood]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be good, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patchingShouldBeCompromised]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be compromised, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patchingShouldBeOverwhelmed]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be overwhelmed, but is patching
|
||||
|
||||
* **Node Operating System or Service State [patching]** [int]
|
||||
|
||||
|
||||
The score to give when the state is patching
|
||||
|
||||
* **Node Operating System or Service State [compromisedShouldBeGood]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be good, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromisedShouldBePatching]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be patching, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromisedShouldBeOverwhelmed]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be overwhelmed, but is compromised
|
||||
|
||||
* **Node Operating System or Service State [compromised]** [int]
|
||||
|
||||
|
||||
The score to give when the state is compromised
|
||||
|
||||
* **Node Operating System or Service State [overwhelmedShouldBeGood]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be good, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmedShouldBePatching]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be patching, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmedShouldBeCompromised]** [int]
|
||||
|
||||
|
||||
The score to give when the state should be compromised, but is overwhelmed
|
||||
|
||||
* **Node Operating System or Service State [overwhelmed]** [int]
|
||||
|
||||
|
||||
The score to give when the state is overwhelmed
|
||||
|
||||
* **Node File System State [goodShouldBeRepairing]** [int]
|
||||
@@ -246,11 +246,11 @@ The config_main.yaml file consists of the following attributes:
|
||||
The score to give when the state is scanning
|
||||
|
||||
* **IER Status [redIerRunning]** [int]
|
||||
|
||||
|
||||
The score to give when a red agent IER is permitted to run
|
||||
|
||||
* **IER Status [greenIerBlocked]** [int]
|
||||
|
||||
|
||||
The score to give when a green agent IER is prevented from running
|
||||
|
||||
**Patching / Reset Durations**
|
||||
@@ -260,14 +260,14 @@ The config_main.yaml file consists of the following attributes:
|
||||
The number of steps to take when patching an Operating System
|
||||
|
||||
* **nodeResetDuration** [int]
|
||||
|
||||
|
||||
The number of steps to take when resetting a node's operating state
|
||||
|
||||
* **servicePatchingDuration** [int]
|
||||
|
||||
|
||||
The number of steps to take when patching a service
|
||||
|
||||
* **fileSystemRepairingLimit** [int]:
|
||||
* **fileSystemRepairingLimit** [int]:
|
||||
|
||||
The number of steps to take when repairing the file system
|
||||
|
||||
@@ -285,23 +285,23 @@ config_[name].yaml:
|
||||
The config_[name].yaml file consists of the following attributes:
|
||||
|
||||
* **itemType: ACTIONS** [enum]
|
||||
|
||||
|
||||
Determines whether a NODE or ACL action space format is adopted for the session
|
||||
|
||||
* **itemType: STEPS** [int]
|
||||
|
||||
|
||||
Determines the number of steps to run in each episode of the session
|
||||
|
||||
* **itemType: PORTS** [int]
|
||||
|
||||
|
||||
Provides a list of ports modelled in this session
|
||||
|
||||
* **itemType: SERVICES** [freetext]
|
||||
|
||||
|
||||
Provides a list of services modelled in this session
|
||||
|
||||
* **itemType: NODE**
|
||||
|
||||
|
||||
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
@@ -318,9 +318,9 @@ The config_[name].yaml file consists of the following attributes:
|
||||
* **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list
|
||||
* **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list
|
||||
* **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
|
||||
* **itemType: LINK**
|
||||
|
||||
|
||||
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
@@ -344,7 +344,7 @@ The config_[name].yaml file consists of the following attributes:
|
||||
* **missionCriticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest)
|
||||
|
||||
* **itemType: RED_IER**
|
||||
|
||||
|
||||
Defines a red agent Information Exchange Requirement (IER). It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
@@ -358,7 +358,7 @@ The config_[name].yaml file consists of the following attributes:
|
||||
* **missionCriticality** [enum]: Not currently used. Default to 0
|
||||
|
||||
* **itemType: GREEN_POL**
|
||||
|
||||
|
||||
Defines a green agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
@@ -370,7 +370,7 @@ The config_[name].yaml file consists of the following attributes:
|
||||
* **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for operating system state) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state)
|
||||
|
||||
* **itemType: RED_POL**
|
||||
|
||||
|
||||
Defines a red agent pattern-of-life instruction. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
@@ -386,7 +386,7 @@ The config_[name].yaml file consists of the following attributes:
|
||||
* **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
|
||||
|
||||
* **itemType: ACL_RULE**
|
||||
|
||||
|
||||
Defines an initial Access Control List (ACL) rule. It should consist of:
|
||||
|
||||
* **id** [int]: Unique ID for this YAML item
|
||||
@@ -394,4 +394,4 @@ The config_[name].yaml file consists of the following attributes:
|
||||
* **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
|
||||
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
|
||||
@@ -10,7 +10,7 @@ PrimAITE is built with the following versions of dependencies:
|
||||
* numpy 1.23.5
|
||||
* networkx 2.8.8
|
||||
* gym 0.21.0
|
||||
* matplotlib 3.6.2
|
||||
* matplotlib 3.6.2
|
||||
* stable_baselines_3 1.6.2
|
||||
|
||||
The latest release of PrimAITE has been tested against the following versions of dependencies:
|
||||
@@ -20,7 +20,5 @@ The latest release of PrimAITE has been tested against the following versions of
|
||||
* numpy 1.23.5
|
||||
* networkx 2.8.8
|
||||
* gym 0.21.0
|
||||
* matplotlib 3.6.2
|
||||
* matplotlib 3.6.2
|
||||
* stable_baselines_3 1.6.2
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ PrimAITE produces four types of data:
|
||||
* Outputs - Saved agents
|
||||
* Logging
|
||||
|
||||
Outputs can be found in the *[Install Directory]\\PRIMAITE\\PRIMAITE\\outputs* directory
|
||||
Outputs can be found in the *[Install Directory]\\Primaite\\Primaite\\outputs* directory
|
||||
|
||||
Logging can be found in the *[Install Directory]\\PRIMAITE\\PRIMAITE\\logs* directory
|
||||
Logging can be found in the *[Install Directory]\\Primaite\\Primaite\\logs* directory
|
||||
|
||||
**Outputs - Results**
|
||||
|
||||
@@ -39,4 +39,4 @@ For each training session, assuming the agent being trained implements the *save
|
||||
|
||||
**Logging**
|
||||
|
||||
PrimAITE also provides output logs (for diagnosis) using the Python Logging package. These can be found in the *[Install Directory]\\PRIMAITE\\PRIMAITE\\logs* directory
|
||||
PrimAITE also provides output logs (for diagnosis) using the Python Logging package. These can be found in the *[Install Directory]\\Primaite\\Primaite\\logs* directory
|
||||
@@ -24,7 +24,7 @@ Integrating a blue agent with PrimAITE requires some modification of the code wi
|
||||
* Stable Baselines 3 PPO (run_stable_baselines3_ppo)
|
||||
* Stable Baselines 3 A2C (run_stable_baselines3_a2c)
|
||||
|
||||
The selection of which agent type to use is made via the config_main.yaml file. In order to train a user generated agent,
|
||||
The selection of which agent type to use is made via the config_main.yaml file. In order to train a user generated agent,
|
||||
the run_generic function should be selected, and should be modified (typically) to be:
|
||||
|
||||
.. code:: python
|
||||
@@ -46,7 +46,7 @@ Where:
|
||||
* the *save_agent()* assumes that a *save()* function has been defined in the user created agent. If not, this line can be ommitted (although it is encouraged, since it will allow the agent to be saved and ported)
|
||||
|
||||
The code below provides a suggested format for the learn() function within the user created agent.
|
||||
It's important to include the *self.environment.reset()* call within the episode loop in order that the
|
||||
It's important to include the *self.environment.reset()* call within the episode loop in order that the
|
||||
environment is reset between episodes. Note that the example below should not be considered exhaustive.
|
||||
|
||||
.. code:: python
|
||||
@@ -58,7 +58,7 @@ environment is reset between episodes. Note that the example below should not be
|
||||
# reset the environment
|
||||
self.environment.reset()
|
||||
done = False
|
||||
|
||||
|
||||
for step in range(max_steps):
|
||||
# calculate the action
|
||||
action = ...
|
||||
@@ -77,12 +77,10 @@ environment is reset between episodes. Note that the example below should not be
|
||||
break
|
||||
|
||||
**Running the session**
|
||||
|
||||
|
||||
In order to execute a session, carry out the following steps:
|
||||
|
||||
1. Navigate to "[Install directory]\\PRIMAITE\\PRIMAITE\\”
|
||||
2. Start a console window (type “CMD” in path window, or start a console window first and navigate to “[Install Directory]\\PRIMAITE\\PRIMAITE\\”)
|
||||
3. Type “python main.py”
|
||||
4. The session will start with an output indicating the current episode, and average reward value for the episode
|
||||
|
||||
|
||||
1. Navigate to "[Install directory]\\Primaite\\Primaite\\”
|
||||
2. Start a console window (type “CMD” in path window, or start a console window first and navigate to “[Install Directory]\\Primaite\\Primaite\\”)
|
||||
3. Type “python main.py”
|
||||
4. The session will start with an output indicating the current episode, and average reward value for the episode
|
||||
61
pyproject.toml
Normal file
61
pyproject.toml
Normal file
@@ -0,0 +1,61 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "setuptools-scm", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "primaite"
|
||||
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
|
||||
authors = [{name="QinetiQ Training and Simulation Ltd"}]
|
||||
license = {text = "MIT License"}
|
||||
requires-python = ">=3.8"
|
||||
dynamic = ["version", "readme"]
|
||||
classifiers = [
|
||||
"License :: MIT License",
|
||||
"Development Status :: 4 - Beta",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Unix",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"gym==0.21.0",
|
||||
"matplotlib==3.7.1",
|
||||
"networkx==3.1",
|
||||
"numpy==1.23.5",
|
||||
"PyYAML==6.0",
|
||||
"stable-baselines3==1.6.2"
|
||||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {file = ["src/primaite/VERSION"]}
|
||||
readme = {file = ["README.md"]}
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
include-package-data = true
|
||||
license-files = ["LICENSE"]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"setuptools==66",
|
||||
"pytest==7.2.0",
|
||||
"flake8==6.0.0",
|
||||
"Sphinx==6.1.3",
|
||||
"furo==2023.3.27",
|
||||
"sphinx-code-tabs==0.5.3",
|
||||
"sphinx-copybutton==0.5.2",
|
||||
"pytest-cov==4.0.0",
|
||||
"pytest-flake8==1.1.1",
|
||||
"pip-licenses==4.3.0",
|
||||
"pre-commit==2.20.0",
|
||||
"wheel==0.38.4",
|
||||
"build==0.10.0"
|
||||
]
|
||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
testpaths =
|
||||
tests
|
||||
34
setup.py
34
setup.py
@@ -1,26 +1,18 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Setup
|
||||
"""
|
||||
from setuptools import setup
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
|
||||
class bdist_wheel(_bdist_wheel): # noqa
|
||||
def finalize_options(self): # noqa
|
||||
super().finalize_options()
|
||||
# forces whee to be platform and Python version specific
|
||||
# Source: https://stackoverflow.com/a/45150383
|
||||
self.root_is_pure = False
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="primaite",
|
||||
maintainer="QinetiQ Training and Simulation Ltd",
|
||||
url="https://github.com/qtsl/PrimAITE",
|
||||
description="A primary-level simulation tool",
|
||||
python_requires=">=3.7",
|
||||
version="1.1.0",
|
||||
install_requires=[
|
||||
"gym==0.21.0",
|
||||
"matplotlib==3.6.2",
|
||||
"networkx==2.8.8",
|
||||
"numpy==1.23.5",
|
||||
"stable_baselines3==1.6.2",
|
||||
# Required for older versions of Gym that aren't compliant with
|
||||
# Setuptools>=67.
|
||||
"setuptools==66"
|
||||
],
|
||||
packages=find_packages()
|
||||
cmdclass={
|
||||
"bdist_wheel": bdist_wheel,
|
||||
}
|
||||
)
|
||||
|
||||
1
src/primaite/VERSION
Normal file
1
src/primaite/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
1.2.0
|
||||
@@ -1,25 +1,19 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
A class that implements the access control list implementation for the network
|
||||
"""
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
|
||||
from acl.acl_rule import ACLRule
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
|
||||
class AccessControlList():
|
||||
"""
|
||||
Access Control List class
|
||||
"""
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Init
|
||||
"""
|
||||
"""Init."""
|
||||
self.acl = {} # A dictionary of ACL Rules
|
||||
|
||||
self.acl = {} # A dictionary of ACL Rules
|
||||
|
||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||
"""
|
||||
Checks for IP address matches
|
||||
Checks for IP address matches.
|
||||
|
||||
Args:
|
||||
_rule: The rule being checked
|
||||
@@ -29,18 +23,28 @@ class AccessControlList():
|
||||
Returns:
|
||||
True if match; False otherwise.
|
||||
"""
|
||||
|
||||
if ((_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) or
|
||||
(_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address) or
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY") or
|
||||
(_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")):
|
||||
if (
|
||||
(
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == "ANY"
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == "ANY"
|
||||
)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
"""
|
||||
Checks for rules that block a protocol / port
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
Args:
|
||||
_source_ip_address: the source IP address to check
|
||||
@@ -51,11 +55,17 @@ class AccessControlList():
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if ((rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and
|
||||
(str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY")):
|
||||
if self.check_address_match(
|
||||
rule_value, _source_ip_address, _dest_ip_address
|
||||
):
|
||||
if (
|
||||
rule_value.get_protocol() == _protocol
|
||||
or rule_value.get_protocol() == "ANY"
|
||||
) and (
|
||||
str(rule_value.get_port()) == str(_port)
|
||||
or rule_value.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule_value.get_permission() == "DENY":
|
||||
return True
|
||||
@@ -67,7 +77,7 @@ class AccessControlList():
|
||||
|
||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Adds a new rule
|
||||
Adds a new rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -76,14 +86,13 @@ class AccessControlList():
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Removes a rule
|
||||
Removes a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -92,25 +101,21 @@ class AccessControlList():
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
try:
|
||||
self.acl.pop(hash_value)
|
||||
except:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def remove_all_rules(self):
|
||||
"""
|
||||
Removes all rules
|
||||
"""
|
||||
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Produces a hash value for a rule
|
||||
Produces a hash value for a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -122,13 +127,6 @@ class AccessControlList():
|
||||
Returns:
|
||||
Hash value based on rule parameters.
|
||||
"""
|
||||
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
A class that implements an access control list rule
|
||||
"""
|
||||
"""A class that implements an access control list rule."""
|
||||
|
||||
class ACLRule():
|
||||
"""
|
||||
Access Control List Rule class
|
||||
"""
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_permission: The permission (ALLOW or DENY)
|
||||
@@ -19,7 +16,6 @@ class ACLRule():
|
||||
_protocol: The rule protocol
|
||||
_port: The rule port
|
||||
"""
|
||||
|
||||
self.permission = _permission
|
||||
self.source_ip = _source_ip
|
||||
self.dest_ip = _dest_ip
|
||||
@@ -28,47 +24,45 @@ class ACLRule():
|
||||
|
||||
def __hash__(self):
|
||||
"""
|
||||
Override the hash function
|
||||
Override the hash function.
|
||||
|
||||
Returns:
|
||||
Returns hash of core parameters.
|
||||
"""
|
||||
|
||||
return hash((self.permission, self.source_ip, self.dest_ip, self.protocol, self.port))
|
||||
return hash(
|
||||
(self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
"""
|
||||
Gets the permission attribute
|
||||
Gets the permission attribute.
|
||||
|
||||
Returns:
|
||||
Returns permission attribute
|
||||
"""
|
||||
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self):
|
||||
"""
|
||||
Gets the source IP address attribute
|
||||
Gets the source IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns source IP address attribute
|
||||
"""
|
||||
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self):
|
||||
"""
|
||||
Gets the desintation IP address attribute
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns destination IP address attribute
|
||||
"""
|
||||
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self):
|
||||
"""
|
||||
Gets the protocol attribute
|
||||
Gets the protocol attribute.
|
||||
|
||||
Returns:
|
||||
Returns protocol attribute
|
||||
@@ -77,12 +71,9 @@ class ACLRule():
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets the port attribute
|
||||
Gets the port attribute.
|
||||
|
||||
Returns:
|
||||
Returns port attribute
|
||||
"""
|
||||
|
||||
return self.port
|
||||
|
||||
|
||||
@@ -1,39 +1,35 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The config class
|
||||
"""
|
||||
"""The config class."""
|
||||
|
||||
|
||||
class config_values_main(object):
|
||||
"""
|
||||
Class to hold main config values
|
||||
"""
|
||||
"""Class to hold main config values."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Init
|
||||
"""
|
||||
|
||||
"""Init."""
|
||||
# Generic
|
||||
self.agent_identifier = "" # the agent in use
|
||||
self.num_episodes = 0 # number of episodes to train over
|
||||
self.num_steps = 0 # number of steps in an episode
|
||||
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||
self.config_filename_use_case = "" # the filename for the Use Case config file
|
||||
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
|
||||
self.agent_identifier = "" # the agent in use
|
||||
self.num_episodes = 0 # number of episodes to train over
|
||||
self.num_steps = 0 # number of steps in an episode
|
||||
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||
self.config_filename_use_case = "" # the filename for the Use Case config file
|
||||
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
|
||||
|
||||
# Environment
|
||||
self.observation_space_high_value = 0 # The high value for the observation space
|
||||
self.observation_space_high_value = (
|
||||
0 # The high value for the observation space
|
||||
)
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
self.all_ok = 0
|
||||
self.all_ok = 0
|
||||
# Node Operating State
|
||||
self.off_should_be_on = 0
|
||||
self.off_should_be_resetting = 0
|
||||
self.on_should_be_off = 0
|
||||
self.on_should_be_resetting = 0
|
||||
self.resetting_should_be_on = 0
|
||||
self.resetting_should_be_off = 0
|
||||
self.resetting_should_be_off = 0
|
||||
self.resetting = 0
|
||||
# Node O/S or Service State
|
||||
self.good_should_be_patching = 0
|
||||
@@ -46,7 +42,7 @@ class config_values_main(object):
|
||||
self.compromised_should_be_good = 0
|
||||
self.compromised_should_be_patching = 0
|
||||
self.compromised_should_be_overwhelmed = 0
|
||||
self.compromised = 0
|
||||
self.compromised = 0
|
||||
self.overwhelmed_should_be_good = 0
|
||||
self.overwhelmed_should_be_patching = 0
|
||||
self.overwhelmed_should_be_compromised = 0
|
||||
@@ -59,11 +55,15 @@ class config_values_main(object):
|
||||
self.repairing_should_be_good = 0
|
||||
self.repairing_should_be_restoring = 0
|
||||
self.repairing_should_be_corrupt = 0
|
||||
self.repairing_should_be_destroyed = 0 # Repairing does not fix destroyed state - you need to restore
|
||||
self.repairing_should_be_destroyed = (
|
||||
0 # Repairing does not fix destroyed state - you need to restore
|
||||
)
|
||||
self.repairing = 0
|
||||
self.restoring_should_be_good = 0
|
||||
self.restoring_should_be_repairing = 0
|
||||
self.restoring_should_be_corrupt = 0 # Not the optimal method (as repair will fix corruption)
|
||||
self.restoring_should_be_corrupt = (
|
||||
0 # Not the optimal method (as repair will fix corruption)
|
||||
)
|
||||
self.restoring_should_be_destroyed = 0
|
||||
self.restoring = 0
|
||||
self.corrupt_should_be_good = 0
|
||||
@@ -82,10 +82,9 @@ class config_values_main(object):
|
||||
self.green_ier_blocked = 0
|
||||
|
||||
# Patching / Reset
|
||||
self.os_patching_duration = 0 # The time taken to patch the OS
|
||||
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
|
||||
self.service_patching_duration = 0 # The time taken to patch a service
|
||||
self.file_system_repairing_limit = 0 # The time take to repair a file
|
||||
self.file_system_restoring_limit = 0 # The time take to restore a file
|
||||
self.file_system_scanning_limit = 0 # The time taken to scan the file system
|
||||
|
||||
self.os_patching_duration = 0 # The time taken to patch the OS
|
||||
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
|
||||
self.service_patching_duration = 0 # The time taken to patch a service
|
||||
self.file_system_repairing_limit = 0 # The time take to repair a file
|
||||
self.file_system_restoring_limit = 0 # The time take to restore a file
|
||||
self.file_system_scanning_limit = 0 # The time taken to scan the file system
|
||||
@@ -1,14 +1,11 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Enumerations for APE
|
||||
"""
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TYPE(Enum):
|
||||
"""
|
||||
Node type enumeration
|
||||
"""
|
||||
"""Node type enumeration."""
|
||||
|
||||
CCTV = 1
|
||||
SWITCH = 2
|
||||
@@ -21,10 +18,9 @@ class TYPE(Enum):
|
||||
ACTUATOR = 9
|
||||
SERVER = 10
|
||||
|
||||
|
||||
class PRIORITY(Enum):
|
||||
"""
|
||||
Node priority enumeration
|
||||
"""
|
||||
"""Node priority enumeration."""
|
||||
|
||||
P1 = 1
|
||||
P2 = 2
|
||||
@@ -32,48 +28,43 @@ class PRIORITY(Enum):
|
||||
P4 = 4
|
||||
P5 = 5
|
||||
|
||||
|
||||
class HARDWARE_STATE(Enum):
|
||||
"""
|
||||
Node hardware state enumeration
|
||||
"""
|
||||
"""Node hardware state enumeration."""
|
||||
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESETTING = 3
|
||||
|
||||
|
||||
class SOFTWARE_STATE(Enum):
|
||||
"""
|
||||
O/S or Service state enumeration
|
||||
"""
|
||||
"""O/S or Service state enumeration."""
|
||||
|
||||
GOOD = 1
|
||||
PATCHING = 2
|
||||
COMPROMISED = 3
|
||||
OVERWHELMED = 4
|
||||
|
||||
|
||||
class NODE_POL_TYPE(Enum):
|
||||
"""
|
||||
Node Pattern of Life type enumeration
|
||||
"""
|
||||
"""Node Pattern of Life type enumeration."""
|
||||
|
||||
OPERATING = 1
|
||||
OS = 2
|
||||
SERVICE = 3
|
||||
FILE = 4
|
||||
|
||||
|
||||
class NODE_POL_INITIATOR(Enum):
|
||||
"""
|
||||
Node Pattern of Life initiator enumeration
|
||||
"""
|
||||
"""Node Pattern of Life initiator enumeration."""
|
||||
|
||||
DIRECT = 1
|
||||
IER = 2
|
||||
SERVICE = 3
|
||||
|
||||
|
||||
class PROTOCOL(Enum):
|
||||
"""
|
||||
Service protocol enumeration
|
||||
"""
|
||||
"""Service protocol enumeration."""
|
||||
|
||||
LDAP = 0
|
||||
FTP = 1
|
||||
@@ -84,18 +75,16 @@ class PROTOCOL(Enum):
|
||||
TCP = 6
|
||||
NONE = 7
|
||||
|
||||
|
||||
class ACTION_TYPE(Enum):
|
||||
"""
|
||||
Action type enumeration
|
||||
"""
|
||||
"""Action type enumeration."""
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
|
||||
|
||||
class FILE_SYSTEM_STATE(Enum):
|
||||
"""
|
||||
File System State
|
||||
"""
|
||||
"""File System State."""
|
||||
|
||||
GOOD = 1
|
||||
CORRUPT = 2
|
||||
@@ -1,59 +1,47 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The protocol class
|
||||
"""
|
||||
"""The protocol class."""
|
||||
|
||||
|
||||
class Protocol(object):
|
||||
"""
|
||||
Protocol class
|
||||
"""
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_name: The protocol name
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
self.load = 0 # bps
|
||||
self.load = 0 # bps
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
Gets the protocol name
|
||||
Gets the protocol name.
|
||||
|
||||
Returns:
|
||||
The protocol name
|
||||
"""
|
||||
|
||||
return self.name
|
||||
|
||||
def get_load(self):
|
||||
"""
|
||||
Gets the protocol load
|
||||
Gets the protocol load.
|
||||
|
||||
Returns:
|
||||
The protocol load (bps)
|
||||
"""
|
||||
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load):
|
||||
"""
|
||||
Adds load to the protocol
|
||||
Adds load to the protocol.
|
||||
|
||||
Args:
|
||||
_load: The load to add
|
||||
"""
|
||||
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self):
|
||||
"""
|
||||
Clears the load on this protocol
|
||||
"""
|
||||
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The Service class
|
||||
"""
|
||||
"""The Service class."""
|
||||
|
||||
from primaite.common.enums import SOFTWARE_STATE
|
||||
|
||||
from common.enums import SOFTWARE_STATE
|
||||
|
||||
class Service(object):
|
||||
"""
|
||||
Service class
|
||||
"""
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, _name, _port, _state):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_name: The service name
|
||||
_port: The service port
|
||||
_state: The service state
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
self.port = _port
|
||||
self.state = _state
|
||||
@@ -27,74 +23,61 @@ class Service(object):
|
||||
|
||||
def set_name(self, _name):
|
||||
"""
|
||||
Sets the service name
|
||||
Sets the service name.
|
||||
|
||||
Args:
|
||||
_name: The service name
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
Gets the service name
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
|
||||
return self.name
|
||||
|
||||
def set_port(self, _port):
|
||||
"""
|
||||
Sets the service port
|
||||
Sets the service port.
|
||||
|
||||
Args:
|
||||
_port: The service port
|
||||
"""
|
||||
|
||||
self.port = _port
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets the service port
|
||||
Gets the service port.
|
||||
|
||||
Returns:
|
||||
The service port
|
||||
"""
|
||||
|
||||
return self.port
|
||||
|
||||
def set_state(self, _state):
|
||||
"""
|
||||
Sets the service state
|
||||
Sets the service state.
|
||||
|
||||
Args:
|
||||
_state: The service state
|
||||
"""
|
||||
|
||||
self.state = _state
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the service state
|
||||
Gets the service state.
|
||||
|
||||
Returns:
|
||||
The service state
|
||||
"""
|
||||
|
||||
return self.state
|
||||
|
||||
def reduce_patching_count(self):
|
||||
"""
|
||||
Reduces the patching count for the service
|
||||
"""
|
||||
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self.state = SOFTWARE_STATE.GOOD
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
steps: 128
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '80'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: PC1
|
||||
@@ -19,9 +19,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: SERVER
|
||||
@@ -33,9 +33,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: PC2
|
||||
@@ -47,9 +47,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '4'
|
||||
name: SWITCH1
|
||||
@@ -4,10 +4,10 @@
|
||||
steps: 128
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '80'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: PC1
|
||||
@@ -19,9 +19,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: PC2
|
||||
@@ -33,9 +33,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: PC3
|
||||
@@ -47,9 +47,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '4'
|
||||
name: PC4
|
||||
@@ -61,9 +61,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '5'
|
||||
name: SWITCH1
|
||||
@@ -85,9 +85,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '7'
|
||||
name: SWITCH2
|
||||
@@ -109,9 +109,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '9'
|
||||
name: SERVER1
|
||||
@@ -123,9 +123,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '10'
|
||||
name: SERVER2
|
||||
@@ -137,9 +137,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '11'
|
||||
name: link1
|
||||
@@ -4,10 +4,10 @@
|
||||
steps: 256
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '80'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: PC1
|
||||
@@ -19,9 +19,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: PC2
|
||||
@@ -33,9 +33,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: SWITCH1
|
||||
@@ -57,9 +57,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '5'
|
||||
name: link1
|
||||
@@ -4,14 +4,14 @@
|
||||
steps: 256
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: CLIENT_1
|
||||
@@ -23,12 +23,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: CLIENT_2
|
||||
@@ -40,9 +40,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: SWITCH_1
|
||||
@@ -64,12 +64,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
@@ -81,12 +81,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '6'
|
||||
name: SWITCH_2
|
||||
@@ -108,12 +108,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '8'
|
||||
name: DATABASE_SERVER
|
||||
@@ -125,15 +125,15 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '9'
|
||||
name: BACKUP_SERVER
|
||||
@@ -145,9 +145,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
@@ -529,5 +529,5 @@
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
@@ -4,14 +4,14 @@
|
||||
steps: 256
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: CLIENT_1
|
||||
@@ -23,12 +23,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: CLIENT_2
|
||||
@@ -40,9 +40,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: SWITCH_1
|
||||
@@ -64,12 +64,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
@@ -81,12 +81,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '6'
|
||||
name: SWITCH_2
|
||||
@@ -108,12 +108,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '8'
|
||||
name: DATABASE_SERVER
|
||||
@@ -125,15 +125,15 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '9'
|
||||
name: BACKUP_SERVER
|
||||
@@ -145,9 +145,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
@@ -529,5 +529,5 @@
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
@@ -21,18 +21,18 @@ agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observationSpaceHighValue: 1000000000
|
||||
observationSpaceHighValue: 1000000000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
allOk: 0
|
||||
allOk: 0
|
||||
# Node Operating State
|
||||
offShouldBeOn: -10
|
||||
offShouldBeResetting: -5
|
||||
onShouldBeOff: -2
|
||||
onShouldBeResetting: -5
|
||||
resettingShouldBeOn: -5
|
||||
resettingShouldBeOff: -2
|
||||
resettingShouldBeOff: -2
|
||||
resetting: -3
|
||||
# Node O/S or Service State
|
||||
goodShouldBePatching: 2
|
||||
@@ -45,7 +45,7 @@ patching: -3
|
||||
compromisedShouldBeGood: -20
|
||||
compromisedShouldBePatching: -20
|
||||
compromisedShouldBeOverwhelmed: -20
|
||||
compromised: -20
|
||||
compromised: -20
|
||||
overwhelmedShouldBeGood: -20
|
||||
overwhelmedShouldBePatching: -20
|
||||
overwhelmedShouldBeCompromised: -20
|
||||
@@ -62,7 +62,7 @@ repairingShouldBeDestroyed: 0
|
||||
repairing: -3
|
||||
restoringShouldBeGood: -10
|
||||
restoringShouldBeRepairing: -2
|
||||
restoringShouldBeCorrupt: 1
|
||||
restoringShouldBeCorrupt: 1
|
||||
restoringShouldBeDestroyed: 2
|
||||
restoring: -6
|
||||
corruptShouldBeGood: -10
|
||||
@@ -1,39 +1,45 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Main environment module containing the PRIMmary AI Training Evironment (PRIMAITE) class
|
||||
"""
|
||||
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
import copy
|
||||
import csv
|
||||
import yaml
|
||||
import os.path
|
||||
import logging
|
||||
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
import os.path
|
||||
from datetime import datetime
|
||||
|
||||
from common.enums import *
|
||||
from links.link import Link
|
||||
from pol.ier import IER
|
||||
from nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from pol.green_pol import apply_iers, apply_node_pol
|
||||
from pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from nodes.active_node import ActiveNode
|
||||
from nodes.passive_node import PassiveNode
|
||||
from nodes.service_node import ServiceNode
|
||||
from common.service import Service
|
||||
from acl.access_control_list import AccessControlList
|
||||
from environment.reward import calculate_reward_function
|
||||
from transactions.transaction import Transaction
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import yaml
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
class PRIMAITE(Env):
|
||||
"""
|
||||
PRIMmary AI Training Evironment (PRIMAITE) class
|
||||
"""
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.enums import (
|
||||
ACTION_TYPE,
|
||||
FILE_SYSTEM_STATE,
|
||||
HARDWARE_STATE,
|
||||
NODE_POL_INITIATOR,
|
||||
NODE_POL_TYPE,
|
||||
PRIORITY,
|
||||
SOFTWARE_STATE,
|
||||
TYPE,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
"""PRIMmary AI Training Evironment (Primaite) class."""
|
||||
|
||||
# Observation / Action Space contants
|
||||
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
|
||||
@@ -42,11 +48,11 @@ class PRIMAITE(Env):
|
||||
ACTION_SPACE_ACL_ACTION_VALUES = 3
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
|
||||
|
||||
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
|
||||
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
|
||||
|
||||
def __init__(self, _config_values, _transaction_list):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_episode_steps: The number of steps for the episode
|
||||
@@ -54,8 +60,7 @@ class PRIMAITE(Env):
|
||||
_transaction_list: The list of transactions to populate
|
||||
_agent_identifier: Identifier for the agent
|
||||
"""
|
||||
|
||||
super(PRIMAITE, self).__init__()
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# Take a copy of the config values
|
||||
self.config_values = _config_values
|
||||
@@ -140,10 +145,12 @@ class PRIMAITE(Env):
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
try:
|
||||
self.config_file = open("config/" + self.config_values.config_filename_use_case, "r")
|
||||
self.config_file = open(
|
||||
"config/" + self.config_values.config_filename_use_case, "r"
|
||||
)
|
||||
self.config_data = yaml.safe_load(self.config_file)
|
||||
self.load_config()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not load the environment configuration")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
@@ -162,17 +169,17 @@ class PRIMAITE(Env):
|
||||
try:
|
||||
plt.tight_layout()
|
||||
nx.draw_networkx(self.network, with_labels=True)
|
||||
now = datetime.now() # current date and time
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
path = 'outputs/diagrams'
|
||||
path = "outputs/diagrams"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/diagrams/network_" + time + ".png"
|
||||
plt.savefig(filename, format="PNG")
|
||||
plt.clf()
|
||||
except Exception as a:
|
||||
except Exception:
|
||||
logging.error("Could not save network diagram")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
print("Could not save network diagram")
|
||||
@@ -194,16 +201,22 @@ class PRIMAITE(Env):
|
||||
# - service F state | service F loading
|
||||
# - service G state | service G loading
|
||||
|
||||
# Calculate the number of items that need to be included in the observation space
|
||||
# Calculate the number of items that need to be included in the
|
||||
# observation space
|
||||
num_items = self.num_links + self.num_nodes
|
||||
# Set the number of observation parameters, being # of services plus id, operating state, file system state and O/S state (i.e. 4)
|
||||
self.num_observation_parameters = self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
||||
# Set the number of observation parameters, being # of services plus id,
|
||||
# operating state, file system state and O/S state (i.e. 4)
|
||||
self.num_observation_parameters = (
|
||||
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
||||
)
|
||||
# Define the observation shape
|
||||
self.observation_shape = (num_items, self.num_observation_parameters)
|
||||
self.observation_space = spaces.Box(low=0,
|
||||
high=self.config_values.observation_space_high_value,
|
||||
shape=self.observation_shape,
|
||||
dtype=np.int64)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=self.config_values.observation_space_high_value,
|
||||
shape=self.observation_shape,
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
# This is the observation that is sent back via the rest and step functions
|
||||
self.env_obs = np.zeros(self.observation_shape, dtype=np.int64)
|
||||
@@ -216,7 +229,14 @@ class PRIMAITE(Env):
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state)
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore)
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service)
|
||||
self.action_space = spaces.MultiDiscrete([self.num_nodes, self.ACTION_SPACE_NODE_PROPERTY_VALUES, self.ACTION_SPACE_NODE_ACTION_VALUES, self.num_services])
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.num_nodes,
|
||||
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
|
||||
self.ACTION_SPACE_NODE_ACTION_VALUES,
|
||||
self.num_services,
|
||||
]
|
||||
)
|
||||
else:
|
||||
logging.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
@@ -226,42 +246,52 @@ class PRIMAITE(Env):
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_space = spaces.MultiDiscrete([self.ACTION_SPACE_ACL_ACTION_VALUES, self.ACTION_SPACE_ACL_PERMISSION_VALUES, self.num_nodes + 1, self.num_nodes + 1, self.num_services + 1, self.num_ports + 1])
|
||||
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.ACTION_SPACE_ACL_ACTION_VALUES,
|
||||
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
|
||||
self.num_nodes + 1,
|
||||
self.num_nodes + 1,
|
||||
self.num_services + 1,
|
||||
self.num_ports + 1,
|
||||
]
|
||||
)
|
||||
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
header = ['Episode', 'Average Reward']
|
||||
header = ["Episode", "Average Reward"]
|
||||
|
||||
# Check whether the output/rerults folder exists (doesn't exist by default install)
|
||||
path = 'outputs/results/'
|
||||
path = "outputs/results/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/results/average_reward_per_episode_" + time + ".csv"
|
||||
self.csv_file = open(filename, 'w', encoding='UTF8', newline='')
|
||||
self.csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
self.csv_writer = csv.writer(self.csv_file)
|
||||
self.csv_writer.writerow(header)
|
||||
except Exception as e:
|
||||
logging.error("Could not create csv file to hold average reward per episode")
|
||||
except Exception:
|
||||
logging.error(
|
||||
"Could not create csv file to hold average reward per episode"
|
||||
)
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
AI Gym Reset function
|
||||
AI Gym Reset function.
|
||||
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
"""
|
||||
|
||||
csv_data = self.episode_count, self.average_reward
|
||||
self.csv_writer.writerow(csv_data)
|
||||
|
||||
self.episode_count += 1
|
||||
|
||||
|
||||
# Don't need to reset links, as they are cleared and recalculated every step
|
||||
|
||||
|
||||
# Clear the ACL
|
||||
self.init_acl()
|
||||
|
||||
@@ -280,7 +310,7 @@ class PRIMAITE(Env):
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
AI Gym Step function
|
||||
AI Gym Step function.
|
||||
|
||||
Args:
|
||||
action: Action space from agent
|
||||
@@ -291,7 +321,6 @@ class PRIMAITE(Env):
|
||||
done: Indicates episode is complete if True
|
||||
step_info: Additional information relating to this step
|
||||
"""
|
||||
|
||||
if self.step_count == 0:
|
||||
print("Episode: " + str(self.episode_count) + " running")
|
||||
|
||||
@@ -299,14 +328,16 @@ class PRIMAITE(Env):
|
||||
done = False
|
||||
|
||||
self.step_count += 1
|
||||
#print("Episode step: " + str(self.stepCount))
|
||||
|
||||
# print("Episode step: " + str(self.stepCount))
|
||||
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(datetime.now(), self.agent_identifier, self.episode_count, self.step_count)
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
|
||||
# Load the action space into the transaction
|
||||
@@ -316,50 +347,97 @@ class PRIMAITE(Env):
|
||||
self.apply_time_based_updates()
|
||||
|
||||
# 2. Apply PoL
|
||||
apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(self.network, self.nodes, self.links, self.green_iers, self.acl, self.step_count) # Network PoL
|
||||
apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.green_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_pol = copy.deepcopy(self.nodes)
|
||||
self.links_post_pol = copy.deepcopy(self.links)
|
||||
self.links_post_pol = copy.deepcopy(self.links)
|
||||
# Reference
|
||||
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(self.network_reference, self.nodes_reference, self.links_reference, self.green_iers, self.acl, self.step_count) # Network PoL
|
||||
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
self.links_reference,
|
||||
self.green_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
|
||||
# 3. Implement Red Action
|
||||
apply_red_agent_iers(self.network, self.nodes, self.links, self.red_iers, self.acl, self.step_count)
|
||||
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
|
||||
# 3. Implement Red Action
|
||||
apply_red_agent_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.red_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
apply_red_agent_node_pol(
|
||||
self.nodes, self.red_iers, self.red_node_pol, self.step_count
|
||||
)
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_red = copy.deepcopy(self.nodes)
|
||||
self.links_post_red = copy.deepcopy(self.links)
|
||||
self.links_post_red = copy.deepcopy(self.links)
|
||||
|
||||
# 4. Implement Blue Action
|
||||
self.interpret_action_and_apply(action)
|
||||
|
||||
# 5. Reapply normal and Red agent IER PoL, as we need to see what effect the blue agent action has had (if any) on link status
|
||||
# 5. Reapply normal and Red agent IER PoL, as we need to see what
|
||||
# effect the blue agent action has had (if any) on link status
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
apply_iers(self.network, self.nodes, self.links, self.green_iers, self.acl, self.step_count)
|
||||
apply_red_agent_iers(self.network, self.nodes, self.links, self.red_iers, self.acl, self.step_count)
|
||||
apply_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.green_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
apply_red_agent_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.red_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_blue = copy.deepcopy(self.nodes)
|
||||
self.links_post_blue = copy.deepcopy(self.links)
|
||||
self.links_post_blue = copy.deepcopy(self.links)
|
||||
|
||||
# 6. Calculate reward signal (for RL)
|
||||
reward = calculate_reward_function(self.nodes_post_pol, self.nodes_post_blue, self.nodes_reference, self.green_iers, self.red_iers, self.step_count, self.config_values)
|
||||
#print("Step reward: " + str(reward))
|
||||
reward = calculate_reward_function(
|
||||
self.nodes_post_pol,
|
||||
self.nodes_post_blue,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.config_values,
|
||||
)
|
||||
# print("Step reward: " + str(reward))
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
if self.config_values.session_type == "EVALUATION":
|
||||
# For evaluation, need to trigger the done value = True when step count is reached in order to prevent neverending episode
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
print("Average reward: " + str(self.average_reward))
|
||||
print("Average reward: " + str(self.average_reward))
|
||||
# Load the reward into the transaction
|
||||
transaction.set_reward(reward)
|
||||
|
||||
|
||||
# 7. Output Verbose
|
||||
#self.output_link_status()
|
||||
# self.output_link_status()
|
||||
|
||||
# 8. Update env_obs
|
||||
self.update_environent_obs()
|
||||
@@ -373,38 +451,33 @@ class PRIMAITE(Env):
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def __close__(self):
|
||||
"""
|
||||
Override close function
|
||||
"""
|
||||
|
||||
"""Override close function."""
|
||||
self.csv_file.close()
|
||||
self.config_file.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""
|
||||
Initialise the Access Control List
|
||||
"""
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
|
||||
self.acl.remove_all_rules()
|
||||
|
||||
def output_link_status(self):
|
||||
"""
|
||||
Output the link status of all links to the console
|
||||
"""
|
||||
|
||||
"""Output the link status of all links to the console."""
|
||||
for link_key, link_value in self.links.items():
|
||||
print("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.get_protocol_list():
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
print(
|
||||
" Protocol: "
|
||||
+ protocol.get_name().name
|
||||
+ ", Load: "
|
||||
+ str(protocol.get_load())
|
||||
)
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
"""
|
||||
Applies agent actions to the nodes and Access Control List
|
||||
Applies agent actions to the nodes and Access Control List.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
|
||||
# At the moment, actions are only affecting nodes
|
||||
if self.action_type == ACTION_TYPE.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
@@ -413,12 +486,11 @@ class PRIMAITE(Env):
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
"""
|
||||
Applies agent actions to the nodes
|
||||
Applies agent actions to the nodes.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
|
||||
node_id = _action[0]
|
||||
node_property = _action[1]
|
||||
property_action = _action[2]
|
||||
@@ -427,7 +499,7 @@ class PRIMAITE(Env):
|
||||
# Check that the action is requesting a valid node
|
||||
try:
|
||||
node = self.nodes[str(node_id)]
|
||||
except:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if node_property == 0:
|
||||
@@ -472,7 +544,9 @@ class PRIMAITE(Env):
|
||||
return
|
||||
elif property_action == 1:
|
||||
# Patch (valid action if it's good or compromised)
|
||||
node.set_service_state(self.services_list[service_index], SOFTWARE_STATE.PATCHING)
|
||||
node.set_service_state(
|
||||
self.services_list[service_index], SOFTWARE_STATE.PATCHING
|
||||
)
|
||||
else:
|
||||
# Node is not of Service Type
|
||||
return
|
||||
@@ -488,7 +562,10 @@ class PRIMAITE(Env):
|
||||
elif property_action == 2:
|
||||
# Repair
|
||||
# You cannot repair a destroyed file system - it needs restoring
|
||||
if node.get_file_system_state_actual() != FILE_SYSTEM_STATE.DESTROYED:
|
||||
if (
|
||||
node.get_file_system_state_actual()
|
||||
!= FILE_SYSTEM_STATE.DESTROYED
|
||||
):
|
||||
node.set_file_system_state(FILE_SYSTEM_STATE.REPAIRING)
|
||||
elif property_action == 3:
|
||||
# Restore
|
||||
@@ -501,12 +578,11 @@ class PRIMAITE(Env):
|
||||
|
||||
def apply_actions_to_acl(self, _action):
|
||||
"""
|
||||
Applies agent actions to the Access Control List [TO DO]
|
||||
Applies agent actions to the Access Control List [TO DO].
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
|
||||
action_decision = _action[0]
|
||||
action_permission = _action[1]
|
||||
action_source_ip = _action[2]
|
||||
@@ -517,7 +593,7 @@ class PRIMAITE(Env):
|
||||
if action_decision == 0:
|
||||
# It's decided to do nothing
|
||||
return
|
||||
else:
|
||||
else:
|
||||
# It's decided to create a new ACL rule or remove an existing rule
|
||||
# Permission value
|
||||
if action_permission == 0:
|
||||
@@ -556,18 +632,31 @@ class PRIMAITE(Env):
|
||||
# Now add or remove
|
||||
if action_decision == 1:
|
||||
# Add the rule
|
||||
self.acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
self.acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
elif action_decision == 2:
|
||||
# Remove the rule
|
||||
self.acl.remove_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
self.acl.remove_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
def apply_time_based_updates(self):
|
||||
"""
|
||||
Updates anything that needs to count down and then change state (e.g. reset / patching status)
|
||||
Updates anything that needs to count down and then change state.
|
||||
|
||||
e.g. reset / patching status
|
||||
"""
|
||||
|
||||
for node_key, node in self.nodes.items():
|
||||
if node.get_state() == HARDWARE_STATE.RESETTING:
|
||||
node.update_resetting_status()
|
||||
@@ -605,10 +694,7 @@ class PRIMAITE(Env):
|
||||
pass
|
||||
|
||||
def update_environent_obs(self):
|
||||
"""
|
||||
# Updates the observation space based on the node and link status
|
||||
"""
|
||||
|
||||
"""Updates the observation space based on the node and link status."""
|
||||
item_index = 0
|
||||
|
||||
# Do nodes first
|
||||
@@ -617,15 +703,19 @@ class PRIMAITE(Env):
|
||||
self.env_obs[item_index][1] = node.get_state().value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.env_obs[item_index][2] = node.get_os_state().value
|
||||
self.env_obs[item_index][3] = node.get_file_system_state_observed().value
|
||||
self.env_obs[item_index][
|
||||
3
|
||||
] = node.get_file_system_state_observed().value
|
||||
else:
|
||||
self.env_obs[item_index][2] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.services_list:
|
||||
if node.has_service(service):
|
||||
self.env_obs[item_index][service_index] = node.get_service_state(service).value
|
||||
self.env_obs[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
else:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
@@ -650,17 +740,14 @@ class PRIMAITE(Env):
|
||||
item_index += 1
|
||||
|
||||
def load_config(self):
|
||||
"""
|
||||
# Loads config data in order to build the environment configuration
|
||||
"""
|
||||
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
for item in self.config_data:
|
||||
if item["itemType"] == "NODE":
|
||||
# Create a node
|
||||
self.create_node(item)
|
||||
elif item["itemType"] == "LINK":
|
||||
# Create a link
|
||||
self.create_link(item)
|
||||
self.create_link(item)
|
||||
elif item["itemType"] == "GREEN_IER":
|
||||
# Create a Green IER
|
||||
self.create_green_ier(item)
|
||||
@@ -697,12 +784,11 @@ class PRIMAITE(Env):
|
||||
|
||||
def create_node(self, item):
|
||||
"""
|
||||
Creates a node from config data
|
||||
Creates a node from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
# All nodes have these parameters
|
||||
node_id = item["id"]
|
||||
node_name = item["name"]
|
||||
@@ -712,19 +798,46 @@ class PRIMAITE(Env):
|
||||
node_hardware_state = HARDWARE_STATE[item["hardwareState"]]
|
||||
|
||||
if node_base_type == "PASSIVE":
|
||||
node = PassiveNode(node_id, node_name, node_type, node_priority, node_hardware_state, self.config_values)
|
||||
node = PassiveNode(
|
||||
node_id,
|
||||
node_name,
|
||||
node_type,
|
||||
node_priority,
|
||||
node_hardware_state,
|
||||
self.config_values,
|
||||
)
|
||||
elif node_base_type == "ACTIVE":
|
||||
# Active nodes have IP address, operating system state and file system state
|
||||
node_ip_address = item["ipAddress"]
|
||||
node_software_state = SOFTWARE_STATE[item["softwareState"]]
|
||||
node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]]
|
||||
node = ActiveNode(node_id, node_name, node_type, node_priority, node_hardware_state, node_ip_address, node_software_state, node_file_system_state, self.config_values)
|
||||
node = ActiveNode(
|
||||
node_id,
|
||||
node_name,
|
||||
node_type,
|
||||
node_priority,
|
||||
node_hardware_state,
|
||||
node_ip_address,
|
||||
node_software_state,
|
||||
node_file_system_state,
|
||||
self.config_values,
|
||||
)
|
||||
elif node_base_type == "SERVICE":
|
||||
# Service nodes have IP address, operating system state, file system state and list of services
|
||||
node_ip_address = item["ipAddress"]
|
||||
node_software_state = SOFTWARE_STATE[item["softwareState"]]
|
||||
node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]]
|
||||
node = ServiceNode(node_id, node_name, node_type, node_priority, node_hardware_state, node_ip_address, node_software_state, node_file_system_state, self.config_values)
|
||||
node = ServiceNode(
|
||||
node_id,
|
||||
node_name,
|
||||
node_type,
|
||||
node_priority,
|
||||
node_hardware_state,
|
||||
node_ip_address,
|
||||
node_software_state,
|
||||
node_file_system_state,
|
||||
self.config_values,
|
||||
)
|
||||
node_services = item["services"]
|
||||
for service in node_services:
|
||||
service_protocol = service["name"]
|
||||
@@ -752,12 +865,11 @@ class PRIMAITE(Env):
|
||||
|
||||
def create_link(self, item):
|
||||
"""
|
||||
Creates a link from config data
|
||||
Creates a link from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
link_id = item["id"]
|
||||
link_name = item["name"]
|
||||
link_bandwidth = item["bandwidth"]
|
||||
@@ -771,7 +883,13 @@ class PRIMAITE(Env):
|
||||
self.network.add_edge(source_node, dest_node, id=link_name)
|
||||
|
||||
# Add link to link dictionary
|
||||
self.links[link_name] = Link(link_id, link_bandwidth, source_node.get_name(), dest_node.get_name(), self.services_list)
|
||||
self.links[link_name] = Link(
|
||||
link_id,
|
||||
link_bandwidth,
|
||||
source_node.get_name(),
|
||||
dest_node.get_name(),
|
||||
self.services_list,
|
||||
)
|
||||
|
||||
# Reference
|
||||
source_node_ref = self.nodes_reference[link_source]
|
||||
@@ -781,16 +899,21 @@ class PRIMAITE(Env):
|
||||
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
|
||||
|
||||
# Add link to link dictionary (reference)
|
||||
self.links_reference[link_name] = Link(link_id, link_bandwidth, source_node_ref.get_name(), dest_node_ref.get_name(), self.services_list)
|
||||
self.links_reference[link_name] = Link(
|
||||
link_id,
|
||||
link_bandwidth,
|
||||
source_node_ref.get_name(),
|
||||
dest_node_ref.get_name(),
|
||||
self.services_list,
|
||||
)
|
||||
|
||||
def create_green_ier(self, item):
|
||||
"""
|
||||
Creates a green IER from config data
|
||||
Creates a green IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
ier_id = item["id"]
|
||||
ier_start_step = item["startStep"]
|
||||
ier_end_step = item["endStep"]
|
||||
@@ -802,16 +925,25 @@ class PRIMAITE(Env):
|
||||
ier_mission_criticality = item["missionCriticality"]
|
||||
|
||||
# Create IER and add to green IER dictionary
|
||||
self.green_iers[ier_id] = IER(ier_id, ier_start_step, ier_end_step, ier_load, ier_protocol, ier_port, ier_source, ier_destination, ier_mission_criticality)
|
||||
self.green_iers[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
Creates a red IER from config data
|
||||
Creates a red IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
ier_id = item["id"]
|
||||
ier_start_step = item["startStep"]
|
||||
ier_end_step = item["endStep"]
|
||||
@@ -823,21 +955,30 @@ class PRIMAITE(Env):
|
||||
ier_mission_criticality = item["missionCriticality"]
|
||||
|
||||
# Create IER and add to red IER dictionary
|
||||
self.red_iers[ier_id] = IER(ier_id, ier_start_step, ier_end_step, ier_load, ier_protocol, ier_port, ier_source, ier_destination, ier_mission_criticality)
|
||||
self.red_iers[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_green_pol(self, item):
|
||||
"""
|
||||
Creates a green PoL object from config data
|
||||
Creates a green PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
pol_id = item["id"]
|
||||
pol_start_step = item["startStep"]
|
||||
pol_end_step = item["endStep"]
|
||||
pol_node = item["nodeId"]
|
||||
pol_type = NODE_POL_TYPE[item["type"]]
|
||||
pol_type = NODE_POL_TYPE[item["type"]]
|
||||
|
||||
# State depends on whether this is Operating, O/S, file system or Service PoL type
|
||||
if pol_type == NODE_POL_TYPE.OPERATING:
|
||||
@@ -850,16 +991,23 @@ class PRIMAITE(Env):
|
||||
pol_protocol = item["protocol"]
|
||||
pol_state = SOFTWARE_STATE[item["state"]]
|
||||
|
||||
self.node_pol[pol_id] = NodeStateInstructionGreen(pol_id, pol_start_step, pol_end_step, pol_node, pol_type, pol_protocol, pol_state)
|
||||
self.node_pol[pol_id] = NodeStateInstructionGreen(
|
||||
pol_id,
|
||||
pol_start_step,
|
||||
pol_end_step,
|
||||
pol_node,
|
||||
pol_type,
|
||||
pol_protocol,
|
||||
pol_state,
|
||||
)
|
||||
|
||||
def create_red_pol(self, item):
|
||||
"""
|
||||
Creates a red PoL object from config data
|
||||
Creates a red PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
pol_id = item["id"]
|
||||
pol_start_step = item["startStep"]
|
||||
pol_end_step = item["endStep"]
|
||||
@@ -880,32 +1028,48 @@ class PRIMAITE(Env):
|
||||
pol_source_node_service = item["sourceNodeService"]
|
||||
pol_source_node_service_state = item["sourceNodeServiceState"]
|
||||
|
||||
self.red_node_pol[pol_id] = NodeStateInstructionRed(pol_id, pol_start_step, pol_end_step, pol_target_node_id, pol_initiator, pol_type, pol_protocol, pol_state, pol_source_node_id, pol_source_node_service, pol_source_node_service_state)
|
||||
self.red_node_pol[pol_id] = NodeStateInstructionRed(
|
||||
pol_id,
|
||||
pol_start_step,
|
||||
pol_end_step,
|
||||
pol_target_node_id,
|
||||
pol_initiator,
|
||||
pol_type,
|
||||
pol_protocol,
|
||||
pol_state,
|
||||
pol_source_node_id,
|
||||
pol_source_node_service,
|
||||
pol_source_node_service_state,
|
||||
)
|
||||
|
||||
def create_acl_rule(self, item):
|
||||
"""
|
||||
Creates an ACL rule from config data
|
||||
Creates an ACL rule from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
acl_rule_permission = item["permission"]
|
||||
acl_rule_source = item["source"]
|
||||
acl_rule_destination = item["destination"]
|
||||
acl_rule_protocol = item["protocol"]
|
||||
acl_rule_port = item["port"]
|
||||
|
||||
self.acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
self.acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
"""
|
||||
Creates a list of services (enum) from config data
|
||||
Creates a list of services (enum) from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the services
|
||||
"""
|
||||
|
||||
service_list = services["serviceList"]
|
||||
|
||||
for service in service_list:
|
||||
@@ -917,12 +1081,11 @@ class PRIMAITE(Env):
|
||||
|
||||
def create_ports_list(self, ports):
|
||||
"""
|
||||
Creates a list of ports from config data
|
||||
Creates a list of ports from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the ports
|
||||
"""
|
||||
|
||||
ports_list = ports["portsList"]
|
||||
|
||||
for port in ports_list:
|
||||
@@ -934,35 +1097,34 @@ class PRIMAITE(Env):
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info
|
||||
Extracts action_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing action info
|
||||
"""
|
||||
|
||||
self.action_type = ACTION_TYPE[action_info["type"]]
|
||||
|
||||
|
||||
def get_steps_info(self, steps_info):
|
||||
"""
|
||||
Extracts steps_info
|
||||
Extracts steps_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing steps info
|
||||
"""
|
||||
|
||||
self.episode_steps = int(steps_info["steps"])
|
||||
logging.info("Training episodes have " + str(self.episode_steps) + " steps")
|
||||
|
||||
def reset_environment(self):
|
||||
"""
|
||||
# Resets environment using config data config data in order to build the environment configuration
|
||||
"""
|
||||
# Resets environment.
|
||||
|
||||
Uses config data config data in order to build the environment
|
||||
configuration.
|
||||
"""
|
||||
for item in self.config_data:
|
||||
if item["itemType"] == "NODE":
|
||||
# Reset a node's state (normal and reference)
|
||||
self.reset_node(item)
|
||||
self.reset_node(item)
|
||||
elif item["itemType"] == "ACL_RULE":
|
||||
# Create an ACL rule (these are cleared on reset, so just need to recreate them)
|
||||
self.create_acl_rule(item)
|
||||
@@ -970,7 +1132,6 @@ class PRIMAITE(Env):
|
||||
# Do nothing (bad formatting or not relevant to reset)
|
||||
pass
|
||||
|
||||
|
||||
# Reset the IER status so they are not running initially
|
||||
# Green IERs
|
||||
for ier_key, ier_value in self.green_iers.items():
|
||||
@@ -981,12 +1142,11 @@ class PRIMAITE(Env):
|
||||
|
||||
def reset_node(self, item):
|
||||
"""
|
||||
Resets the statuses of a node
|
||||
Resets the statuses of a node.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
# All nodes have these parameters
|
||||
node_id = item["id"]
|
||||
node_base_type = item["baseType"]
|
||||
@@ -1027,10 +1187,3 @@ class PRIMAITE(Env):
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Implements reward function
|
||||
"""
|
||||
"""Implements reward function."""
|
||||
from primaite.common.enums import FILE_SYSTEM_STATE, HARDWARE_STATE, SOFTWARE_STATE
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
from common.enums import *
|
||||
from nodes.active_node import ActiveNode
|
||||
from nodes.service_node import ServiceNode
|
||||
|
||||
def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green_iers, red_iers, step_count, config_values):
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
):
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
Args:
|
||||
initial_nodes: The nodes before red and blue agents take effect
|
||||
@@ -20,29 +26,36 @@ def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green
|
||||
step_count: current step
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
reward_value = 0
|
||||
|
||||
# For each node, compare operating state, o/s operating state, service states
|
||||
for node_key, final_node in final_nodes.items():
|
||||
initial_node = initial_nodes[node_key]
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
|
||||
# Operating State
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
reward_value += score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# Operating System State
|
||||
if (isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode)):
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# Service State
|
||||
if (isinstance(final_node, ServiceNode)):
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
reward_value += score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
@@ -57,14 +70,17 @@ def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running():
|
||||
reward_value += config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
)
|
||||
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the operating state of a node
|
||||
Calculates score relating to the operating state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -72,8 +88,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
score = 0
|
||||
score = 0
|
||||
final_node_operating_state = final_node.get_state()
|
||||
initial_node_operating_state = initial_node.get_state()
|
||||
reference_node_operating_state = reference_node.get_state()
|
||||
@@ -81,7 +96,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
if final_node_operating_state == reference_node_operating_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
||||
if initial_node_operating_state == HARDWARE_STATE.ON:
|
||||
@@ -95,7 +110,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
if final_node_operating_state == HARDWARE_STATE.ON:
|
||||
score += config_values.on_should_be_off
|
||||
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
score += config_values.resetting_should_be_off
|
||||
score += config_values.resetting_should_be_off
|
||||
else:
|
||||
pass
|
||||
elif initial_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
@@ -112,9 +127,10 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the operating system state of a node
|
||||
Calculates score relating to the operating system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -122,8 +138,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
score = 0
|
||||
score = 0
|
||||
final_node_os_state = final_node.get_os_state()
|
||||
initial_node_os_state = initial_node.get_os_state()
|
||||
reference_node_os_state = reference_node.get_os_state()
|
||||
@@ -131,7 +146,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
if final_node_os_state == reference_node_os_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
||||
if initial_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
@@ -145,18 +160,18 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
if final_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
pass
|
||||
elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
if final_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
score += config_values.compromised
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
@@ -164,9 +179,10 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -174,12 +190,11 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
score = 0
|
||||
score = 0
|
||||
final_node_services = final_node.get_services()
|
||||
initial_node_services = initial_node.get_services()
|
||||
reference_node_services = reference_node.get_services()
|
||||
|
||||
|
||||
for service_key, final_service in final_node_services.items():
|
||||
reference_service = reference_node_services[service_key]
|
||||
initial_service = initial_node_services[service_key]
|
||||
@@ -203,11 +218,11 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
if final_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_patching
|
||||
score += config_values.overwhelmed_should_be_patching
|
||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
@@ -216,9 +231,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
score += config_values.compromised
|
||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_compromised
|
||||
score += config_values.overwhelmed_should_be_compromised
|
||||
else:
|
||||
pass
|
||||
elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
@@ -227,9 +242,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_overwhelmed
|
||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_overwhelmed
|
||||
score += config_values.compromised_should_be_overwhelmed
|
||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
score += config_values.overwhelmed
|
||||
score += config_values.overwhelmed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
@@ -237,17 +252,17 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the file system state of a node
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
"""
|
||||
|
||||
score = 0
|
||||
score = 0
|
||||
final_node_file_system_state = final_node.get_file_system_state_actual()
|
||||
initial_node_file_system_state = initial_node.get_file_system_state_actual()
|
||||
reference_node_file_system_state = reference_node.get_file_system_state_actual()
|
||||
@@ -259,7 +274,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
if final_node_file_system_state == reference_node_file_system_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
||||
if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
@@ -277,15 +292,15 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
score += config_values.good_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring_should_be_repairing
|
||||
score += config_values.restoring_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt_should_be_repairing
|
||||
score += config_values.corrupt_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed_should_be_repairing
|
||||
score += config_values.destroyed_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
score += config_values.repairing
|
||||
else:
|
||||
pass
|
||||
pass
|
||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
score += config_values.good_should_be_restoring
|
||||
@@ -294,9 +309,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt_should_be_restoring
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed_should_be_restoring
|
||||
score += config_values.destroyed_should_be_restoring
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring
|
||||
score += config_values.restoring
|
||||
else:
|
||||
pass
|
||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
@@ -307,9 +322,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring_should_be_corrupt
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed_should_be_corrupt
|
||||
score += config_values.destroyed_should_be_corrupt
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt
|
||||
score += config_values.corrupt
|
||||
else:
|
||||
pass
|
||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
@@ -320,9 +335,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring_should_be_destroyed
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt_should_be_destroyed
|
||||
score += config_values.corrupt_should_be_destroyed
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed
|
||||
score += config_values.destroyed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
@@ -332,9 +347,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
if final_node_scanning_state == reference_node_scanning_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# We're scanning the file system which incurs a penalty (as it slows down systems)
|
||||
score += config_values.scanning
|
||||
|
||||
return score
|
||||
return score
|
||||
@@ -1,19 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The link class
|
||||
"""
|
||||
"""The link class."""
|
||||
|
||||
from primaite.common.protocol import Protocol
|
||||
|
||||
from common.protocol import Protocol
|
||||
from common.enums import *
|
||||
|
||||
class Link(object):
|
||||
"""
|
||||
Link class
|
||||
"""
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The IER id
|
||||
@@ -22,9 +18,8 @@ class Link(object):
|
||||
_dest_node_name: The name of the destination node
|
||||
_protocols: The protocols to add to the link
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.bandwidth = _bandwidth
|
||||
self.bandwidth = _bandwidth
|
||||
self.source_node_name = _source_node_name
|
||||
self.dest_node_name = _dest_node_name
|
||||
self.protocol_list = []
|
||||
@@ -35,72 +30,65 @@ class Link(object):
|
||||
|
||||
def add_protocol(self, _protocol):
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to be added (enum)
|
||||
"""
|
||||
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets link ID
|
||||
Gets link ID.
|
||||
|
||||
Returns:
|
||||
Link ID
|
||||
"""
|
||||
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self):
|
||||
"""
|
||||
Gets source node name
|
||||
Gets source node name.
|
||||
|
||||
Returns:
|
||||
Source node name
|
||||
"""
|
||||
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self):
|
||||
"""
|
||||
Gets destination node name
|
||||
Gets destination node name.
|
||||
|
||||
Returns:
|
||||
Destination node name
|
||||
"""
|
||||
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self):
|
||||
"""
|
||||
Gets bandwidth of link
|
||||
Gets bandwidth of link.
|
||||
|
||||
Returns:
|
||||
Link bandwidth (bps)
|
||||
"""
|
||||
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self):
|
||||
"""
|
||||
Gets list of protocols on this link
|
||||
Gets list of protocols on this link.
|
||||
|
||||
Returns:
|
||||
List of protocols on this link
|
||||
"""
|
||||
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self):
|
||||
"""
|
||||
Gets current total load on this link
|
||||
Gets current total load on this link.
|
||||
|
||||
Returns:
|
||||
Total load on this link (bps)
|
||||
"""
|
||||
|
||||
total_load = 0
|
||||
for protocol in self.protocol_list:
|
||||
total_load += protocol.get_load()
|
||||
@@ -108,13 +96,12 @@ class Link(object):
|
||||
|
||||
def add_protocol_load(self, _protocol, _load):
|
||||
"""
|
||||
Adds a loading to a protocol on this link
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to load
|
||||
_load: The amount to load (bps)
|
||||
"""
|
||||
|
||||
for protocol in self.protocol_list:
|
||||
if protocol.get_name() == _protocol:
|
||||
protocol.add_load(_load)
|
||||
@@ -122,11 +109,6 @@ class Link(object):
|
||||
pass
|
||||
|
||||
def clear_traffic(self):
|
||||
"""
|
||||
Clears all traffic on this link
|
||||
"""
|
||||
|
||||
"""Clears all traffic on this link."""
|
||||
for protocol in self.protocol_list:
|
||||
protocol.clear_load()
|
||||
|
||||
|
||||
379
src/primaite/main.py
Normal file
379
src/primaite/main.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Primaite - main (harness) module.
|
||||
|
||||
Coding Standards: PEP 8
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os.path
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite.common.config_values_main import config_values_main
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
# FUNCTIONS #
|
||||
|
||||
|
||||
def run_generic():
|
||||
"""Run against a generic agent."""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo():
|
||||
"""Run against a stable_baselines3 PPO agent."""
|
||||
if config_values.load_agent == True:
|
||||
try:
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_a2c():
|
||||
"""Run against a stable_baselines3 A2C agent."""
|
||||
if config_values.load_agent == True:
|
||||
try:
|
||||
agent = A2C.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
logging.info("Starting training session...")
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
agent.learn(total_timesteps=1)
|
||||
save_agent(agent)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
logging.info("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def save_agent(_agent):
|
||||
"""Persist an agent (only works for stable baselines3 agents at present)."""
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
try:
|
||||
path = "outputs/agents/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/agents/agent_saved_" + time
|
||||
_agent.save(filename)
|
||||
logging.info("Trained agent saved as " + filename)
|
||||
except Exception:
|
||||
logging.error("Could not save agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
def configure_logging():
|
||||
"""Configures logging."""
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
filename = "logs/app_" + time + ".log"
|
||||
path = "logs/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
logging.basicConfig(
|
||||
filename=filename,
|
||||
filemode="w",
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%d-%b-%y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
except Exception:
|
||||
print("ERROR: Could not start logging")
|
||||
|
||||
|
||||
def load_config_values():
|
||||
"""Loads the config values from the main config file into a config object."""
|
||||
try:
|
||||
# Generic
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = config_data["configFilename"]
|
||||
config_values.session_type = config_data["sessionType"]
|
||||
config_values.load_agent = bool(config_data["loadAgent"])
|
||||
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||
# Environment
|
||||
config_values.observation_space_high_value = int(
|
||||
config_data["observationSpaceHighValue"]
|
||||
)
|
||||
# Reward values
|
||||
# Generic
|
||||
config_values.all_ok = int(config_data["allOk"])
|
||||
# Node Operating State
|
||||
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
|
||||
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
|
||||
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
|
||||
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
|
||||
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
|
||||
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
|
||||
config_values.resetting = int(config_data["resetting"])
|
||||
# Node O/S or Service State
|
||||
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
|
||||
config_values.good_should_be_compromised = int(
|
||||
config_data["goodShouldBeCompromised"]
|
||||
)
|
||||
config_values.good_should_be_overwhelmed = int(
|
||||
config_data["goodShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
|
||||
config_values.patching_should_be_compromised = int(
|
||||
config_data["patchingShouldBeCompromised"]
|
||||
)
|
||||
config_values.patching_should_be_overwhelmed = int(
|
||||
config_data["patchingShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching = int(config_data["patching"])
|
||||
config_values.compromised_should_be_good = int(
|
||||
config_data["compromisedShouldBeGood"]
|
||||
)
|
||||
config_values.compromised_should_be_patching = int(
|
||||
config_data["compromisedShouldBePatching"]
|
||||
)
|
||||
config_values.compromised_should_be_overwhelmed = int(
|
||||
config_data["compromisedShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.compromised = int(config_data["compromised"])
|
||||
config_values.overwhelmed_should_be_good = int(
|
||||
config_data["overwhelmedShouldBeGood"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_patching = int(
|
||||
config_data["overwhelmedShouldBePatching"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_compromised = int(
|
||||
config_data["overwhelmedShouldBeCompromised"]
|
||||
)
|
||||
config_values.overwhelmed = int(config_data["overwhelmed"])
|
||||
# Node File System State
|
||||
config_values.good_should_be_repairing = int(
|
||||
config_data["goodShouldBeRepairing"]
|
||||
)
|
||||
config_values.good_should_be_restoring = int(
|
||||
config_data["goodShouldBeRestoring"]
|
||||
)
|
||||
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
|
||||
config_values.good_should_be_destroyed = int(
|
||||
config_data["goodShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing_should_be_good = int(
|
||||
config_data["repairingShouldBeGood"]
|
||||
)
|
||||
config_values.repairing_should_be_restoring = int(
|
||||
config_data["repairingShouldBeRestoring"]
|
||||
)
|
||||
config_values.repairing_should_be_corrupt = int(
|
||||
config_data["repairingShouldBeCorrupt"]
|
||||
)
|
||||
config_values.repairing_should_be_destroyed = int(
|
||||
config_data["repairingShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing = int(config_data["repairing"])
|
||||
config_values.restoring_should_be_good = int(
|
||||
config_data["restoringShouldBeGood"]
|
||||
)
|
||||
config_values.restoring_should_be_repairing = int(
|
||||
config_data["restoringShouldBeRepairing"]
|
||||
)
|
||||
config_values.restoring_should_be_corrupt = int(
|
||||
config_data["restoringShouldBeCorrupt"]
|
||||
)
|
||||
config_values.restoring_should_be_destroyed = int(
|
||||
config_data["restoringShouldBeDestroyed"]
|
||||
)
|
||||
config_values.restoring = int(config_data["restoring"])
|
||||
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
|
||||
config_values.corrupt_should_be_repairing = int(
|
||||
config_data["corruptShouldBeRepairing"]
|
||||
)
|
||||
config_values.corrupt_should_be_restoring = int(
|
||||
config_data["corruptShouldBeRestoring"]
|
||||
)
|
||||
config_values.corrupt_should_be_destroyed = int(
|
||||
config_data["corruptShouldBeDestroyed"]
|
||||
)
|
||||
config_values.corrupt = int(config_data["corrupt"])
|
||||
config_values.destroyed_should_be_good = int(
|
||||
config_data["destroyedShouldBeGood"]
|
||||
)
|
||||
config_values.destroyed_should_be_repairing = int(
|
||||
config_data["destroyedShouldBeRepairing"]
|
||||
)
|
||||
config_values.destroyed_should_be_restoring = int(
|
||||
config_data["destroyedShouldBeRestoring"]
|
||||
)
|
||||
config_values.destroyed_should_be_corrupt = int(
|
||||
config_data["destroyedShouldBeCorrupt"]
|
||||
)
|
||||
config_values.destroyed = int(config_data["destroyed"])
|
||||
config_values.scanning = int(config_data["scanning"])
|
||||
# IER status
|
||||
config_values.red_ier_running = int(config_data["redIerRunning"])
|
||||
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
|
||||
# Patching / Reset durations
|
||||
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
|
||||
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
|
||||
config_values.service_patching_duration = int(
|
||||
config_data["servicePatchingDuration"]
|
||||
)
|
||||
config_values.file_system_repairing_limit = int(
|
||||
config_data["fileSystemRepairingLimit"]
|
||||
)
|
||||
config_values.file_system_restoring_limit = int(
|
||||
config_data["fileSystemRestoringLimit"]
|
||||
)
|
||||
config_values.file_system_scanning_limit = int(
|
||||
config_data["fileSystemScanningLimit"]
|
||||
)
|
||||
|
||||
logging.info("Training agent: " + config_values.agent_identifier)
|
||||
logging.info(
|
||||
"Training environment config: " + config_values.config_filename_use_case
|
||||
)
|
||||
logging.info(
|
||||
"Training cycle has " + str(config_values.num_episodes) + " episodes"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logging.error("Could not save load config data")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
# MAIN PROCESS #
|
||||
|
||||
# Starting point
|
||||
|
||||
# Welcome message
|
||||
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
|
||||
|
||||
# Configure logging
|
||||
configure_logging()
|
||||
|
||||
# Open the main config file
|
||||
try:
|
||||
config_file_main = open("config/config_main.yaml", "r")
|
||||
config_data = yaml.safe_load(config_file_main)
|
||||
# Create a config class
|
||||
config_values = config_values_main()
|
||||
# Load in config data
|
||||
load_config_values()
|
||||
except Exception:
|
||||
logging.error("Could not load main config")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Create a list of transactions
|
||||
# A transaction is an object holding the:
|
||||
# - episode #
|
||||
# - step #
|
||||
# - initial observation space
|
||||
# - action
|
||||
# - reward
|
||||
# - new observation space
|
||||
transaction_list = []
|
||||
|
||||
# Create the Primaite environment
|
||||
try:
|
||||
env = Primaite(config_values, transaction_list)
|
||||
logging.info("PrimAITE environment created")
|
||||
except Exception:
|
||||
logging.error("Could not create PrimAITE environment")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
# Run environment against an agent
|
||||
if config_values.agent_identifier == "GENERIC":
|
||||
run_generic()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo()
|
||||
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c()
|
||||
|
||||
print("Session finished")
|
||||
logging.info("Session finished")
|
||||
|
||||
print("Saving transaction logs...")
|
||||
logging.info("Saving transaction logs...")
|
||||
|
||||
write_transaction_to_file(transaction_list)
|
||||
|
||||
config_file_main.close
|
||||
|
||||
print("Finished")
|
||||
logging.info("Finished")
|
||||
@@ -1,19 +1,26 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
An Active Node (i.e. not an actuator)
|
||||
"""
|
||||
"""An Active Node (i.e. not an actuator)."""
|
||||
from primaite.common.enums import FILE_SYSTEM_STATE, SOFTWARE_STATE
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
from nodes.node import Node
|
||||
from common.enums import *
|
||||
|
||||
class ActiveNode(Node):
|
||||
"""
|
||||
Active Node class
|
||||
"""
|
||||
"""Active Node class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_name,
|
||||
_type,
|
||||
_priority,
|
||||
_state,
|
||||
_ip_address,
|
||||
_os_state,
|
||||
_file_system_state,
|
||||
_config_values,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node ID
|
||||
@@ -26,7 +33,6 @@ class ActiveNode(Node):
|
||||
_file_system_state: The node file system state
|
||||
_config_values: The config values
|
||||
"""
|
||||
|
||||
super().__init__(_id, _name, _type, _priority, _state, _config_values)
|
||||
self.ip_address = _ip_address
|
||||
# Related to O/S
|
||||
@@ -39,20 +45,18 @@ class ActiveNode(Node):
|
||||
self.file_system_scanning_count = 0
|
||||
self.file_system_action_count = 0
|
||||
|
||||
|
||||
def set_ip_address(self, _ip_address):
|
||||
"""
|
||||
Sets IP address
|
||||
Sets IP address.
|
||||
|
||||
Args:
|
||||
_ip_address: IP address
|
||||
"""
|
||||
|
||||
self.ip_address = _ip_address
|
||||
|
||||
def get_ip_address(self):
|
||||
"""
|
||||
Gets IP address
|
||||
Gets IP address.
|
||||
|
||||
Returns:
|
||||
IP address
|
||||
@@ -61,24 +65,22 @@ class ActiveNode(Node):
|
||||
|
||||
def set_os_state(self, _os_state):
|
||||
"""
|
||||
Sets operating system state
|
||||
Sets operating system state.
|
||||
|
||||
Args:
|
||||
_os_state: Operating system state
|
||||
"""
|
||||
|
||||
self.os_state = _os_state
|
||||
if _os_state == SOFTWARE_STATE.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
|
||||
def set_os_state_if_not_compromised(self, _os_state):
|
||||
"""
|
||||
Sets operating system state if the node is not compromised
|
||||
Sets operating system state if the node is not compromised.
|
||||
|
||||
Args:
|
||||
_os_state: Operating system state
|
||||
"""
|
||||
|
||||
if self.os_state != SOFTWARE_STATE.COMPROMISED:
|
||||
self.os_state = _os_state
|
||||
if _os_state == SOFTWARE_STATE.PATCHING:
|
||||
@@ -86,19 +88,15 @@ class ActiveNode(Node):
|
||||
|
||||
def get_os_state(self):
|
||||
"""
|
||||
Gets operating system state
|
||||
Gets operating system state.
|
||||
|
||||
Returns:
|
||||
Operating system state
|
||||
"""
|
||||
|
||||
return self.os_state
|
||||
|
||||
def update_os_patching_status(self):
|
||||
"""
|
||||
Updates operating system status based on patching cycle
|
||||
"""
|
||||
|
||||
"""Updates operating system status based on patching cycle."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
@@ -106,87 +104,88 @@ class ActiveNode(Node):
|
||||
|
||||
def set_file_system_state(self, _file_system_state):
|
||||
"""
|
||||
Sets the file system state (actual and observed)
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
Args:
|
||||
_file_system_state: File system state
|
||||
"""
|
||||
|
||||
self.file_system_state_actual = _file_system_state
|
||||
|
||||
if _file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, _file_system_state):
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
Use for green PoL to prevent it overturning a compromised state
|
||||
|
||||
Args:
|
||||
_file_system_state: File system state
|
||||
"""
|
||||
|
||||
if self.file_system_state_actual != FILE_SYSTEM_STATE.CORRUPT and self.file_system_state_actual != FILE_SYSTEM_STATE.DESTROYED:
|
||||
if (
|
||||
self.file_system_state_actual != FILE_SYSTEM_STATE.CORRUPT
|
||||
and self.file_system_state_actual != FILE_SYSTEM_STATE.DESTROYED
|
||||
):
|
||||
self.file_system_state_actual = _file_system_state
|
||||
|
||||
if _file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
|
||||
|
||||
def get_file_system_state_actual(self):
|
||||
"""
|
||||
Gets file system state (actual)
|
||||
Gets file system state (actual).
|
||||
|
||||
Returns:
|
||||
File system state (actual)
|
||||
"""
|
||||
|
||||
return self.file_system_state_actual
|
||||
|
||||
def get_file_system_state_observed(self):
|
||||
"""
|
||||
Gets file system state (observed)
|
||||
Gets file system state (observed).
|
||||
|
||||
Returns:
|
||||
File system state (observed)
|
||||
"""
|
||||
|
||||
return self.file_system_state_observed
|
||||
|
||||
def start_file_system_scan(self):
|
||||
"""
|
||||
Starts a file system scan
|
||||
"""
|
||||
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def is_scanning_file_system(self):
|
||||
"""
|
||||
Gets true/false on whether file system is being scanned
|
||||
Gets true/false on whether file system is being scanned.
|
||||
|
||||
Returns:
|
||||
True if file system is being scanned
|
||||
"""
|
||||
|
||||
return self.file_system_scanning
|
||||
|
||||
def update_file_system_state(self):
|
||||
"""
|
||||
Updates file system status based on scanning / restore / repair cycle
|
||||
"""
|
||||
|
||||
"""Updates file system status based on scanning / restore / repair cycle."""
|
||||
# Deprecate both the action count (for restoring or reparing) and the scanning count
|
||||
self.file_system_action_count -= 1
|
||||
self.file_system_scanning_count -= 1
|
||||
@@ -194,7 +193,10 @@ class ActiveNode(Node):
|
||||
# Reparing / Restoring updates
|
||||
if self.file_system_action_count <= 0:
|
||||
self.file_system_action_count = 0
|
||||
if self.file_system_state_actual == FILE_SYSTEM_STATE.REPAIRING or self.file_system_state_actual == FILE_SYSTEM_STATE.RESTORING:
|
||||
if (
|
||||
self.file_system_state_actual == FILE_SYSTEM_STATE.REPAIRING
|
||||
or self.file_system_state_actual == FILE_SYSTEM_STATE.RESTORING
|
||||
):
|
||||
self.file_system_state_actual = FILE_SYSTEM_STATE.GOOD
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The base Node class
|
||||
"""
|
||||
"""The base Node class."""
|
||||
from primaite.common.enums import HARDWARE_STATE
|
||||
|
||||
from common.enums import *
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Node class
|
||||
"""
|
||||
"""Node class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _config_values):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node id
|
||||
@@ -21,156 +17,124 @@ class Node:
|
||||
_priority: The priority of the node
|
||||
_state: The state of the node
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.name = _name
|
||||
self.type = _type
|
||||
self.priority = _priority
|
||||
self.operating_state = _state
|
||||
self.resetting_count = 0
|
||||
self.resetting_count = 0
|
||||
self.config_values = _config_values
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Returns the name of the node
|
||||
"""
|
||||
|
||||
"""Returns the name of the node."""
|
||||
return self.name
|
||||
|
||||
def set_id(self, _id):
|
||||
"""
|
||||
Sets the node ID
|
||||
Sets the node ID.
|
||||
|
||||
Args:
|
||||
_id: The node ID
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets the node ID
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
|
||||
return self.id
|
||||
|
||||
def set_name(self, _name):
|
||||
"""
|
||||
Sets the node name
|
||||
Sets the node name.
|
||||
|
||||
Args:
|
||||
_name: The node name
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
Gets the node name
|
||||
Gets the node name.
|
||||
|
||||
Returns:
|
||||
The node name
|
||||
"""
|
||||
|
||||
return self.name
|
||||
|
||||
def set_type(self, _type):
|
||||
"""
|
||||
Sets the node type
|
||||
Sets the node type.
|
||||
|
||||
Args:
|
||||
_type: The node type
|
||||
"""
|
||||
|
||||
self.type = _type
|
||||
|
||||
def get_type(self):
|
||||
"""
|
||||
Gets the node type
|
||||
Gets the node type.
|
||||
|
||||
Returns:
|
||||
The node type
|
||||
"""
|
||||
|
||||
return self.type
|
||||
|
||||
def set_priority(self, _priority):
|
||||
"""
|
||||
Sets the node priority
|
||||
Sets the node priority.
|
||||
|
||||
Args:
|
||||
_priority: The node priority
|
||||
"""
|
||||
|
||||
self.priority = _priority
|
||||
|
||||
def get_priority(self):
|
||||
"""
|
||||
Gets the node priority
|
||||
Gets the node priority.
|
||||
|
||||
Returns:
|
||||
The node priority
|
||||
"""
|
||||
|
||||
return self.priority
|
||||
|
||||
def set_state(self, _state):
|
||||
"""
|
||||
Sets the node state
|
||||
Sets the node state.
|
||||
|
||||
Args:
|
||||
_state: The node state
|
||||
"""
|
||||
|
||||
self.operating_state = _state
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the node operating state
|
||||
Gets the node operating state.
|
||||
|
||||
Returns:
|
||||
The node operating state
|
||||
"""
|
||||
|
||||
return self.operating_state
|
||||
|
||||
def turn_on(self):
|
||||
"""
|
||||
Sets the node state to ON
|
||||
"""
|
||||
|
||||
"""Sets the node state to ON."""
|
||||
self.operating_state = HARDWARE_STATE.ON
|
||||
|
||||
def turn_off(self):
|
||||
"""
|
||||
Sets the node state to OFF
|
||||
"""
|
||||
|
||||
"""Sets the node state to OFF."""
|
||||
self.operating_state = HARDWARE_STATE.OFF
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Sets the node state to Resetting and starts the reset count
|
||||
"""
|
||||
|
||||
"""Sets the node state to Resetting and starts the reset count."""
|
||||
self.operating_state = HARDWARE_STATE.RESETTING
|
||||
self.resetting_count = self.config_values.node_reset_duration
|
||||
|
||||
def update_resetting_status(self):
|
||||
"""
|
||||
Updates the resetting count
|
||||
"""
|
||||
|
||||
"""Updates the resetting count."""
|
||||
self.resetting_count -= 1
|
||||
if self.resetting_count <= 0:
|
||||
self.resetting_count = 0
|
||||
self.operating_state = HARDWARE_STATE.ON
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Defines node behaviour for Green PoL
|
||||
"""
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
The Node State Instruction class
|
||||
"""
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(self, _id, _start_step, _end_step, _node_id, _node_pol_type, _service_name, _state):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_node_id,
|
||||
_node_pol_type,
|
||||
_service_name,
|
||||
_state,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node state instruction id
|
||||
@@ -21,72 +27,64 @@ class NodeStateInstructionGreen(object):
|
||||
_service_name: The service name
|
||||
_state: The state (node or service)
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type = _node_pol_type
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.state = _state
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.state = _state
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets the start step
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets the end step
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self):
|
||||
"""
|
||||
Gets the node ID
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self):
|
||||
"""
|
||||
Gets the node pattern of life type (enum)
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""
|
||||
Gets the service name
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the state (node or service)
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
|
||||
return self.state
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Defines node behaviour for Green PoL
|
||||
"""
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
|
||||
|
||||
class NodeStateInstructionRed(object):
|
||||
"""
|
||||
The Node State Instruction class
|
||||
"""
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(self, _id, _start_step, _end_step, _target_node_id, _pol_initiator, _pol_type, pol_protocol, _pol_state, _pol_source_node_id, _pol_source_node_service, _pol_source_node_service_state):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_pol_type,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node state instruction id
|
||||
@@ -25,14 +35,13 @@ class NodeStateInstructionRed(object):
|
||||
_pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
_pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.pol_type = _pol_type
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
@@ -40,101 +49,90 @@ class NodeStateInstructionRed(object):
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets the start step
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets the end step
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self):
|
||||
"""
|
||||
Gets the node ID
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self):
|
||||
"""
|
||||
Gets the initiator
|
||||
Gets the initiator.
|
||||
|
||||
Returns:
|
||||
The initiator
|
||||
"""
|
||||
|
||||
return self.initiator
|
||||
|
||||
def get_pol_type(self):
|
||||
"""
|
||||
Gets the node pattern of life type (enum)
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""
|
||||
Gets the service name
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the state (node or service)
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE)
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node id
|
||||
"""
|
||||
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self):
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE)
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service
|
||||
"""
|
||||
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self):
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE)
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service state
|
||||
"""
|
||||
|
||||
return self.source_node_service_state
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The Passive Node class (i.e. an actuator)
|
||||
"""
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
from nodes.node import Node
|
||||
|
||||
class PassiveNode(Node):
|
||||
"""
|
||||
The Passive Node class
|
||||
"""
|
||||
"""The Passive Node class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _config_values):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node id
|
||||
@@ -21,17 +18,15 @@ class PassiveNode(Node):
|
||||
_priority: The priority of the node
|
||||
_state: The state of the node
|
||||
"""
|
||||
|
||||
# Pass through to Super for now
|
||||
super().__init__(_id, _name, _type, _priority, _state, _config_values)
|
||||
|
||||
def get_ip_address(self):
|
||||
"""
|
||||
Gets the node IP address
|
||||
Gets the node IP address.
|
||||
|
||||
Returns:
|
||||
The node IP address
|
||||
"""
|
||||
|
||||
# No concept of IP address for passive nodes for now
|
||||
return ""
|
||||
return ""
|
||||
@@ -1,19 +1,26 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
A Service Node (i.e. not an actuator)
|
||||
"""
|
||||
"""A Service Node (i.e. not an actuator)."""
|
||||
from primaite.common.enums import SOFTWARE_STATE
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
|
||||
from nodes.active_node import ActiveNode
|
||||
from common.enums import *
|
||||
|
||||
class ServiceNode(ActiveNode):
|
||||
"""
|
||||
ServiceNode class
|
||||
"""
|
||||
"""ServiceNode class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_name,
|
||||
_type,
|
||||
_priority,
|
||||
_state,
|
||||
_ip_address,
|
||||
_os_state,
|
||||
_file_system_state,
|
||||
_config_values,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node id
|
||||
@@ -25,38 +32,44 @@ class ServiceNode(ActiveNode):
|
||||
_osState: The operating system state of the node
|
||||
_file_system_state: The file system state of the node
|
||||
"""
|
||||
|
||||
super().__init__(_id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values)
|
||||
super().__init__(
|
||||
_id,
|
||||
_name,
|
||||
_type,
|
||||
_priority,
|
||||
_state,
|
||||
_ip_address,
|
||||
_os_state,
|
||||
_file_system_state,
|
||||
_config_values,
|
||||
)
|
||||
self.services = {}
|
||||
|
||||
def add_service(self, _service):
|
||||
"""
|
||||
Adds a service to the node
|
||||
Adds a service to the node.
|
||||
|
||||
Args:
|
||||
_service: The service to add
|
||||
"""
|
||||
|
||||
self.services[_service.get_name()] = _service
|
||||
|
||||
def get_services(self):
|
||||
"""
|
||||
Gets the dictionary of services on this node
|
||||
Gets the dictionary of services on this node.
|
||||
|
||||
Returns:
|
||||
Dictionary of services on this node
|
||||
"""
|
||||
|
||||
return self.services
|
||||
|
||||
def has_service(self, _protocol):
|
||||
"""
|
||||
Indicates whether a service is on a node
|
||||
Indicates whether a service is on a node.
|
||||
|
||||
Returns:
|
||||
True if service (protocol) is on the node
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
return True
|
||||
@@ -66,12 +79,11 @@ class ServiceNode(ActiveNode):
|
||||
|
||||
def service_running(self, _protocol):
|
||||
"""
|
||||
Indicates whether a service is in a running state on the node
|
||||
Indicates whether a service is in a running state on the node.
|
||||
|
||||
Returns:
|
||||
True if service (protocol) is in a running state on the node
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
if service_value.get_state() != SOFTWARE_STATE.PATCHING:
|
||||
@@ -84,12 +96,11 @@ class ServiceNode(ActiveNode):
|
||||
|
||||
def service_is_overwhelmed(self, _protocol):
|
||||
"""
|
||||
Indicates whether a service is in an overwhelmed state on the node
|
||||
Indicates whether a service is in an overwhelmed state on the node.
|
||||
|
||||
Returns:
|
||||
True if service (protocol) is in an overwhelmed state on the node
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
if service_value.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
@@ -102,61 +113,61 @@ class ServiceNode(ActiveNode):
|
||||
|
||||
def set_service_state(self, _protocol, _state):
|
||||
"""
|
||||
Sets the state of a service (protocol) on the node
|
||||
Sets the state of a service (protocol) on the node.
|
||||
|
||||
Args:
|
||||
_protocol: The service (protocol)
|
||||
_state: The state value
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
# Can't set to compromised if you're in a patching state
|
||||
if (_state == SOFTWARE_STATE.COMPROMISED and service_value.get_state() != SOFTWARE_STATE.PATCHING) or _state != SOFTWARE_STATE.COMPROMISED:
|
||||
if (
|
||||
_state == SOFTWARE_STATE.COMPROMISED
|
||||
and service_value.get_state() != SOFTWARE_STATE.PATCHING
|
||||
) or _state != SOFTWARE_STATE.COMPROMISED:
|
||||
service_value.set_state(_state)
|
||||
else:
|
||||
# Do nothing
|
||||
pass
|
||||
if _state == SOFTWARE_STATE.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
else:
|
||||
# Do nothing
|
||||
pass
|
||||
|
||||
def set_service_state_if_not_compromised(self, _protocol, _state):
|
||||
"""
|
||||
Sets the state of a service (protocol) on the node if the operating state is not "compromised"
|
||||
Sets the state of a service (protocol) on the node if the operating state is not "compromised".
|
||||
|
||||
Args:
|
||||
_protocol: The service (protocol)
|
||||
_state: The state value
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
if service_value.get_state() != SOFTWARE_STATE.COMPROMISED:
|
||||
service_value.set_state(_state)
|
||||
if _state == SOFTWARE_STATE.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
|
||||
def get_service_state(self, _protocol):
|
||||
"""
|
||||
Gets the state of a service
|
||||
Gets the state of a service.
|
||||
|
||||
Returns:
|
||||
The state of the service
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
return service_value.get_state()
|
||||
|
||||
def update_services_patching_status(self):
|
||||
"""
|
||||
Updates the patching counter for any service that are patching
|
||||
"""
|
||||
|
||||
"""Updates the patching counter for any service that are patching."""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_value.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
@@ -1,19 +1,18 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Implements Pattern of Life on the network (nodes and links)
|
||||
"""
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
|
||||
from networkx import shortest_path
|
||||
|
||||
from common.enums import *
|
||||
from nodes.active_node import ActiveNode
|
||||
from nodes.service_node import ServiceNode
|
||||
from primaite.common.enums import HARDWARE_STATE, NODE_POL_TYPE, SOFTWARE_STATE, TYPE
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_VERBOSE = False
|
||||
|
||||
|
||||
def apply_iers(network, nodes, links, iers, acl, step):
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life)
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -21,9 +20,8 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
links: The links within the environment
|
||||
iers: The IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
if _VERBOSE:
|
||||
print("Applying IERs")
|
||||
|
||||
@@ -38,7 +36,7 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
@@ -46,8 +44,8 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
@@ -55,7 +53,10 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
# 1. Check the source node situation
|
||||
if source_node.get_type() == TYPE.SWITCH:
|
||||
# It's a switch
|
||||
if source_node.get_state() == HARDWARE_STATE.ON and source_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
source_node.get_state() == HARDWARE_STATE.ON
|
||||
and source_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
@@ -66,9 +67,14 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if source_node.get_state() == HARDWARE_STATE.ON and source_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
source_node.get_state() == HARDWARE_STATE.ON
|
||||
and source_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
if source_node.service_running(
|
||||
protocol
|
||||
) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -80,11 +86,13 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.get_type() == TYPE.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.get_state() == HARDWARE_STATE.ON and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
dest_node.get_state() == HARDWARE_STATE.ON
|
||||
and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
@@ -94,9 +102,14 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.get_state() == HARDWARE_STATE.ON and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
dest_node.get_state() == HARDWARE_STATE.ON
|
||||
and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
if dest_node.service_running(
|
||||
protocol
|
||||
) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
@@ -109,10 +122,21 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port)
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port
|
||||
)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print("ACL block on source: " + source_node.get_ip_address() + ", dest: " + dest_node.get_ip_address() + ", protocol: " + protocol + ", port: " + port)
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.get_ip_address()
|
||||
+ ", dest: "
|
||||
+ dest_node.get_ip_address()
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
@@ -131,20 +155,25 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if node.get_state() != HARDWARE_STATE.ON or node.get_os_state() == SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
node.get_state() != HARDWARE_STATE.ON
|
||||
or node.get_os_state() == SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
path_valid = False
|
||||
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
count = 0
|
||||
link_capacity_exceeded = False
|
||||
|
||||
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
@@ -152,7 +181,7 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count+=1
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
@@ -160,20 +189,22 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count+=1
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
ier_value.set_is_running(True)
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Source, Dest or ACL were not valid")
|
||||
@@ -183,19 +214,19 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
|
||||
def apply_node_pol(nodes, node_pol, step):
|
||||
"""
|
||||
Applies node pattern of life
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
node_pol: The node pattern of life to apply
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
if _VERBOSE:
|
||||
print("Applying Node PoL")
|
||||
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
@@ -205,7 +236,7 @@ def apply_node_pol(nodes, node_pol, step):
|
||||
state = node_instruction.get_state()
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
# continue --------------------------
|
||||
node = nodes[node_id]
|
||||
|
||||
if node_pol_type == NODE_POL_TYPE.OPERATING:
|
||||
@@ -227,4 +258,4 @@ def apply_node_pol(nodes, node_pol, step):
|
||||
node.set_file_system_state_if_not_compromised(state)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
pass
|
||||
@@ -1,17 +1,29 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Information Exchange Requirements for APE
|
||||
Used to represent an information flow from source to destination
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
Used to represent an information flow from source to destination.
|
||||
"""
|
||||
|
||||
class IER(object):
|
||||
"""
|
||||
Information Exchange Requirement class
|
||||
"""
|
||||
|
||||
def __init__(self, _id, _start_step, _end_step, _load, _protocol, _port, _source_node_id, _dest_node_id, _mission_criticality, _running=False):
|
||||
class IER(object):
|
||||
"""Information Exchange Requirement class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_load,
|
||||
_protocol,
|
||||
_port,
|
||||
_source_node_id,
|
||||
_dest_node_id,
|
||||
_mission_criticality,
|
||||
_running=False,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The IER id
|
||||
@@ -25,13 +37,12 @@ class IER(object):
|
||||
_mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
_running: Indicates whether the IER is currently running
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.source_node_id = _source_node_id
|
||||
self.dest_node_id = _dest_node_id
|
||||
self.load = _load
|
||||
self.load = _load
|
||||
self.protocol = _protocol
|
||||
self.port = _port
|
||||
self.mission_criticality = _mission_criticality
|
||||
@@ -39,97 +50,88 @@ class IER(object):
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets IER ID
|
||||
Gets IER ID.
|
||||
|
||||
Returns:
|
||||
IER ID
|
||||
"""
|
||||
|
||||
return self.id
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets IER start step
|
||||
Gets IER start step.
|
||||
|
||||
Returns:
|
||||
IER start step
|
||||
"""
|
||||
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets IER end step
|
||||
Gets IER end step.
|
||||
|
||||
Returns:
|
||||
IER end step
|
||||
"""
|
||||
|
||||
return self.end_step
|
||||
|
||||
def get_load(self):
|
||||
"""
|
||||
Gets IER load
|
||||
Gets IER load.
|
||||
|
||||
Returns:
|
||||
IER load
|
||||
"""
|
||||
|
||||
return self.load
|
||||
|
||||
def get_protocol(self):
|
||||
"""
|
||||
Gets IER protocol
|
||||
Gets IER protocol.
|
||||
|
||||
Returns:
|
||||
IER protocol
|
||||
"""
|
||||
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets IER port
|
||||
Gets IER port.
|
||||
|
||||
Returns:
|
||||
IER port
|
||||
"""
|
||||
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""
|
||||
Gets IER source node ID
|
||||
Gets IER source node ID.
|
||||
|
||||
Returns:
|
||||
IER source node ID
|
||||
"""
|
||||
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self):
|
||||
"""
|
||||
Gets IER destination node ID
|
||||
Gets IER destination node ID.
|
||||
|
||||
Returns:
|
||||
IER destination node ID
|
||||
"""
|
||||
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self):
|
||||
"""
|
||||
Informs whether the IER is currently running
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
Returns:
|
||||
True if running
|
||||
"""
|
||||
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value):
|
||||
"""
|
||||
Sets the running state of the IER
|
||||
Sets the running state of the IER.
|
||||
|
||||
Args:
|
||||
_value: running status
|
||||
@@ -138,10 +140,9 @@ class IER(object):
|
||||
|
||||
def get_mission_criticality(self):
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function)
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
Returns:
|
||||
Mission criticality value (0 lowest to 5 highest)
|
||||
"""
|
||||
|
||||
return self.mission_criticality
|
||||
return self.mission_criticality
|
||||
@@ -1,19 +1,24 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Implements Pattern of Life on the network (nodes and links) resulting from the red agent attack
|
||||
"""
|
||||
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
|
||||
|
||||
from networkx import shortest_path
|
||||
|
||||
from common.enums import *
|
||||
from nodes.active_node import ActiveNode
|
||||
from nodes.service_node import ServiceNode
|
||||
from primaite.common.enums import (
|
||||
HARDWARE_STATE,
|
||||
NODE_POL_INITIATOR,
|
||||
NODE_POL_TYPE,
|
||||
SOFTWARE_STATE,
|
||||
TYPE,
|
||||
)
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_VERBOSE = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life) resulting from red agent attack
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -21,9 +26,8 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
links: The links within the environment
|
||||
iers: The red agent IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
@@ -35,7 +39,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
@@ -43,8 +47,8 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
@@ -66,7 +70,10 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
if source_node.get_state() == HARDWARE_STATE.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if source_node.get_service_state(protocol) == SOFTWARE_STATE.COMPROMISED:
|
||||
if (
|
||||
source_node.get_service_state(protocol)
|
||||
== SOFTWARE_STATE.COMPROMISED
|
||||
):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -78,7 +85,6 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.get_type() == TYPE.SWITCH:
|
||||
# It's a switch
|
||||
@@ -105,10 +111,21 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port)
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port
|
||||
)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print("ACL block on source: " + source_node.get_ip_address() + ", dest: " + dest_node.get_ip_address() + ", protocol: " + protocol + ", port: " + port)
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.get_ip_address()
|
||||
+ ", dest: "
|
||||
+ dest_node.get_ip_address()
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
@@ -130,7 +147,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
for node in path_node_list:
|
||||
if node.get_state() != HARDWARE_STATE.ON:
|
||||
path_valid = False
|
||||
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
@@ -140,8 +157,10 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
@@ -149,7 +168,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count+=1
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
@@ -157,12 +176,14 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count+=1
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
@@ -172,7 +193,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Red IER was NOT allowed to run in step " + str(step))
|
||||
@@ -185,20 +206,20 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
"""
|
||||
Applies node pattern of life
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
iers: The red agent IERs
|
||||
node_pol: The red agent node pattern of life to apply
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
if _VERBOSE:
|
||||
print("Applying Node Red Agent PoL")
|
||||
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
@@ -209,12 +230,14 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
source_node_service_state_value = (
|
||||
node_instruction.get_source_node_service_state()
|
||||
)
|
||||
|
||||
passed_checks = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
# continue --------------------------
|
||||
target_node = nodes[target_node_id]
|
||||
|
||||
# Based the action taken on the initiator type
|
||||
@@ -228,7 +251,10 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
if source_node.has_service(source_node_service_name):
|
||||
if source_node.get_service_state(source_node_service_name) == SOFTWARE_STATE[source_node_service_state_value]:
|
||||
if (
|
||||
source_node.get_service_state(source_node_service_name)
|
||||
== SOFTWARE_STATE[source_node_service_state_value]
|
||||
):
|
||||
passed_checks = True
|
||||
else:
|
||||
# Do nothing, no matching state value
|
||||
@@ -248,7 +274,9 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
target_node.set_state(state)
|
||||
elif pol_type == NODE_POL_TYPE.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
target_node.set_os_state(state)
|
||||
elif pol_type == NODE_POL_TYPE.SERVICE:
|
||||
# Change a service state
|
||||
@@ -256,23 +284,34 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
|
||||
def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
"""
|
||||
Checks if the RED IER is incoming.
|
||||
|
||||
TODO: Write more descriptive docstring with params and returns.
|
||||
"""
|
||||
node_id = node.get_id()
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
for ier_key, ier_value in iers.items():
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if node_pol_type == NODE_POL_TYPE.OPERATING or node_pol_type == NODE_POL_TYPE.OS or node_pol_type == NODE_POL_TYPE.FILE:
|
||||
# It's looking to change operating state, file system or O/S state, so valid
|
||||
if (
|
||||
node_pol_type == NODE_POL_TYPE.OPERATING
|
||||
or node_pol_type == NODE_POL_TYPE.OS
|
||||
or node_pol_type == NODE_POL_TYPE.FILE
|
||||
):
|
||||
# It's looking to change operating state, file system or O/S state, so valid
|
||||
return True
|
||||
elif node_pol_type == NODE_POL_TYPE.SERVICE:
|
||||
# Check if the service is present on the node and running
|
||||
@@ -297,5 +336,3 @@ def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
else:
|
||||
# The IER destination is not this node, or the IER is not running
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
@@ -1,69 +1,57 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The Transaction class
|
||||
"""
|
||||
"""The Transaction class."""
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""
|
||||
Transaction class
|
||||
"""
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_timestamp: The time this object was created
|
||||
_agent_identifier: An identifier for the agent in use
|
||||
_episode_number: The episode number
|
||||
_step_number: The step number
|
||||
_step_number: The step number
|
||||
"""
|
||||
|
||||
self.timestamp = _timestamp
|
||||
self.agent_identifier = _agent_identifier
|
||||
self.episode_number = _episode_number
|
||||
self.step_number = _step_number
|
||||
self.step_number = _step_number
|
||||
|
||||
def set_obs_space_pre(self, _obs_space_pre):
|
||||
"""
|
||||
Sets the observation space (pre)
|
||||
Sets the observation space (pre).
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
"""
|
||||
|
||||
self.obs_space_pre = _obs_space_pre
|
||||
|
||||
def set_obs_space_post(self, _obs_space_post):
|
||||
"""
|
||||
Sets the observation space (post)
|
||||
Sets the observation space (post).
|
||||
|
||||
Args:
|
||||
_obs_space_post: The observation space after any actions are taken
|
||||
"""
|
||||
|
||||
self.obs_space_post = _obs_space_post
|
||||
|
||||
def set_reward(self, _reward):
|
||||
"""
|
||||
Sets the reward
|
||||
Sets the reward.
|
||||
|
||||
Args:
|
||||
_reward: The reward value
|
||||
"""
|
||||
|
||||
self.reward = _reward
|
||||
|
||||
def set_action_space(self, _action_space):
|
||||
"""
|
||||
Sets the action space
|
||||
Sets the action space.
|
||||
|
||||
Args:
|
||||
_action_space: The action space invoked by the agent
|
||||
"""
|
||||
|
||||
self.action_space = _action_space
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,40 +1,35 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Writes the Transaction log list out to file for evaluation to utilse
|
||||
"""
|
||||
"""Writes the Transaction log list out to file for evaluation to utilse."""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os.path
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from transactions.transaction import Transaction
|
||||
|
||||
def turn_action_space_to_array(_action_space):
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_action_space: The action space
|
||||
_action_space: The action space.
|
||||
"""
|
||||
|
||||
return_array = []
|
||||
for x in range(len(_action_space)):
|
||||
return_array.append(str(_action_space[x]))
|
||||
|
||||
return return_array
|
||||
|
||||
|
||||
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_obs_space: The observation space
|
||||
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
_obs_features: The number of features associated with the asset
|
||||
"""
|
||||
|
||||
return_array = []
|
||||
for x in range(_obs_assets):
|
||||
for y in range(_obs_features):
|
||||
@@ -42,15 +37,15 @@ def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
|
||||
return return_array
|
||||
|
||||
|
||||
def write_transaction_to_file(_transaction_list):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
Args:
|
||||
_transaction_list: The list of transactions from all steps and all episodes
|
||||
_num_episodes: The number of episodes that were conducted
|
||||
_num_episodes: The number of episodes that were conducted.
|
||||
"""
|
||||
|
||||
# Get the first transaction and use it to determine the makeup of the observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
@@ -59,46 +54,56 @@ def write_transaction_to_file(_transaction_list):
|
||||
obs_assets = template_transation.obs_space_post.shape[0]
|
||||
obs_features = template_transation.obs_space_post.shape[1]
|
||||
|
||||
# Create the action space headers array
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append('AS_' + str(x))
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = []
|
||||
obs_header_new = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
obs_header_initial.append('OSI_' + str(x) + '_' + str(y))
|
||||
obs_header_new.append('OSN_' + str(x) + '_' + str(y))
|
||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
||||
|
||||
# Open up a csv file
|
||||
header = ['Timestamp', 'Episode', 'Step', 'Reward']
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
now = datetime.now() # current date and time
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
try:
|
||||
path = 'outputs/results/'
|
||||
path = "outputs/results/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
|
||||
filename = "outputs/results/all_transactions_" + time + ".csv"
|
||||
csv_file = open(filename, 'w', encoding='UTF8', newline='')
|
||||
csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(header)
|
||||
|
||||
for transaction in _transaction_list:
|
||||
csv_data = [str(transaction.timestamp), str(transaction.episode_number), str(transaction.step_number), str(transaction.reward)]
|
||||
csv_data = csv_data + turn_action_space_to_array(transaction.action_space) + \
|
||||
turn_obs_space_to_array(transaction.obs_space_pre, obs_assets, obs_features) + \
|
||||
turn_obs_space_to_array(transaction.obs_space_post, obs_assets, obs_features)
|
||||
csv_data = [
|
||||
str(transaction.timestamp),
|
||||
str(transaction.episode_number),
|
||||
str(transaction.step_number),
|
||||
str(transaction.reward),
|
||||
]
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_pre, obs_assets, obs_features
|
||||
)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_post, obs_assets, obs_features
|
||||
)
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
csv_file.close()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not save the transaction file")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
1
tests/conftest.py
Normal file
1
tests/conftest.py
Normal file
@@ -0,0 +1 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
@@ -1,61 +1,48 @@
|
||||
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Used to tes the ACL functions
|
||||
"""
|
||||
"""Used to tes the ACL functions."""
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
|
||||
from acl.acl_rule import ACLRule
|
||||
from acl.access_control_list import AccessControlList
|
||||
|
||||
def test_acl_address_match_1():
|
||||
"""
|
||||
Test that matching IP addresses produce True
|
||||
"""
|
||||
|
||||
"""Test that matching IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""
|
||||
Test that mismatching IP addresses produce False
|
||||
"""
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""Test that mismatching IP addresses produce False."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""
|
||||
Test the ANY condition for source IP addresses produce True
|
||||
"""
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""Test the ANY condition for source IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""
|
||||
Test the ANY condition for dest IP addresses produce True
|
||||
"""
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""Test the ANY condition for dest IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
def test_check_acl_block_affirmative():
|
||||
"""
|
||||
Test the block function (affirmative)
|
||||
"""
|
||||
|
||||
def test_check_acl_block_affirmative():
|
||||
"""Test the block function (affirmative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
|
||||
@@ -66,15 +53,19 @@ def test_check_acl_block_affirmative():
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
|
||||
acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
|
||||
|
||||
def test_check_acl_block_negative():
|
||||
"""
|
||||
Test the block function (negative)
|
||||
"""
|
||||
|
||||
def test_check_acl_block_negative():
|
||||
"""Test the block function (negative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
|
||||
@@ -85,21 +76,27 @@ def test_check_acl_block_negative():
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
|
||||
acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
|
||||
|
||||
def test_rule_hash():
|
||||
"""
|
||||
Test the rule hash
|
||||
"""
|
||||
|
||||
def test_rule_hash():
|
||||
"""Test the rule hash."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_remote = acl.get_dictionary_hash(
|
||||
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
|
||||
)
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
0
tests/test_reward.py
Normal file
0
tests/test_reward.py
Normal file
Reference in New Issue
Block a user