Ran pre-commit hook on all files and performed changes to fix flake8 failures
This commit is contained in:
12
.flake8
Normal file
12
.flake8
Normal file
@@ -0,0 +1,12 @@
|
||||
[flake8]
|
||||
max-line-length=120
|
||||
extend-ignore =
|
||||
D105
|
||||
D107
|
||||
D100
|
||||
D104
|
||||
E203
|
||||
E712
|
||||
D401
|
||||
exclude =
|
||||
docs/source/*
|
||||
@@ -3,12 +3,15 @@
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
import datetime
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import furo
|
||||
|
||||
import furo # noqa
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../"))
|
||||
|
||||
|
||||
@@ -33,7 +36,6 @@ templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
|
||||
@@ -22,5 +22,3 @@ The latest release of PrimAITE has been tested against the following versions of
|
||||
* gym 0.21.0
|
||||
* matplotlib 3.6.2
|
||||
* stable_baselines_3 1.6.2
|
||||
|
||||
|
||||
|
||||
@@ -84,5 +84,3 @@ In order to execute a session, carry out the following steps:
|
||||
2. Start a console window (type “CMD” in path window, or start a console window first and navigate to “[Install Directory]\\Primaite\\Primaite\\”)
|
||||
3. Type “python main.py”
|
||||
4. The session will start with an output indicating the current episode, and average reward value for the episode
|
||||
|
||||
|
||||
|
||||
1
setup.py
1
setup.py
@@ -10,6 +10,7 @@ class bdist_wheel(_bdist_wheel): # noqa
|
||||
# Source: https://stackoverflow.com/a/45150383
|
||||
self.root_is_pure = False
|
||||
|
||||
|
||||
setup(
|
||||
cmdclass={
|
||||
"bdist_wheel": bdist_wheel,
|
||||
|
||||
@@ -1,25 +1,19 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
A class that implements the access control list implementation for the network
|
||||
"""
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
|
||||
class AccessControlList():
|
||||
"""
|
||||
Access Control List class
|
||||
"""
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Init
|
||||
"""
|
||||
|
||||
self.acl = {} # A dictionary of ACL Rules
|
||||
"""Init."""
|
||||
self.acl = {} # A dictionary of ACL Rules
|
||||
|
||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||
"""
|
||||
Checks for IP address matches
|
||||
Checks for IP address matches.
|
||||
|
||||
Args:
|
||||
_rule: The rule being checked
|
||||
@@ -29,18 +23,28 @@ class AccessControlList():
|
||||
Returns:
|
||||
True if match; False otherwise.
|
||||
"""
|
||||
|
||||
if ((_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) or
|
||||
(_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address) or
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY") or
|
||||
(_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")):
|
||||
if (
|
||||
(
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == "ANY"
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == "ANY"
|
||||
)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
"""
|
||||
Checks for rules that block a protocol / port
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
Args:
|
||||
_source_ip_address: the source IP address to check
|
||||
@@ -51,11 +55,17 @@ class AccessControlList():
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if ((rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and
|
||||
(str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY")):
|
||||
if self.check_address_match(
|
||||
rule_value, _source_ip_address, _dest_ip_address
|
||||
):
|
||||
if (
|
||||
rule_value.get_protocol() == _protocol
|
||||
or rule_value.get_protocol() == "ANY"
|
||||
) and (
|
||||
str(rule_value.get_port()) == str(_port)
|
||||
or rule_value.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule_value.get_permission() == "DENY":
|
||||
return True
|
||||
@@ -67,7 +77,7 @@ class AccessControlList():
|
||||
|
||||
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Adds a new rule
|
||||
Adds a new rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -76,14 +86,13 @@ class AccessControlList():
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Removes a rule
|
||||
Removes a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -92,25 +101,21 @@ class AccessControlList():
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
try:
|
||||
self.acl.pop(hash_value)
|
||||
except:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def remove_all_rules(self):
|
||||
"""
|
||||
Removes all rules
|
||||
"""
|
||||
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Produces a hash value for a rule
|
||||
Produces a hash value for a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
@@ -122,13 +127,6 @@ class AccessControlList():
|
||||
Returns:
|
||||
Hash value based on rule parameters.
|
||||
"""
|
||||
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
A class that implements an access control list rule
|
||||
"""
|
||||
"""A class that implements an access control list rule."""
|
||||
|
||||
class ACLRule():
|
||||
"""
|
||||
Access Control List Rule class
|
||||
"""
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_permission: The permission (ALLOW or DENY)
|
||||
@@ -19,7 +16,6 @@ class ACLRule():
|
||||
_protocol: The rule protocol
|
||||
_port: The rule port
|
||||
"""
|
||||
|
||||
self.permission = _permission
|
||||
self.source_ip = _source_ip
|
||||
self.dest_ip = _dest_ip
|
||||
@@ -28,47 +24,45 @@ class ACLRule():
|
||||
|
||||
def __hash__(self):
|
||||
"""
|
||||
Override the hash function
|
||||
Override the hash function.
|
||||
|
||||
Returns:
|
||||
Returns hash of core parameters.
|
||||
"""
|
||||
|
||||
return hash((self.permission, self.source_ip, self.dest_ip, self.protocol, self.port))
|
||||
return hash(
|
||||
(self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)
|
||||
)
|
||||
|
||||
def get_permission(self):
|
||||
"""
|
||||
Gets the permission attribute
|
||||
Gets the permission attribute.
|
||||
|
||||
Returns:
|
||||
Returns permission attribute
|
||||
"""
|
||||
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self):
|
||||
"""
|
||||
Gets the source IP address attribute
|
||||
Gets the source IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns source IP address attribute
|
||||
"""
|
||||
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self):
|
||||
"""
|
||||
Gets the desintation IP address attribute
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns destination IP address attribute
|
||||
"""
|
||||
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self):
|
||||
"""
|
||||
Gets the protocol attribute
|
||||
Gets the protocol attribute.
|
||||
|
||||
Returns:
|
||||
Returns protocol attribute
|
||||
@@ -77,12 +71,9 @@ class ACLRule():
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets the port attribute
|
||||
Gets the port attribute.
|
||||
|
||||
Returns:
|
||||
Returns port attribute
|
||||
"""
|
||||
|
||||
return self.port
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
@@ -1,28 +1,24 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The config class
|
||||
"""
|
||||
"""The config class."""
|
||||
|
||||
|
||||
class config_values_main(object):
|
||||
"""
|
||||
Class to hold main config values
|
||||
"""
|
||||
"""Class to hold main config values."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Init
|
||||
"""
|
||||
|
||||
"""Init."""
|
||||
# Generic
|
||||
self.agent_identifier = "" # the agent in use
|
||||
self.num_episodes = 0 # number of episodes to train over
|
||||
self.num_steps = 0 # number of steps in an episode
|
||||
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||
self.config_filename_use_case = "" # the filename for the Use Case config file
|
||||
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
|
||||
self.agent_identifier = "" # the agent in use
|
||||
self.num_episodes = 0 # number of episodes to train over
|
||||
self.num_steps = 0 # number of steps in an episode
|
||||
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||
self.config_filename_use_case = "" # the filename for the Use Case config file
|
||||
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
|
||||
|
||||
# Environment
|
||||
self.observation_space_high_value = 0 # The high value for the observation space
|
||||
self.observation_space_high_value = (
|
||||
0 # The high value for the observation space
|
||||
)
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
@@ -59,11 +55,15 @@ class config_values_main(object):
|
||||
self.repairing_should_be_good = 0
|
||||
self.repairing_should_be_restoring = 0
|
||||
self.repairing_should_be_corrupt = 0
|
||||
self.repairing_should_be_destroyed = 0 # Repairing does not fix destroyed state - you need to restore
|
||||
self.repairing_should_be_destroyed = (
|
||||
0 # Repairing does not fix destroyed state - you need to restore
|
||||
)
|
||||
self.repairing = 0
|
||||
self.restoring_should_be_good = 0
|
||||
self.restoring_should_be_repairing = 0
|
||||
self.restoring_should_be_corrupt = 0 # Not the optimal method (as repair will fix corruption)
|
||||
self.restoring_should_be_corrupt = (
|
||||
0 # Not the optimal method (as repair will fix corruption)
|
||||
)
|
||||
self.restoring_should_be_destroyed = 0
|
||||
self.restoring = 0
|
||||
self.corrupt_should_be_good = 0
|
||||
@@ -82,10 +82,9 @@ class config_values_main(object):
|
||||
self.green_ier_blocked = 0
|
||||
|
||||
# Patching / Reset
|
||||
self.os_patching_duration = 0 # The time taken to patch the OS
|
||||
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
|
||||
self.service_patching_duration = 0 # The time taken to patch a service
|
||||
self.file_system_repairing_limit = 0 # The time take to repair a file
|
||||
self.file_system_restoring_limit = 0 # The time take to restore a file
|
||||
self.file_system_scanning_limit = 0 # The time taken to scan the file system
|
||||
|
||||
self.os_patching_duration = 0 # The time taken to patch the OS
|
||||
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
|
||||
self.service_patching_duration = 0 # The time taken to patch a service
|
||||
self.file_system_repairing_limit = 0 # The time take to repair a file
|
||||
self.file_system_restoring_limit = 0 # The time take to restore a file
|
||||
self.file_system_scanning_limit = 0 # The time taken to scan the file system
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Enumerations for APE
|
||||
"""
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TYPE(Enum):
|
||||
"""
|
||||
Node type enumeration
|
||||
"""
|
||||
"""Node type enumeration."""
|
||||
|
||||
CCTV = 1
|
||||
SWITCH = 2
|
||||
@@ -21,10 +18,9 @@ class TYPE(Enum):
|
||||
ACTUATOR = 9
|
||||
SERVER = 10
|
||||
|
||||
|
||||
class PRIORITY(Enum):
|
||||
"""
|
||||
Node priority enumeration
|
||||
"""
|
||||
"""Node priority enumeration."""
|
||||
|
||||
P1 = 1
|
||||
P2 = 2
|
||||
@@ -32,48 +28,43 @@ class PRIORITY(Enum):
|
||||
P4 = 4
|
||||
P5 = 5
|
||||
|
||||
|
||||
class HARDWARE_STATE(Enum):
|
||||
"""
|
||||
Node hardware state enumeration
|
||||
"""
|
||||
"""Node hardware state enumeration."""
|
||||
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESETTING = 3
|
||||
|
||||
|
||||
class SOFTWARE_STATE(Enum):
|
||||
"""
|
||||
O/S or Service state enumeration
|
||||
"""
|
||||
"""O/S or Service state enumeration."""
|
||||
|
||||
GOOD = 1
|
||||
PATCHING = 2
|
||||
COMPROMISED = 3
|
||||
OVERWHELMED = 4
|
||||
|
||||
|
||||
class NODE_POL_TYPE(Enum):
|
||||
"""
|
||||
Node Pattern of Life type enumeration
|
||||
"""
|
||||
"""Node Pattern of Life type enumeration."""
|
||||
|
||||
OPERATING = 1
|
||||
OS = 2
|
||||
SERVICE = 3
|
||||
FILE = 4
|
||||
|
||||
|
||||
class NODE_POL_INITIATOR(Enum):
|
||||
"""
|
||||
Node Pattern of Life initiator enumeration
|
||||
"""
|
||||
"""Node Pattern of Life initiator enumeration."""
|
||||
|
||||
DIRECT = 1
|
||||
IER = 2
|
||||
SERVICE = 3
|
||||
|
||||
|
||||
class PROTOCOL(Enum):
|
||||
"""
|
||||
Service protocol enumeration
|
||||
"""
|
||||
"""Service protocol enumeration."""
|
||||
|
||||
LDAP = 0
|
||||
FTP = 1
|
||||
@@ -84,18 +75,16 @@ class PROTOCOL(Enum):
|
||||
TCP = 6
|
||||
NONE = 7
|
||||
|
||||
|
||||
class ACTION_TYPE(Enum):
|
||||
"""
|
||||
Action type enumeration
|
||||
"""
|
||||
"""Action type enumeration."""
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
|
||||
|
||||
class FILE_SYSTEM_STATE(Enum):
|
||||
"""
|
||||
File System State
|
||||
"""
|
||||
"""File System State."""
|
||||
|
||||
GOOD = 1
|
||||
CORRUPT = 2
|
||||
|
||||
@@ -1,59 +1,47 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The protocol class
|
||||
"""
|
||||
"""The protocol class."""
|
||||
|
||||
|
||||
class Protocol(object):
|
||||
"""
|
||||
Protocol class
|
||||
"""
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_name: The protocol name
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
self.load = 0 # bps
|
||||
self.load = 0 # bps
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
Gets the protocol name
|
||||
Gets the protocol name.
|
||||
|
||||
Returns:
|
||||
The protocol name
|
||||
"""
|
||||
|
||||
return self.name
|
||||
|
||||
def get_load(self):
|
||||
"""
|
||||
Gets the protocol load
|
||||
Gets the protocol load.
|
||||
|
||||
Returns:
|
||||
The protocol load (bps)
|
||||
"""
|
||||
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load):
|
||||
"""
|
||||
Adds load to the protocol
|
||||
Adds load to the protocol.
|
||||
|
||||
Args:
|
||||
_load: The load to add
|
||||
"""
|
||||
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self):
|
||||
"""
|
||||
Clears the load on this protocol
|
||||
"""
|
||||
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The Service class
|
||||
"""
|
||||
"""The Service class."""
|
||||
|
||||
from primaite.common.enums import SOFTWARE_STATE
|
||||
|
||||
|
||||
class Service(object):
|
||||
"""
|
||||
Service class
|
||||
"""
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, _name, _port, _state):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_name: The service name
|
||||
_port: The service port
|
||||
_state: The service state
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
self.port = _port
|
||||
self.state = _state
|
||||
@@ -27,74 +23,61 @@ class Service(object):
|
||||
|
||||
def set_name(self, _name):
|
||||
"""
|
||||
Sets the service name
|
||||
Sets the service name.
|
||||
|
||||
Args:
|
||||
_name: The service name
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
Gets the service name
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
|
||||
return self.name
|
||||
|
||||
def set_port(self, _port):
|
||||
"""
|
||||
Sets the service port
|
||||
Sets the service port.
|
||||
|
||||
Args:
|
||||
_port: The service port
|
||||
"""
|
||||
|
||||
self.port = _port
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets the service port
|
||||
Gets the service port.
|
||||
|
||||
Returns:
|
||||
The service port
|
||||
"""
|
||||
|
||||
return self.port
|
||||
|
||||
def set_state(self, _state):
|
||||
"""
|
||||
Sets the service state
|
||||
Sets the service state.
|
||||
|
||||
Args:
|
||||
_state: The service state
|
||||
"""
|
||||
|
||||
self.state = _state
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the service state
|
||||
Gets the service state.
|
||||
|
||||
Returns:
|
||||
The service state
|
||||
"""
|
||||
|
||||
return self.state
|
||||
|
||||
def reduce_patching_count(self):
|
||||
"""
|
||||
Reduces the patching count for the service
|
||||
"""
|
||||
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self.state = SOFTWARE_STATE.GOOD
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
steps: 128
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '80'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: PC1
|
||||
@@ -19,9 +19,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: SERVER
|
||||
@@ -33,9 +33,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: PC2
|
||||
@@ -47,9 +47,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '4'
|
||||
name: SWITCH1
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
steps: 128
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '80'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: PC1
|
||||
@@ -19,9 +19,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: PC2
|
||||
@@ -33,9 +33,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: PC3
|
||||
@@ -47,9 +47,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '4'
|
||||
name: PC4
|
||||
@@ -61,9 +61,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '5'
|
||||
name: SWITCH1
|
||||
@@ -85,9 +85,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '7'
|
||||
name: SWITCH2
|
||||
@@ -109,9 +109,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '9'
|
||||
name: SERVER1
|
||||
@@ -123,9 +123,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '10'
|
||||
name: SERVER2
|
||||
@@ -137,9 +137,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '11'
|
||||
name: link1
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
steps: 256
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '80'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: PC1
|
||||
@@ -19,9 +19,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: PC2
|
||||
@@ -33,9 +33,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: SWITCH1
|
||||
@@ -57,9 +57,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '5'
|
||||
name: link1
|
||||
|
||||
@@ -4,14 +4,14 @@
|
||||
steps: 256
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: CLIENT_1
|
||||
@@ -23,12 +23,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: CLIENT_2
|
||||
@@ -40,9 +40,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: SWITCH_1
|
||||
@@ -64,12 +64,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
@@ -81,12 +81,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '6'
|
||||
name: SWITCH_2
|
||||
@@ -108,12 +108,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '8'
|
||||
name: DATABASE_SERVER
|
||||
@@ -125,15 +125,15 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '9'
|
||||
name: BACKUP_SERVER
|
||||
@@ -145,9 +145,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
|
||||
@@ -4,14 +4,14 @@
|
||||
steps: 256
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: CLIENT_1
|
||||
@@ -23,12 +23,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '2'
|
||||
name: CLIENT_2
|
||||
@@ -40,9 +40,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '3'
|
||||
name: SWITCH_1
|
||||
@@ -64,12 +64,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
@@ -81,12 +81,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '6'
|
||||
name: SWITCH_2
|
||||
@@ -108,12 +108,12 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '8'
|
||||
name: DATABASE_SERVER
|
||||
@@ -125,15 +125,15 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- itemType: NODE
|
||||
id: '9'
|
||||
name: BACKUP_SERVER
|
||||
@@ -145,9 +145,9 @@
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- itemType: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,39 +1,45 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Main environment module containing the PRIMmary AI Training Evironment (Primaite) class
|
||||
"""
|
||||
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
import copy
|
||||
import csv
|
||||
import yaml
|
||||
import os.path
|
||||
import logging
|
||||
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
import os.path
|
||||
from datetime import datetime
|
||||
|
||||
from primaite.common.enums import *
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import yaml
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.enums import (
|
||||
ACTION_TYPE,
|
||||
FILE_SYSTEM_STATE,
|
||||
HARDWARE_STATE,
|
||||
NODE_POL_INITIATOR,
|
||||
NODE_POL_TYPE,
|
||||
PRIORITY,
|
||||
SOFTWARE_STATE,
|
||||
TYPE,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.common.service import Service
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
"""
|
||||
PRIMmary AI Training Evironment (Primaite) class
|
||||
"""
|
||||
"""PRIMmary AI Training Evironment (Primaite) class."""
|
||||
|
||||
# Observation / Action Space contants
|
||||
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
|
||||
@@ -42,11 +48,11 @@ class Primaite(Env):
|
||||
ACTION_SPACE_ACL_ACTION_VALUES = 3
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
|
||||
|
||||
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
|
||||
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
|
||||
|
||||
def __init__(self, _config_values, _transaction_list):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_episode_steps: The number of steps for the episode
|
||||
@@ -54,7 +60,6 @@ class Primaite(Env):
|
||||
_transaction_list: The list of transactions to populate
|
||||
_agent_identifier: Identifier for the agent
|
||||
"""
|
||||
|
||||
super(Primaite, self).__init__()
|
||||
|
||||
# Take a copy of the config values
|
||||
@@ -140,10 +145,12 @@ class Primaite(Env):
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
try:
|
||||
self.config_file = open("config/" + self.config_values.config_filename_use_case, "r")
|
||||
self.config_file = open(
|
||||
"config/" + self.config_values.config_filename_use_case, "r"
|
||||
)
|
||||
self.config_data = yaml.safe_load(self.config_file)
|
||||
self.load_config()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not load the environment configuration")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
@@ -162,17 +169,17 @@ class Primaite(Env):
|
||||
try:
|
||||
plt.tight_layout()
|
||||
nx.draw_networkx(self.network, with_labels=True)
|
||||
now = datetime.now() # current date and time
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
path = 'outputs/diagrams'
|
||||
path = "outputs/diagrams"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/diagrams/network_" + time + ".png"
|
||||
plt.savefig(filename, format="PNG")
|
||||
plt.clf()
|
||||
except Exception as a:
|
||||
except Exception:
|
||||
logging.error("Could not save network diagram")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
print("Could not save network diagram")
|
||||
@@ -194,16 +201,22 @@ class Primaite(Env):
|
||||
# - service F state | service F loading
|
||||
# - service G state | service G loading
|
||||
|
||||
# Calculate the number of items that need to be included in the observation space
|
||||
# Calculate the number of items that need to be included in the
|
||||
# observation space
|
||||
num_items = self.num_links + self.num_nodes
|
||||
# Set the number of observation parameters, being # of services plus id, operating state, file system state and O/S state (i.e. 4)
|
||||
self.num_observation_parameters = self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
||||
# Set the number of observation parameters, being # of services plus id,
|
||||
# operating state, file system state and O/S state (i.e. 4)
|
||||
self.num_observation_parameters = (
|
||||
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
||||
)
|
||||
# Define the observation shape
|
||||
self.observation_shape = (num_items, self.num_observation_parameters)
|
||||
self.observation_space = spaces.Box(low=0,
|
||||
high=self.config_values.observation_space_high_value,
|
||||
shape=self.observation_shape,
|
||||
dtype=np.int64)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=self.config_values.observation_space_high_value,
|
||||
shape=self.observation_shape,
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
# This is the observation that is sent back via the rest and step functions
|
||||
self.env_obs = np.zeros(self.observation_shape, dtype=np.int64)
|
||||
@@ -216,7 +229,14 @@ class Primaite(Env):
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state)
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore)
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service)
|
||||
self.action_space = spaces.MultiDiscrete([self.num_nodes, self.ACTION_SPACE_NODE_PROPERTY_VALUES, self.ACTION_SPACE_NODE_ACTION_VALUES, self.num_services])
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.num_nodes,
|
||||
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
|
||||
self.ACTION_SPACE_NODE_ACTION_VALUES,
|
||||
self.num_services,
|
||||
]
|
||||
)
|
||||
else:
|
||||
logging.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
@@ -226,35 +246,45 @@ class Primaite(Env):
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_space = spaces.MultiDiscrete([self.ACTION_SPACE_ACL_ACTION_VALUES, self.ACTION_SPACE_ACL_PERMISSION_VALUES, self.num_nodes + 1, self.num_nodes + 1, self.num_services + 1, self.num_ports + 1])
|
||||
self.action_space = spaces.MultiDiscrete(
|
||||
[
|
||||
self.ACTION_SPACE_ACL_ACTION_VALUES,
|
||||
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
|
||||
self.num_nodes + 1,
|
||||
self.num_nodes + 1,
|
||||
self.num_services + 1,
|
||||
self.num_ports + 1,
|
||||
]
|
||||
)
|
||||
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
header = ['Episode', 'Average Reward']
|
||||
header = ["Episode", "Average Reward"]
|
||||
|
||||
# Check whether the output/rerults folder exists (doesn't exist by default install)
|
||||
path = 'outputs/results/'
|
||||
path = "outputs/results/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/results/average_reward_per_episode_" + time + ".csv"
|
||||
self.csv_file = open(filename, 'w', encoding='UTF8', newline='')
|
||||
self.csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
self.csv_writer = csv.writer(self.csv_file)
|
||||
self.csv_writer.writerow(header)
|
||||
except Exception as e:
|
||||
logging.error("Could not create csv file to hold average reward per episode")
|
||||
except Exception:
|
||||
logging.error(
|
||||
"Could not create csv file to hold average reward per episode"
|
||||
)
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
AI Gym Reset function
|
||||
AI Gym Reset function.
|
||||
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
"""
|
||||
|
||||
csv_data = self.episode_count, self.average_reward
|
||||
self.csv_writer.writerow(csv_data)
|
||||
|
||||
@@ -280,7 +310,7 @@ class Primaite(Env):
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
AI Gym Step function
|
||||
AI Gym Step function.
|
||||
|
||||
Args:
|
||||
action: Action space from agent
|
||||
@@ -291,7 +321,6 @@ class Primaite(Env):
|
||||
done: Indicates episode is complete if True
|
||||
step_info: Additional information relating to this step
|
||||
"""
|
||||
|
||||
if self.step_count == 0:
|
||||
print("Episode: " + str(self.episode_count) + " running")
|
||||
|
||||
@@ -299,14 +328,16 @@ class Primaite(Env):
|
||||
done = False
|
||||
|
||||
self.step_count += 1
|
||||
#print("Episode step: " + str(self.stepCount))
|
||||
# print("Episode step: " + str(self.stepCount))
|
||||
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(datetime.now(), self.agent_identifier, self.episode_count, self.step_count)
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
|
||||
# Load the action space into the transaction
|
||||
@@ -316,18 +347,41 @@ class Primaite(Env):
|
||||
self.apply_time_based_updates()
|
||||
|
||||
# 2. Apply PoL
|
||||
apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(self.network, self.nodes, self.links, self.green_iers, self.acl, self.step_count) # Network PoL
|
||||
apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.green_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_pol = copy.deepcopy(self.nodes)
|
||||
self.links_post_pol = copy.deepcopy(self.links)
|
||||
# Reference
|
||||
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(self.network_reference, self.nodes_reference, self.links_reference, self.green_iers, self.acl, self.step_count) # Network PoL
|
||||
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
self.links_reference,
|
||||
self.green_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
|
||||
# 3. Implement Red Action
|
||||
apply_red_agent_iers(self.network, self.nodes, self.links, self.red_iers, self.acl, self.step_count)
|
||||
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
|
||||
apply_red_agent_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.red_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
apply_red_agent_node_pol(
|
||||
self.nodes, self.red_iers, self.red_node_pol, self.step_count
|
||||
)
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_red = copy.deepcopy(self.nodes)
|
||||
self.links_post_red = copy.deepcopy(self.links)
|
||||
@@ -335,31 +389,55 @@ class Primaite(Env):
|
||||
# 4. Implement Blue Action
|
||||
self.interpret_action_and_apply(action)
|
||||
|
||||
# 5. Reapply normal and Red agent IER PoL, as we need to see what effect the blue agent action has had (if any) on link status
|
||||
# 5. Reapply normal and Red agent IER PoL, as we need to see what
|
||||
# effect the blue agent action has had (if any) on link status
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
apply_iers(self.network, self.nodes, self.links, self.green_iers, self.acl, self.step_count)
|
||||
apply_red_agent_iers(self.network, self.nodes, self.links, self.red_iers, self.acl, self.step_count)
|
||||
apply_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.green_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
apply_red_agent_iers(
|
||||
self.network,
|
||||
self.nodes,
|
||||
self.links,
|
||||
self.red_iers,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_blue = copy.deepcopy(self.nodes)
|
||||
self.links_post_blue = copy.deepcopy(self.links)
|
||||
|
||||
# 6. Calculate reward signal (for RL)
|
||||
reward = calculate_reward_function(self.nodes_post_pol, self.nodes_post_blue, self.nodes_reference, self.green_iers, self.red_iers, self.step_count, self.config_values)
|
||||
#print("Step reward: " + str(reward))
|
||||
reward = calculate_reward_function(
|
||||
self.nodes_post_pol,
|
||||
self.nodes_post_blue,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.config_values,
|
||||
)
|
||||
# print("Step reward: " + str(reward))
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
if self.config_values.session_type == "EVALUATION":
|
||||
# For evaluation, need to trigger the done value = True when step count is reached in order to prevent neverending episode
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
print("Average reward: " + str(self.average_reward))
|
||||
# Load the reward into the transaction
|
||||
transaction.set_reward(reward)
|
||||
|
||||
# 7. Output Verbose
|
||||
#self.output_link_status()
|
||||
# self.output_link_status()
|
||||
|
||||
# 8. Update env_obs
|
||||
self.update_environent_obs()
|
||||
@@ -373,38 +451,33 @@ class Primaite(Env):
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def __close__(self):
|
||||
"""
|
||||
Override close function
|
||||
"""
|
||||
|
||||
"""Override close function."""
|
||||
self.csv_file.close()
|
||||
self.config_file.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""
|
||||
Initialise the Access Control List
|
||||
"""
|
||||
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
|
||||
def output_link_status(self):
|
||||
"""
|
||||
Output the link status of all links to the console
|
||||
"""
|
||||
|
||||
"""Output the link status of all links to the console."""
|
||||
for link_key, link_value in self.links.items():
|
||||
print("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.get_protocol_list():
|
||||
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
print(
|
||||
" Protocol: "
|
||||
+ protocol.get_name().name
|
||||
+ ", Load: "
|
||||
+ str(protocol.get_load())
|
||||
)
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
"""
|
||||
Applies agent actions to the nodes and Access Control List
|
||||
Applies agent actions to the nodes and Access Control List.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
|
||||
# At the moment, actions are only affecting nodes
|
||||
if self.action_type == ACTION_TYPE.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
@@ -413,12 +486,11 @@ class Primaite(Env):
|
||||
|
||||
def apply_actions_to_nodes(self, _action):
|
||||
"""
|
||||
Applies agent actions to the nodes
|
||||
Applies agent actions to the nodes.
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
|
||||
node_id = _action[0]
|
||||
node_property = _action[1]
|
||||
property_action = _action[2]
|
||||
@@ -427,7 +499,7 @@ class Primaite(Env):
|
||||
# Check that the action is requesting a valid node
|
||||
try:
|
||||
node = self.nodes[str(node_id)]
|
||||
except:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if node_property == 0:
|
||||
@@ -472,7 +544,9 @@ class Primaite(Env):
|
||||
return
|
||||
elif property_action == 1:
|
||||
# Patch (valid action if it's good or compromised)
|
||||
node.set_service_state(self.services_list[service_index], SOFTWARE_STATE.PATCHING)
|
||||
node.set_service_state(
|
||||
self.services_list[service_index], SOFTWARE_STATE.PATCHING
|
||||
)
|
||||
else:
|
||||
# Node is not of Service Type
|
||||
return
|
||||
@@ -488,7 +562,10 @@ class Primaite(Env):
|
||||
elif property_action == 2:
|
||||
# Repair
|
||||
# You cannot repair a destroyed file system - it needs restoring
|
||||
if node.get_file_system_state_actual() != FILE_SYSTEM_STATE.DESTROYED:
|
||||
if (
|
||||
node.get_file_system_state_actual()
|
||||
!= FILE_SYSTEM_STATE.DESTROYED
|
||||
):
|
||||
node.set_file_system_state(FILE_SYSTEM_STATE.REPAIRING)
|
||||
elif property_action == 3:
|
||||
# Restore
|
||||
@@ -501,12 +578,11 @@ class Primaite(Env):
|
||||
|
||||
def apply_actions_to_acl(self, _action):
|
||||
"""
|
||||
Applies agent actions to the Access Control List [TO DO]
|
||||
Applies agent actions to the Access Control List [TO DO].
|
||||
|
||||
Args:
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
|
||||
action_decision = _action[0]
|
||||
action_permission = _action[1]
|
||||
action_source_ip = _action[2]
|
||||
@@ -556,18 +632,31 @@ class Primaite(Env):
|
||||
# Now add or remove
|
||||
if action_decision == 1:
|
||||
# Add the rule
|
||||
self.acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
self.acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
elif action_decision == 2:
|
||||
# Remove the rule
|
||||
self.acl.remove_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
self.acl.remove_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
def apply_time_based_updates(self):
|
||||
"""
|
||||
Updates anything that needs to count down and then change state (e.g. reset / patching status)
|
||||
"""
|
||||
Updates anything that needs to count down and then change state.
|
||||
|
||||
e.g. reset / patching status
|
||||
"""
|
||||
for node_key, node in self.nodes.items():
|
||||
if node.get_state() == HARDWARE_STATE.RESETTING:
|
||||
node.update_resetting_status()
|
||||
@@ -605,10 +694,7 @@ class Primaite(Env):
|
||||
pass
|
||||
|
||||
def update_environent_obs(self):
|
||||
"""
|
||||
# Updates the observation space based on the node and link status
|
||||
"""
|
||||
|
||||
"""Updates the observation space based on the node and link status."""
|
||||
item_index = 0
|
||||
|
||||
# Do nodes first
|
||||
@@ -617,7 +703,9 @@ class Primaite(Env):
|
||||
self.env_obs[item_index][1] = node.get_state().value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.env_obs[item_index][2] = node.get_os_state().value
|
||||
self.env_obs[item_index][3] = node.get_file_system_state_observed().value
|
||||
self.env_obs[item_index][
|
||||
3
|
||||
] = node.get_file_system_state_observed().value
|
||||
else:
|
||||
self.env_obs[item_index][2] = 0
|
||||
self.env_obs[item_index][3] = 0
|
||||
@@ -625,7 +713,9 @@ class Primaite(Env):
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.services_list:
|
||||
if node.has_service(service):
|
||||
self.env_obs[item_index][service_index] = node.get_service_state(service).value
|
||||
self.env_obs[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
else:
|
||||
self.env_obs[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
@@ -650,10 +740,7 @@ class Primaite(Env):
|
||||
item_index += 1
|
||||
|
||||
def load_config(self):
|
||||
"""
|
||||
# Loads config data in order to build the environment configuration
|
||||
"""
|
||||
|
||||
"""Loads config data in order to build the environment configuration."""
|
||||
for item in self.config_data:
|
||||
if item["itemType"] == "NODE":
|
||||
# Create a node
|
||||
@@ -697,12 +784,11 @@ class Primaite(Env):
|
||||
|
||||
def create_node(self, item):
|
||||
"""
|
||||
Creates a node from config data
|
||||
Creates a node from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
# All nodes have these parameters
|
||||
node_id = item["id"]
|
||||
node_name = item["name"]
|
||||
@@ -712,19 +798,46 @@ class Primaite(Env):
|
||||
node_hardware_state = HARDWARE_STATE[item["hardwareState"]]
|
||||
|
||||
if node_base_type == "PASSIVE":
|
||||
node = PassiveNode(node_id, node_name, node_type, node_priority, node_hardware_state, self.config_values)
|
||||
node = PassiveNode(
|
||||
node_id,
|
||||
node_name,
|
||||
node_type,
|
||||
node_priority,
|
||||
node_hardware_state,
|
||||
self.config_values,
|
||||
)
|
||||
elif node_base_type == "ACTIVE":
|
||||
# Active nodes have IP address, operating system state and file system state
|
||||
node_ip_address = item["ipAddress"]
|
||||
node_software_state = SOFTWARE_STATE[item["softwareState"]]
|
||||
node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]]
|
||||
node = ActiveNode(node_id, node_name, node_type, node_priority, node_hardware_state, node_ip_address, node_software_state, node_file_system_state, self.config_values)
|
||||
node = ActiveNode(
|
||||
node_id,
|
||||
node_name,
|
||||
node_type,
|
||||
node_priority,
|
||||
node_hardware_state,
|
||||
node_ip_address,
|
||||
node_software_state,
|
||||
node_file_system_state,
|
||||
self.config_values,
|
||||
)
|
||||
elif node_base_type == "SERVICE":
|
||||
# Service nodes have IP address, operating system state, file system state and list of services
|
||||
node_ip_address = item["ipAddress"]
|
||||
node_software_state = SOFTWARE_STATE[item["softwareState"]]
|
||||
node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]]
|
||||
node = ServiceNode(node_id, node_name, node_type, node_priority, node_hardware_state, node_ip_address, node_software_state, node_file_system_state, self.config_values)
|
||||
node = ServiceNode(
|
||||
node_id,
|
||||
node_name,
|
||||
node_type,
|
||||
node_priority,
|
||||
node_hardware_state,
|
||||
node_ip_address,
|
||||
node_software_state,
|
||||
node_file_system_state,
|
||||
self.config_values,
|
||||
)
|
||||
node_services = item["services"]
|
||||
for service in node_services:
|
||||
service_protocol = service["name"]
|
||||
@@ -752,12 +865,11 @@ class Primaite(Env):
|
||||
|
||||
def create_link(self, item):
|
||||
"""
|
||||
Creates a link from config data
|
||||
Creates a link from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
link_id = item["id"]
|
||||
link_name = item["name"]
|
||||
link_bandwidth = item["bandwidth"]
|
||||
@@ -771,7 +883,13 @@ class Primaite(Env):
|
||||
self.network.add_edge(source_node, dest_node, id=link_name)
|
||||
|
||||
# Add link to link dictionary
|
||||
self.links[link_name] = Link(link_id, link_bandwidth, source_node.get_name(), dest_node.get_name(), self.services_list)
|
||||
self.links[link_name] = Link(
|
||||
link_id,
|
||||
link_bandwidth,
|
||||
source_node.get_name(),
|
||||
dest_node.get_name(),
|
||||
self.services_list,
|
||||
)
|
||||
|
||||
# Reference
|
||||
source_node_ref = self.nodes_reference[link_source]
|
||||
@@ -781,16 +899,21 @@ class Primaite(Env):
|
||||
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
|
||||
|
||||
# Add link to link dictionary (reference)
|
||||
self.links_reference[link_name] = Link(link_id, link_bandwidth, source_node_ref.get_name(), dest_node_ref.get_name(), self.services_list)
|
||||
self.links_reference[link_name] = Link(
|
||||
link_id,
|
||||
link_bandwidth,
|
||||
source_node_ref.get_name(),
|
||||
dest_node_ref.get_name(),
|
||||
self.services_list,
|
||||
)
|
||||
|
||||
def create_green_ier(self, item):
|
||||
"""
|
||||
Creates a green IER from config data
|
||||
Creates a green IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
ier_id = item["id"]
|
||||
ier_start_step = item["startStep"]
|
||||
ier_end_step = item["endStep"]
|
||||
@@ -802,16 +925,25 @@ class Primaite(Env):
|
||||
ier_mission_criticality = item["missionCriticality"]
|
||||
|
||||
# Create IER and add to green IER dictionary
|
||||
self.green_iers[ier_id] = IER(ier_id, ier_start_step, ier_end_step, ier_load, ier_protocol, ier_port, ier_source, ier_destination, ier_mission_criticality)
|
||||
self.green_iers[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
Creates a red IER from config data
|
||||
Creates a red IER from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
ier_id = item["id"]
|
||||
ier_start_step = item["startStep"]
|
||||
ier_end_step = item["endStep"]
|
||||
@@ -823,16 +955,25 @@ class Primaite(Env):
|
||||
ier_mission_criticality = item["missionCriticality"]
|
||||
|
||||
# Create IER and add to red IER dictionary
|
||||
self.red_iers[ier_id] = IER(ier_id, ier_start_step, ier_end_step, ier_load, ier_protocol, ier_port, ier_source, ier_destination, ier_mission_criticality)
|
||||
self.red_iers[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_green_pol(self, item):
|
||||
"""
|
||||
Creates a green PoL object from config data
|
||||
Creates a green PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
pol_id = item["id"]
|
||||
pol_start_step = item["startStep"]
|
||||
pol_end_step = item["endStep"]
|
||||
@@ -850,16 +991,23 @@ class Primaite(Env):
|
||||
pol_protocol = item["protocol"]
|
||||
pol_state = SOFTWARE_STATE[item["state"]]
|
||||
|
||||
self.node_pol[pol_id] = NodeStateInstructionGreen(pol_id, pol_start_step, pol_end_step, pol_node, pol_type, pol_protocol, pol_state)
|
||||
self.node_pol[pol_id] = NodeStateInstructionGreen(
|
||||
pol_id,
|
||||
pol_start_step,
|
||||
pol_end_step,
|
||||
pol_node,
|
||||
pol_type,
|
||||
pol_protocol,
|
||||
pol_state,
|
||||
)
|
||||
|
||||
def create_red_pol(self, item):
|
||||
"""
|
||||
Creates a red PoL object from config data
|
||||
Creates a red PoL object from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
pol_id = item["id"]
|
||||
pol_start_step = item["startStep"]
|
||||
pol_end_step = item["endStep"]
|
||||
@@ -880,32 +1028,48 @@ class Primaite(Env):
|
||||
pol_source_node_service = item["sourceNodeService"]
|
||||
pol_source_node_service_state = item["sourceNodeServiceState"]
|
||||
|
||||
self.red_node_pol[pol_id] = NodeStateInstructionRed(pol_id, pol_start_step, pol_end_step, pol_target_node_id, pol_initiator, pol_type, pol_protocol, pol_state, pol_source_node_id, pol_source_node_service, pol_source_node_service_state)
|
||||
self.red_node_pol[pol_id] = NodeStateInstructionRed(
|
||||
pol_id,
|
||||
pol_start_step,
|
||||
pol_end_step,
|
||||
pol_target_node_id,
|
||||
pol_initiator,
|
||||
pol_type,
|
||||
pol_protocol,
|
||||
pol_state,
|
||||
pol_source_node_id,
|
||||
pol_source_node_service,
|
||||
pol_source_node_service_state,
|
||||
)
|
||||
|
||||
def create_acl_rule(self, item):
|
||||
"""
|
||||
Creates an ACL rule from config data
|
||||
Creates an ACL rule from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
acl_rule_permission = item["permission"]
|
||||
acl_rule_source = item["source"]
|
||||
acl_rule_destination = item["destination"]
|
||||
acl_rule_protocol = item["protocol"]
|
||||
acl_rule_port = item["port"]
|
||||
|
||||
self.acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
self.acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
|
||||
def create_services_list(self, services):
|
||||
"""
|
||||
Creates a list of services (enum) from config data
|
||||
Creates a list of services (enum) from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the services
|
||||
"""
|
||||
|
||||
service_list = services["serviceList"]
|
||||
|
||||
for service in service_list:
|
||||
@@ -917,12 +1081,11 @@ class Primaite(Env):
|
||||
|
||||
def create_ports_list(self, ports):
|
||||
"""
|
||||
Creates a list of ports from config data
|
||||
Creates a list of ports from config data.
|
||||
|
||||
Args:
|
||||
item: A config data item representing the ports
|
||||
"""
|
||||
|
||||
ports_list = ports["portsList"]
|
||||
|
||||
for port in ports_list:
|
||||
@@ -934,31 +1097,30 @@ class Primaite(Env):
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info
|
||||
Extracts action_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing action info
|
||||
"""
|
||||
|
||||
self.action_type = ACTION_TYPE[action_info["type"]]
|
||||
|
||||
|
||||
def get_steps_info(self, steps_info):
|
||||
"""
|
||||
Extracts steps_info
|
||||
Extracts steps_info.
|
||||
|
||||
Args:
|
||||
item: A config data item representing steps info
|
||||
"""
|
||||
|
||||
self.episode_steps = int(steps_info["steps"])
|
||||
logging.info("Training episodes have " + str(self.episode_steps) + " steps")
|
||||
|
||||
def reset_environment(self):
|
||||
"""
|
||||
# Resets environment using config data config data in order to build the environment configuration
|
||||
"""
|
||||
# Resets environment.
|
||||
|
||||
Uses config data config data in order to build the environment
|
||||
configuration.
|
||||
"""
|
||||
for item in self.config_data:
|
||||
if item["itemType"] == "NODE":
|
||||
# Reset a node's state (normal and reference)
|
||||
@@ -970,7 +1132,6 @@ class Primaite(Env):
|
||||
# Do nothing (bad formatting or not relevant to reset)
|
||||
pass
|
||||
|
||||
|
||||
# Reset the IER status so they are not running initially
|
||||
# Green IERs
|
||||
for ier_key, ier_value in self.green_iers.items():
|
||||
@@ -981,12 +1142,11 @@ class Primaite(Env):
|
||||
|
||||
def reset_node(self, item):
|
||||
"""
|
||||
Resets the statuses of a node
|
||||
Resets the statuses of a node.
|
||||
|
||||
Args:
|
||||
item: A config data item
|
||||
"""
|
||||
|
||||
# All nodes have these parameters
|
||||
node_id = item["id"]
|
||||
node_base_type = item["baseType"]
|
||||
@@ -1027,10 +1187,3 @@ class Primaite(Env):
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Implements reward function
|
||||
"""
|
||||
|
||||
from primaite.common.enums import *
|
||||
"""Implements reward function."""
|
||||
from primaite.common.enums import FILE_SYSTEM_STATE, HARDWARE_STATE, SOFTWARE_STATE
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green_iers, red_iers, step_count, config_values):
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
):
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
Args:
|
||||
initial_nodes: The nodes before red and blue agents take effect
|
||||
@@ -20,7 +26,6 @@ def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green
|
||||
step_count: current step
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
reward_value = 0
|
||||
|
||||
# For each node, compare operating state, o/s operating state, service states
|
||||
@@ -29,19 +34,27 @@ def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
# Operating State
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
reward_value += score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# Operating System State
|
||||
if (isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode)):
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# Service State
|
||||
if (isinstance(final_node, ServiceNode)):
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
reward_value += score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
@@ -57,14 +70,17 @@ def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running():
|
||||
reward_value += config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
)
|
||||
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the operating state of a node
|
||||
Calculates score relating to the operating state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -72,7 +88,6 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
score = 0
|
||||
final_node_operating_state = final_node.get_state()
|
||||
initial_node_operating_state = initial_node.get_state()
|
||||
@@ -112,9 +127,10 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the operating system state of a node
|
||||
Calculates score relating to the operating system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -122,7 +138,6 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
score = 0
|
||||
final_node_os_state = final_node.get_os_state()
|
||||
initial_node_os_state = initial_node.get_os_state()
|
||||
@@ -164,9 +179,10 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
@@ -174,7 +190,6 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
|
||||
score = 0
|
||||
final_node_services = final_node.get_services()
|
||||
initial_node_services = initial_node.get_services()
|
||||
@@ -237,16 +252,16 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the file system state of a node
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
"""
|
||||
|
||||
score = 0
|
||||
final_node_file_system_state = final_node.get_file_system_state_actual()
|
||||
initial_node_file_system_state = initial_node.get_file_system_state_actual()
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The link class
|
||||
"""
|
||||
"""The link class."""
|
||||
|
||||
from primaite.common.protocol import Protocol
|
||||
from primaite.common.enums import *
|
||||
|
||||
|
||||
class Link(object):
|
||||
"""
|
||||
Link class
|
||||
"""
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The IER id
|
||||
@@ -22,7 +18,6 @@ class Link(object):
|
||||
_dest_node_name: The name of the destination node
|
||||
_protocols: The protocols to add to the link
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.bandwidth = _bandwidth
|
||||
self.source_node_name = _source_node_name
|
||||
@@ -35,72 +30,65 @@ class Link(object):
|
||||
|
||||
def add_protocol(self, _protocol):
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to be added (enum)
|
||||
"""
|
||||
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets link ID
|
||||
Gets link ID.
|
||||
|
||||
Returns:
|
||||
Link ID
|
||||
"""
|
||||
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self):
|
||||
"""
|
||||
Gets source node name
|
||||
Gets source node name.
|
||||
|
||||
Returns:
|
||||
Source node name
|
||||
"""
|
||||
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self):
|
||||
"""
|
||||
Gets destination node name
|
||||
Gets destination node name.
|
||||
|
||||
Returns:
|
||||
Destination node name
|
||||
"""
|
||||
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self):
|
||||
"""
|
||||
Gets bandwidth of link
|
||||
Gets bandwidth of link.
|
||||
|
||||
Returns:
|
||||
Link bandwidth (bps)
|
||||
"""
|
||||
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self):
|
||||
"""
|
||||
Gets list of protocols on this link
|
||||
Gets list of protocols on this link.
|
||||
|
||||
Returns:
|
||||
List of protocols on this link
|
||||
"""
|
||||
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self):
|
||||
"""
|
||||
Gets current total load on this link
|
||||
Gets current total load on this link.
|
||||
|
||||
Returns:
|
||||
Total load on this link (bps)
|
||||
"""
|
||||
|
||||
total_load = 0
|
||||
for protocol in self.protocol_list:
|
||||
total_load += protocol.get_load()
|
||||
@@ -108,13 +96,12 @@ class Link(object):
|
||||
|
||||
def add_protocol_load(self, _protocol, _load):
|
||||
"""
|
||||
Adds a loading to a protocol on this link
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to load
|
||||
_load: The amount to load (bps)
|
||||
"""
|
||||
|
||||
for protocol in self.protocol_list:
|
||||
if protocol.get_name() == _protocol:
|
||||
protocol.add_load(_load)
|
||||
@@ -122,11 +109,6 @@ class Link(object):
|
||||
pass
|
||||
|
||||
def clear_traffic(self):
|
||||
"""
|
||||
Clears all traffic on this link
|
||||
"""
|
||||
|
||||
"""Clears all traffic on this link."""
|
||||
for protocol in self.protocol_list:
|
||||
protocol.clear_load()
|
||||
|
||||
|
||||
|
||||
@@ -1,37 +1,31 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Primaite - main (harness) module
|
||||
Primaite - main (harness) module.
|
||||
|
||||
Coding Standards: PEP 8
|
||||
"""
|
||||
|
||||
from sys import exc_info
|
||||
import time
|
||||
import yaml
|
||||
import os.path
|
||||
import logging
|
||||
import os.path
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite.common.config_values_main import config_values_main
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
from primaite.common.config_values_main import config_values_main
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
# FUNCTIONS #
|
||||
|
||||
################################# FUNCTIONS ######################################
|
||||
|
||||
def run_generic():
|
||||
"""
|
||||
Run against a generic agent
|
||||
"""
|
||||
|
||||
"""Run against a generic agent."""
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
for step in range(0, config_values.num_steps):
|
||||
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
@@ -54,15 +48,20 @@ def run_generic():
|
||||
|
||||
|
||||
def run_stable_baselines3_ppo():
|
||||
"""
|
||||
Run against a stable_baselines3 PPO agent
|
||||
"""
|
||||
|
||||
"""Run against a stable_baselines3 PPO agent."""
|
||||
if config_values.load_agent == True:
|
||||
try:
|
||||
agent = PPO.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps)
|
||||
except:
|
||||
print("ERROR: Could not load agent at location: " + config_values.agent_load_file)
|
||||
agent = PPO.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
@@ -83,16 +82,22 @@ def run_stable_baselines3_ppo():
|
||||
|
||||
env.close()
|
||||
|
||||
def run_stable_baselines3_a2c():
|
||||
"""
|
||||
Run against a stable_baselines3 A2C agent
|
||||
"""
|
||||
|
||||
def run_stable_baselines3_a2c():
|
||||
"""Run against a stable_baselines3 A2C agent."""
|
||||
if config_values.load_agent == True:
|
||||
try:
|
||||
agent = A2C.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps)
|
||||
except:
|
||||
print("ERROR: Could not load agent at location: " + config_values.agent_load_file)
|
||||
agent = A2C.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
logging.error("Could not load agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
@@ -113,133 +118,203 @@ def run_stable_baselines3_a2c():
|
||||
|
||||
env.close()
|
||||
|
||||
def save_agent(_agent):
|
||||
"""
|
||||
Persist an agent (only works for stable baselines3 agents at present)
|
||||
"""
|
||||
|
||||
now = datetime.now() # current date and time
|
||||
def save_agent(_agent):
|
||||
"""Persist an agent (only works for stable baselines3 agents at present)."""
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
try:
|
||||
path = 'outputs/agents/'
|
||||
path = "outputs/agents/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
filename = "outputs/agents/agent_saved_" + time
|
||||
_agent.save(filename)
|
||||
logging.info("Trained agent saved as " + filename)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not save agent")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
def configure_logging():
|
||||
"""
|
||||
Configures logging
|
||||
"""
|
||||
|
||||
def configure_logging():
|
||||
"""Configures logging."""
|
||||
try:
|
||||
now = datetime.now() # current date and time
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
filename = "logs/app_" + time + ".log"
|
||||
path = 'logs/'
|
||||
path = "logs/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
logging.basicConfig(filename=filename, filemode='w', format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO)
|
||||
except:
|
||||
logging.basicConfig(
|
||||
filename=filename,
|
||||
filemode="w",
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%d-%b-%y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
except Exception:
|
||||
print("ERROR: Could not start logging")
|
||||
|
||||
def load_config_values():
|
||||
"""
|
||||
Loads the config values from the main config file into a config object
|
||||
"""
|
||||
|
||||
def load_config_values():
|
||||
"""Loads the config values from the main config file into a config object."""
|
||||
try:
|
||||
# Generic
|
||||
config_values.agent_identifier = config_data['agentIdentifier']
|
||||
config_values.num_episodes = int(config_data['numEpisodes'])
|
||||
config_values.time_delay = int(config_data['timeDelay'])
|
||||
config_values.config_filename_use_case = config_data['configFilename']
|
||||
config_values.session_type = config_data['sessionType']
|
||||
config_values.load_agent = bool(config_data['loadAgent'])
|
||||
config_values.agent_load_file = config_data['agentLoadFile']
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = config_data["configFilename"]
|
||||
config_values.session_type = config_data["sessionType"]
|
||||
config_values.load_agent = bool(config_data["loadAgent"])
|
||||
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||
# Environment
|
||||
config_values.observation_space_high_value = int(config_data['observationSpaceHighValue'])
|
||||
config_values.observation_space_high_value = int(
|
||||
config_data["observationSpaceHighValue"]
|
||||
)
|
||||
# Reward values
|
||||
# Generic
|
||||
config_values.all_ok = int(config_data['allOk'])
|
||||
config_values.all_ok = int(config_data["allOk"])
|
||||
# Node Operating State
|
||||
config_values.off_should_be_on = int(config_data['offShouldBeOn'])
|
||||
config_values.off_should_be_resetting = int(config_data['offShouldBeResetting'])
|
||||
config_values.on_should_be_off = int(config_data['onShouldBeOff'])
|
||||
config_values.on_should_be_resetting = int(config_data['onShouldBeResetting'])
|
||||
config_values.resetting_should_be_on = int(config_data['resettingShouldBeOn'])
|
||||
config_values.resetting_should_be_off = int(config_data['resettingShouldBeOff'])
|
||||
config_values.resetting = int(config_data['resetting'])
|
||||
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
|
||||
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
|
||||
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
|
||||
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
|
||||
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
|
||||
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
|
||||
config_values.resetting = int(config_data["resetting"])
|
||||
# Node O/S or Service State
|
||||
config_values.good_should_be_patching = int(config_data['goodShouldBePatching'])
|
||||
config_values.good_should_be_compromised = int(config_data['goodShouldBeCompromised'])
|
||||
config_values.good_should_be_overwhelmed = int(config_data['goodShouldBeOverwhelmed'])
|
||||
config_values.patching_should_be_good = int(config_data['patchingShouldBeGood'])
|
||||
config_values.patching_should_be_compromised = int(config_data['patchingShouldBeCompromised'])
|
||||
config_values.patching_should_be_overwhelmed = int(config_data['patchingShouldBeOverwhelmed'])
|
||||
config_values.patching = int(config_data['patching'])
|
||||
config_values.compromised_should_be_good = int(config_data['compromisedShouldBeGood'])
|
||||
config_values.compromised_should_be_patching = int(config_data['compromisedShouldBePatching'])
|
||||
config_values.compromised_should_be_overwhelmed = int(config_data['compromisedShouldBeOverwhelmed'])
|
||||
config_values.compromised = int(config_data['compromised'])
|
||||
config_values.overwhelmed_should_be_good = int(config_data['overwhelmedShouldBeGood'])
|
||||
config_values.overwhelmed_should_be_patching = int(config_data['overwhelmedShouldBePatching'])
|
||||
config_values.overwhelmed_should_be_compromised = int(config_data['overwhelmedShouldBeCompromised'])
|
||||
config_values.overwhelmed = int(config_data['overwhelmed'])
|
||||
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
|
||||
config_values.good_should_be_compromised = int(
|
||||
config_data["goodShouldBeCompromised"]
|
||||
)
|
||||
config_values.good_should_be_overwhelmed = int(
|
||||
config_data["goodShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
|
||||
config_values.patching_should_be_compromised = int(
|
||||
config_data["patchingShouldBeCompromised"]
|
||||
)
|
||||
config_values.patching_should_be_overwhelmed = int(
|
||||
config_data["patchingShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching = int(config_data["patching"])
|
||||
config_values.compromised_should_be_good = int(
|
||||
config_data["compromisedShouldBeGood"]
|
||||
)
|
||||
config_values.compromised_should_be_patching = int(
|
||||
config_data["compromisedShouldBePatching"]
|
||||
)
|
||||
config_values.compromised_should_be_overwhelmed = int(
|
||||
config_data["compromisedShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.compromised = int(config_data["compromised"])
|
||||
config_values.overwhelmed_should_be_good = int(
|
||||
config_data["overwhelmedShouldBeGood"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_patching = int(
|
||||
config_data["overwhelmedShouldBePatching"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_compromised = int(
|
||||
config_data["overwhelmedShouldBeCompromised"]
|
||||
)
|
||||
config_values.overwhelmed = int(config_data["overwhelmed"])
|
||||
# Node File System State
|
||||
config_values.good_should_be_repairing = int(config_data['goodShouldBeRepairing'])
|
||||
config_values.good_should_be_restoring = int(config_data['goodShouldBeRestoring'])
|
||||
config_values.good_should_be_corrupt = int(config_data['goodShouldBeCorrupt'])
|
||||
config_values.good_should_be_destroyed = int(config_data['goodShouldBeDestroyed'])
|
||||
config_values.repairing_should_be_good = int(config_data['repairingShouldBeGood'])
|
||||
config_values.repairing_should_be_restoring = int(config_data['repairingShouldBeRestoring'])
|
||||
config_values.repairing_should_be_corrupt = int(config_data['repairingShouldBeCorrupt'])
|
||||
config_values.repairing_should_be_destroyed = int(config_data['repairingShouldBeDestroyed'])
|
||||
config_values.repairing = int(config_data['repairing'])
|
||||
config_values.restoring_should_be_good = int(config_data['restoringShouldBeGood'])
|
||||
config_values.restoring_should_be_repairing = int(config_data['restoringShouldBeRepairing'])
|
||||
config_values.restoring_should_be_corrupt = int(config_data['restoringShouldBeCorrupt'])
|
||||
config_values.restoring_should_be_destroyed = int(config_data['restoringShouldBeDestroyed'])
|
||||
config_values.restoring = int(config_data['restoring'])
|
||||
config_values.corrupt_should_be_good = int(config_data['corruptShouldBeGood'])
|
||||
config_values.corrupt_should_be_repairing = int(config_data['corruptShouldBeRepairing'])
|
||||
config_values.corrupt_should_be_restoring = int(config_data['corruptShouldBeRestoring'])
|
||||
config_values.corrupt_should_be_destroyed = int(config_data['corruptShouldBeDestroyed'])
|
||||
config_values.corrupt = int(config_data['corrupt'])
|
||||
config_values.destroyed_should_be_good = int(config_data['destroyedShouldBeGood'])
|
||||
config_values.destroyed_should_be_repairing = int(config_data['destroyedShouldBeRepairing'])
|
||||
config_values.destroyed_should_be_restoring = int(config_data['destroyedShouldBeRestoring'])
|
||||
config_values.destroyed_should_be_corrupt = int(config_data['destroyedShouldBeCorrupt'])
|
||||
config_values.destroyed = int(config_data['destroyed'])
|
||||
config_values.scanning = int(config_data['scanning'])
|
||||
config_values.good_should_be_repairing = int(
|
||||
config_data["goodShouldBeRepairing"]
|
||||
)
|
||||
config_values.good_should_be_restoring = int(
|
||||
config_data["goodShouldBeRestoring"]
|
||||
)
|
||||
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
|
||||
config_values.good_should_be_destroyed = int(
|
||||
config_data["goodShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing_should_be_good = int(
|
||||
config_data["repairingShouldBeGood"]
|
||||
)
|
||||
config_values.repairing_should_be_restoring = int(
|
||||
config_data["repairingShouldBeRestoring"]
|
||||
)
|
||||
config_values.repairing_should_be_corrupt = int(
|
||||
config_data["repairingShouldBeCorrupt"]
|
||||
)
|
||||
config_values.repairing_should_be_destroyed = int(
|
||||
config_data["repairingShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing = int(config_data["repairing"])
|
||||
config_values.restoring_should_be_good = int(
|
||||
config_data["restoringShouldBeGood"]
|
||||
)
|
||||
config_values.restoring_should_be_repairing = int(
|
||||
config_data["restoringShouldBeRepairing"]
|
||||
)
|
||||
config_values.restoring_should_be_corrupt = int(
|
||||
config_data["restoringShouldBeCorrupt"]
|
||||
)
|
||||
config_values.restoring_should_be_destroyed = int(
|
||||
config_data["restoringShouldBeDestroyed"]
|
||||
)
|
||||
config_values.restoring = int(config_data["restoring"])
|
||||
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
|
||||
config_values.corrupt_should_be_repairing = int(
|
||||
config_data["corruptShouldBeRepairing"]
|
||||
)
|
||||
config_values.corrupt_should_be_restoring = int(
|
||||
config_data["corruptShouldBeRestoring"]
|
||||
)
|
||||
config_values.corrupt_should_be_destroyed = int(
|
||||
config_data["corruptShouldBeDestroyed"]
|
||||
)
|
||||
config_values.corrupt = int(config_data["corrupt"])
|
||||
config_values.destroyed_should_be_good = int(
|
||||
config_data["destroyedShouldBeGood"]
|
||||
)
|
||||
config_values.destroyed_should_be_repairing = int(
|
||||
config_data["destroyedShouldBeRepairing"]
|
||||
)
|
||||
config_values.destroyed_should_be_restoring = int(
|
||||
config_data["destroyedShouldBeRestoring"]
|
||||
)
|
||||
config_values.destroyed_should_be_corrupt = int(
|
||||
config_data["destroyedShouldBeCorrupt"]
|
||||
)
|
||||
config_values.destroyed = int(config_data["destroyed"])
|
||||
config_values.scanning = int(config_data["scanning"])
|
||||
# IER status
|
||||
config_values.red_ier_running = int(config_data['redIerRunning'])
|
||||
config_values.green_ier_blocked = int(config_data['greenIerBlocked'])
|
||||
config_values.red_ier_running = int(config_data["redIerRunning"])
|
||||
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
|
||||
# Patching / Reset durations
|
||||
config_values.os_patching_duration = int(config_data['osPatchingDuration'])
|
||||
config_values.node_reset_duration = int(config_data['nodeResetDuration'])
|
||||
config_values.service_patching_duration = int(config_data['servicePatchingDuration'])
|
||||
config_values.file_system_repairing_limit = int(config_data['fileSystemRepairingLimit'])
|
||||
config_values.file_system_restoring_limit = int(config_data['fileSystemRestoringLimit'])
|
||||
config_values.file_system_scanning_limit = int(config_data['fileSystemScanningLimit'])
|
||||
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
|
||||
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
|
||||
config_values.service_patching_duration = int(
|
||||
config_data["servicePatchingDuration"]
|
||||
)
|
||||
config_values.file_system_repairing_limit = int(
|
||||
config_data["fileSystemRepairingLimit"]
|
||||
)
|
||||
config_values.file_system_restoring_limit = int(
|
||||
config_data["fileSystemRestoringLimit"]
|
||||
)
|
||||
config_values.file_system_scanning_limit = int(
|
||||
config_data["fileSystemScanningLimit"]
|
||||
)
|
||||
|
||||
logging.info("Training agent: " + config_values.agent_identifier)
|
||||
logging.info("Training environment config: " + config_values.config_filename_use_case)
|
||||
logging.info("Training cycle has " + str(config_values.num_episodes) + " episodes")
|
||||
logging.info(
|
||||
"Training environment config: " + config_values.config_filename_use_case
|
||||
)
|
||||
logging.info(
|
||||
"Training cycle has " + str(config_values.num_episodes) + " episodes"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not save load config data")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
################################# MAIN PROCESS ############################################
|
||||
# MAIN PROCESS #
|
||||
|
||||
# Starting point
|
||||
|
||||
@@ -257,7 +332,7 @@ try:
|
||||
config_values = config_values_main()
|
||||
# Load in config data
|
||||
load_config_values()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not load main config")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
@@ -275,7 +350,7 @@ transaction_list = []
|
||||
try:
|
||||
env = Primaite(config_values, transaction_list)
|
||||
logging.info("PrimAITE environment created")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not create PrimAITE environment")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
@@ -302,11 +377,3 @@ config_file_main.close
|
||||
|
||||
print("Finished")
|
||||
logging.info("Finished")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
An Active Node (i.e. not an actuator)
|
||||
"""
|
||||
|
||||
"""An Active Node (i.e. not an actuator)."""
|
||||
from primaite.common.enums import FILE_SYSTEM_STATE, SOFTWARE_STATE
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.common.enums import *
|
||||
|
||||
|
||||
class ActiveNode(Node):
|
||||
"""
|
||||
Active Node class
|
||||
"""
|
||||
"""Active Node class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_name,
|
||||
_type,
|
||||
_priority,
|
||||
_state,
|
||||
_ip_address,
|
||||
_os_state,
|
||||
_file_system_state,
|
||||
_config_values,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node ID
|
||||
@@ -26,7 +33,6 @@ class ActiveNode(Node):
|
||||
_file_system_state: The node file system state
|
||||
_config_values: The config values
|
||||
"""
|
||||
|
||||
super().__init__(_id, _name, _type, _priority, _state, _config_values)
|
||||
self.ip_address = _ip_address
|
||||
# Related to O/S
|
||||
@@ -39,20 +45,18 @@ class ActiveNode(Node):
|
||||
self.file_system_scanning_count = 0
|
||||
self.file_system_action_count = 0
|
||||
|
||||
|
||||
def set_ip_address(self, _ip_address):
|
||||
"""
|
||||
Sets IP address
|
||||
Sets IP address.
|
||||
|
||||
Args:
|
||||
_ip_address: IP address
|
||||
"""
|
||||
|
||||
self.ip_address = _ip_address
|
||||
|
||||
def get_ip_address(self):
|
||||
"""
|
||||
Gets IP address
|
||||
Gets IP address.
|
||||
|
||||
Returns:
|
||||
IP address
|
||||
@@ -61,24 +65,22 @@ class ActiveNode(Node):
|
||||
|
||||
def set_os_state(self, _os_state):
|
||||
"""
|
||||
Sets operating system state
|
||||
Sets operating system state.
|
||||
|
||||
Args:
|
||||
_os_state: Operating system state
|
||||
"""
|
||||
|
||||
self.os_state = _os_state
|
||||
if _os_state == SOFTWARE_STATE.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
|
||||
def set_os_state_if_not_compromised(self, _os_state):
|
||||
"""
|
||||
Sets operating system state if the node is not compromised
|
||||
Sets operating system state if the node is not compromised.
|
||||
|
||||
Args:
|
||||
_os_state: Operating system state
|
||||
"""
|
||||
|
||||
if self.os_state != SOFTWARE_STATE.COMPROMISED:
|
||||
self.os_state = _os_state
|
||||
if _os_state == SOFTWARE_STATE.PATCHING:
|
||||
@@ -86,19 +88,15 @@ class ActiveNode(Node):
|
||||
|
||||
def get_os_state(self):
|
||||
"""
|
||||
Gets operating system state
|
||||
Gets operating system state.
|
||||
|
||||
Returns:
|
||||
Operating system state
|
||||
"""
|
||||
|
||||
return self.os_state
|
||||
|
||||
def update_os_patching_status(self):
|
||||
"""
|
||||
Updates operating system status based on patching cycle
|
||||
"""
|
||||
|
||||
"""Updates operating system status based on patching cycle."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
@@ -106,87 +104,88 @@ class ActiveNode(Node):
|
||||
|
||||
def set_file_system_state(self, _file_system_state):
|
||||
"""
|
||||
Sets the file system state (actual and observed)
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
Args:
|
||||
_file_system_state: File system state
|
||||
"""
|
||||
|
||||
self.file_system_state_actual = _file_system_state
|
||||
|
||||
if _file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, _file_system_state):
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
Use for green PoL to prevent it overturning a compromised state
|
||||
|
||||
Args:
|
||||
_file_system_state: File system state
|
||||
"""
|
||||
|
||||
if self.file_system_state_actual != FILE_SYSTEM_STATE.CORRUPT and self.file_system_state_actual != FILE_SYSTEM_STATE.DESTROYED:
|
||||
if (
|
||||
self.file_system_state_actual != FILE_SYSTEM_STATE.CORRUPT
|
||||
and self.file_system_state_actual != FILE_SYSTEM_STATE.DESTROYED
|
||||
):
|
||||
self.file_system_state_actual = _file_system_state
|
||||
|
||||
if _file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING
|
||||
elif _file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
|
||||
|
||||
def get_file_system_state_actual(self):
|
||||
"""
|
||||
Gets file system state (actual)
|
||||
Gets file system state (actual).
|
||||
|
||||
Returns:
|
||||
File system state (actual)
|
||||
"""
|
||||
|
||||
return self.file_system_state_actual
|
||||
|
||||
def get_file_system_state_observed(self):
|
||||
"""
|
||||
Gets file system state (observed)
|
||||
Gets file system state (observed).
|
||||
|
||||
Returns:
|
||||
File system state (observed)
|
||||
"""
|
||||
|
||||
return self.file_system_state_observed
|
||||
|
||||
def start_file_system_scan(self):
|
||||
"""
|
||||
Starts a file system scan
|
||||
"""
|
||||
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def is_scanning_file_system(self):
|
||||
"""
|
||||
Gets true/false on whether file system is being scanned
|
||||
Gets true/false on whether file system is being scanned.
|
||||
|
||||
Returns:
|
||||
True if file system is being scanned
|
||||
"""
|
||||
|
||||
return self.file_system_scanning
|
||||
|
||||
def update_file_system_state(self):
|
||||
"""
|
||||
Updates file system status based on scanning / restore / repair cycle
|
||||
"""
|
||||
|
||||
"""Updates file system status based on scanning / restore / repair cycle."""
|
||||
# Deprecate both the action count (for restoring or reparing) and the scanning count
|
||||
self.file_system_action_count -= 1
|
||||
self.file_system_scanning_count -= 1
|
||||
@@ -194,7 +193,10 @@ class ActiveNode(Node):
|
||||
# Reparing / Restoring updates
|
||||
if self.file_system_action_count <= 0:
|
||||
self.file_system_action_count = 0
|
||||
if self.file_system_state_actual == FILE_SYSTEM_STATE.REPAIRING or self.file_system_state_actual == FILE_SYSTEM_STATE.RESTORING:
|
||||
if (
|
||||
self.file_system_state_actual == FILE_SYSTEM_STATE.REPAIRING
|
||||
or self.file_system_state_actual == FILE_SYSTEM_STATE.RESTORING
|
||||
):
|
||||
self.file_system_state_actual = FILE_SYSTEM_STATE.GOOD
|
||||
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
|
||||
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The base Node class
|
||||
"""
|
||||
"""The base Node class."""
|
||||
from primaite.common.enums import HARDWARE_STATE
|
||||
|
||||
from primaite.common.enums import *
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Node class
|
||||
"""
|
||||
"""Node class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _config_values):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node id
|
||||
@@ -21,7 +17,6 @@ class Node:
|
||||
_priority: The priority of the node
|
||||
_state: The state of the node
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.name = _name
|
||||
self.type = _type
|
||||
@@ -31,146 +26,115 @@ class Node:
|
||||
self.config_values = _config_values
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Returns the name of the node
|
||||
"""
|
||||
|
||||
"""Returns the name of the node."""
|
||||
return self.name
|
||||
|
||||
def set_id(self, _id):
|
||||
"""
|
||||
Sets the node ID
|
||||
Sets the node ID.
|
||||
|
||||
Args:
|
||||
_id: The node ID
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets the node ID
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
|
||||
return self.id
|
||||
|
||||
def set_name(self, _name):
|
||||
"""
|
||||
Sets the node name
|
||||
Sets the node name.
|
||||
|
||||
Args:
|
||||
_name: The node name
|
||||
"""
|
||||
|
||||
self.name = _name
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
Gets the node name
|
||||
Gets the node name.
|
||||
|
||||
Returns:
|
||||
The node name
|
||||
"""
|
||||
|
||||
return self.name
|
||||
|
||||
def set_type(self, _type):
|
||||
"""
|
||||
Sets the node type
|
||||
Sets the node type.
|
||||
|
||||
Args:
|
||||
_type: The node type
|
||||
"""
|
||||
|
||||
self.type = _type
|
||||
|
||||
def get_type(self):
|
||||
"""
|
||||
Gets the node type
|
||||
Gets the node type.
|
||||
|
||||
Returns:
|
||||
The node type
|
||||
"""
|
||||
|
||||
return self.type
|
||||
|
||||
def set_priority(self, _priority):
|
||||
"""
|
||||
Sets the node priority
|
||||
Sets the node priority.
|
||||
|
||||
Args:
|
||||
_priority: The node priority
|
||||
"""
|
||||
|
||||
self.priority = _priority
|
||||
|
||||
def get_priority(self):
|
||||
"""
|
||||
Gets the node priority
|
||||
Gets the node priority.
|
||||
|
||||
Returns:
|
||||
The node priority
|
||||
"""
|
||||
|
||||
return self.priority
|
||||
|
||||
def set_state(self, _state):
|
||||
"""
|
||||
Sets the node state
|
||||
Sets the node state.
|
||||
|
||||
Args:
|
||||
_state: The node state
|
||||
"""
|
||||
|
||||
self.operating_state = _state
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the node operating state
|
||||
Gets the node operating state.
|
||||
|
||||
Returns:
|
||||
The node operating state
|
||||
"""
|
||||
|
||||
return self.operating_state
|
||||
|
||||
def turn_on(self):
|
||||
"""
|
||||
Sets the node state to ON
|
||||
"""
|
||||
|
||||
"""Sets the node state to ON."""
|
||||
self.operating_state = HARDWARE_STATE.ON
|
||||
|
||||
def turn_off(self):
|
||||
"""
|
||||
Sets the node state to OFF
|
||||
"""
|
||||
|
||||
"""Sets the node state to OFF."""
|
||||
self.operating_state = HARDWARE_STATE.OFF
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Sets the node state to Resetting and starts the reset count
|
||||
"""
|
||||
|
||||
"""Sets the node state to Resetting and starts the reset count."""
|
||||
self.operating_state = HARDWARE_STATE.RESETTING
|
||||
self.resetting_count = self.config_values.node_reset_duration
|
||||
|
||||
def update_resetting_status(self):
|
||||
"""
|
||||
Updates the resetting count
|
||||
"""
|
||||
|
||||
"""Updates the resetting count."""
|
||||
self.resetting_count -= 1
|
||||
if self.resetting_count <= 0:
|
||||
self.resetting_count = 0
|
||||
self.operating_state = HARDWARE_STATE.ON
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Defines node behaviour for Green PoL
|
||||
"""
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
"""
|
||||
The Node State Instruction class
|
||||
"""
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(self, _id, _start_step, _end_step, _node_id, _node_pol_type, _service_name, _state):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_node_id,
|
||||
_node_pol_type,
|
||||
_service_name,
|
||||
_state,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node state instruction id
|
||||
@@ -21,72 +27,64 @@ class NodeStateInstructionGreen(object):
|
||||
_service_name: The service name
|
||||
_state: The state (node or service)
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type = _node_pol_type
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.state = _state
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets the start step
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets the end step
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self):
|
||||
"""
|
||||
Gets the node ID
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self):
|
||||
"""
|
||||
Gets the node pattern of life type (enum)
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""
|
||||
Gets the service name
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the state (node or service)
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
|
||||
return self.state
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Defines node behaviour for Green PoL
|
||||
"""
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
|
||||
|
||||
class NodeStateInstructionRed(object):
|
||||
"""
|
||||
The Node State Instruction class
|
||||
"""
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(self, _id, _start_step, _end_step, _target_node_id, _pol_initiator, _pol_type, pol_protocol, _pol_state, _pol_source_node_id, _pol_source_node_service, _pol_source_node_service_state):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_pol_type,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node state instruction id
|
||||
@@ -25,14 +35,13 @@ class NodeStateInstructionRed(object):
|
||||
_pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
_pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.pol_type = _pol_type
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
@@ -40,101 +49,90 @@ class NodeStateInstructionRed(object):
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets the start step
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets the end step
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self):
|
||||
"""
|
||||
Gets the node ID
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self):
|
||||
"""
|
||||
Gets the initiator
|
||||
Gets the initiator.
|
||||
|
||||
Returns:
|
||||
The initiator
|
||||
"""
|
||||
|
||||
return self.initiator
|
||||
|
||||
def get_pol_type(self):
|
||||
"""
|
||||
Gets the node pattern of life type (enum)
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self):
|
||||
"""
|
||||
Gets the service name
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
|
||||
return self.service_name
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Gets the state (node or service)
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE)
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node id
|
||||
"""
|
||||
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self):
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE)
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service
|
||||
"""
|
||||
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self):
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE)
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service state
|
||||
"""
|
||||
|
||||
return self.source_node_service_state
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The Passive Node class (i.e. an actuator)
|
||||
"""
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
|
||||
class PassiveNode(Node):
|
||||
"""
|
||||
The Passive Node class
|
||||
"""
|
||||
"""The Passive Node class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _config_values):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node id
|
||||
@@ -21,17 +18,15 @@ class PassiveNode(Node):
|
||||
_priority: The priority of the node
|
||||
_state: The state of the node
|
||||
"""
|
||||
|
||||
# Pass through to Super for now
|
||||
super().__init__(_id, _name, _type, _priority, _state, _config_values)
|
||||
|
||||
def get_ip_address(self):
|
||||
"""
|
||||
Gets the node IP address
|
||||
Gets the node IP address.
|
||||
|
||||
Returns:
|
||||
The node IP address
|
||||
"""
|
||||
|
||||
# No concept of IP address for passive nodes for now
|
||||
return ""
|
||||
@@ -1,19 +1,26 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
A Service Node (i.e. not an actuator)
|
||||
"""
|
||||
|
||||
"""A Service Node (i.e. not an actuator)."""
|
||||
from primaite.common.enums import SOFTWARE_STATE
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.common.enums import *
|
||||
|
||||
|
||||
class ServiceNode(ActiveNode):
|
||||
"""
|
||||
ServiceNode class
|
||||
"""
|
||||
"""ServiceNode class."""
|
||||
|
||||
def __init__(self, _id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values):
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_name,
|
||||
_type,
|
||||
_priority,
|
||||
_state,
|
||||
_ip_address,
|
||||
_os_state,
|
||||
_file_system_state,
|
||||
_config_values,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The node id
|
||||
@@ -25,38 +32,44 @@ class ServiceNode(ActiveNode):
|
||||
_osState: The operating system state of the node
|
||||
_file_system_state: The file system state of the node
|
||||
"""
|
||||
|
||||
super().__init__(_id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values)
|
||||
super().__init__(
|
||||
_id,
|
||||
_name,
|
||||
_type,
|
||||
_priority,
|
||||
_state,
|
||||
_ip_address,
|
||||
_os_state,
|
||||
_file_system_state,
|
||||
_config_values,
|
||||
)
|
||||
self.services = {}
|
||||
|
||||
def add_service(self, _service):
|
||||
"""
|
||||
Adds a service to the node
|
||||
Adds a service to the node.
|
||||
|
||||
Args:
|
||||
_service: The service to add
|
||||
"""
|
||||
|
||||
self.services[_service.get_name()] = _service
|
||||
|
||||
def get_services(self):
|
||||
"""
|
||||
Gets the dictionary of services on this node
|
||||
Gets the dictionary of services on this node.
|
||||
|
||||
Returns:
|
||||
Dictionary of services on this node
|
||||
"""
|
||||
|
||||
return self.services
|
||||
|
||||
def has_service(self, _protocol):
|
||||
"""
|
||||
Indicates whether a service is on a node
|
||||
Indicates whether a service is on a node.
|
||||
|
||||
Returns:
|
||||
True if service (protocol) is on the node
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
return True
|
||||
@@ -66,12 +79,11 @@ class ServiceNode(ActiveNode):
|
||||
|
||||
def service_running(self, _protocol):
|
||||
"""
|
||||
Indicates whether a service is in a running state on the node
|
||||
Indicates whether a service is in a running state on the node.
|
||||
|
||||
Returns:
|
||||
True if service (protocol) is in a running state on the node
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
if service_value.get_state() != SOFTWARE_STATE.PATCHING:
|
||||
@@ -84,12 +96,11 @@ class ServiceNode(ActiveNode):
|
||||
|
||||
def service_is_overwhelmed(self, _protocol):
|
||||
"""
|
||||
Indicates whether a service is in an overwhelmed state on the node
|
||||
Indicates whether a service is in an overwhelmed state on the node.
|
||||
|
||||
Returns:
|
||||
True if service (protocol) is in an overwhelmed state on the node
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
if service_value.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
@@ -102,61 +113,61 @@ class ServiceNode(ActiveNode):
|
||||
|
||||
def set_service_state(self, _protocol, _state):
|
||||
"""
|
||||
Sets the state of a service (protocol) on the node
|
||||
Sets the state of a service (protocol) on the node.
|
||||
|
||||
Args:
|
||||
_protocol: The service (protocol)
|
||||
_state: The state value
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
# Can't set to compromised if you're in a patching state
|
||||
if (_state == SOFTWARE_STATE.COMPROMISED and service_value.get_state() != SOFTWARE_STATE.PATCHING) or _state != SOFTWARE_STATE.COMPROMISED:
|
||||
if (
|
||||
_state == SOFTWARE_STATE.COMPROMISED
|
||||
and service_value.get_state() != SOFTWARE_STATE.PATCHING
|
||||
) or _state != SOFTWARE_STATE.COMPROMISED:
|
||||
service_value.set_state(_state)
|
||||
else:
|
||||
# Do nothing
|
||||
pass
|
||||
if _state == SOFTWARE_STATE.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
else:
|
||||
# Do nothing
|
||||
pass
|
||||
|
||||
def set_service_state_if_not_compromised(self, _protocol, _state):
|
||||
"""
|
||||
Sets the state of a service (protocol) on the node if the operating state is not "compromised"
|
||||
Sets the state of a service (protocol) on the node if the operating state is not "compromised".
|
||||
|
||||
Args:
|
||||
_protocol: The service (protocol)
|
||||
_state: The state value
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
if service_value.get_state() != SOFTWARE_STATE.COMPROMISED:
|
||||
service_value.set_state(_state)
|
||||
if _state == SOFTWARE_STATE.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
|
||||
def get_service_state(self, _protocol):
|
||||
"""
|
||||
Gets the state of a service
|
||||
Gets the state of a service.
|
||||
|
||||
Returns:
|
||||
The state of the service
|
||||
"""
|
||||
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == _protocol:
|
||||
return service_value.get_state()
|
||||
|
||||
def update_services_patching_status(self):
|
||||
"""
|
||||
Updates the patching counter for any service that are patching
|
||||
"""
|
||||
|
||||
"""Updates the patching counter for any service that are patching."""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_value.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Implements Pattern of Life on the network (nodes and links)
|
||||
"""
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
|
||||
from networkx import shortest_path
|
||||
|
||||
from primaite.common.enums import *
|
||||
from primaite.common.enums import HARDWARE_STATE, NODE_POL_TYPE, SOFTWARE_STATE, TYPE
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_VERBOSE = False
|
||||
|
||||
|
||||
def apply_iers(network, nodes, links, iers, acl, step):
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life)
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -21,9 +20,8 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
links: The links within the environment
|
||||
iers: The IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
if _VERBOSE:
|
||||
print("Applying IERs")
|
||||
|
||||
@@ -55,7 +53,10 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
# 1. Check the source node situation
|
||||
if source_node.get_type() == TYPE.SWITCH:
|
||||
# It's a switch
|
||||
if source_node.get_state() == HARDWARE_STATE.ON and source_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
source_node.get_state() == HARDWARE_STATE.ON
|
||||
and source_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
@@ -66,9 +67,14 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if source_node.get_state() == HARDWARE_STATE.ON and source_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
source_node.get_state() == HARDWARE_STATE.ON
|
||||
and source_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
if source_node.service_running(
|
||||
protocol
|
||||
) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -80,11 +86,13 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.get_type() == TYPE.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.get_state() == HARDWARE_STATE.ON and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
dest_node.get_state() == HARDWARE_STATE.ON
|
||||
and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
@@ -94,9 +102,14 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.get_state() == HARDWARE_STATE.ON and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
dest_node.get_state() == HARDWARE_STATE.ON
|
||||
and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
if dest_node.service_running(
|
||||
protocol
|
||||
) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
@@ -109,10 +122,21 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port)
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port
|
||||
)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print("ACL block on source: " + source_node.get_ip_address() + ", dest: " + dest_node.get_ip_address() + ", protocol: " + protocol + ", port: " + port)
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.get_ip_address()
|
||||
+ ", dest: "
|
||||
+ dest_node.get_ip_address()
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
@@ -131,7 +155,10 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if node.get_state() != HARDWARE_STATE.ON or node.get_os_state() == SOFTWARE_STATE.PATCHING:
|
||||
if (
|
||||
node.get_state() != HARDWARE_STATE.ON
|
||||
or node.get_os_state() == SOFTWARE_STATE.PATCHING
|
||||
):
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
@@ -143,8 +170,10 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
@@ -152,7 +181,7 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count+=1
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
@@ -160,12 +189,14 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count+=1
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
else:
|
||||
@@ -183,16 +214,16 @@ def apply_iers(network, nodes, links, iers, acl, step):
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
|
||||
def apply_node_pol(nodes, node_pol, step):
|
||||
"""
|
||||
Applies node pattern of life
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
node_pol: The node pattern of life to apply
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
if _VERBOSE:
|
||||
print("Applying Node PoL")
|
||||
|
||||
|
||||
@@ -1,17 +1,29 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Information Exchange Requirements for APE
|
||||
Used to represent an information flow from source to destination
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
Used to represent an information flow from source to destination.
|
||||
"""
|
||||
|
||||
class IER(object):
|
||||
"""
|
||||
Information Exchange Requirement class
|
||||
"""
|
||||
|
||||
def __init__(self, _id, _start_step, _end_step, _load, _protocol, _port, _source_node_id, _dest_node_id, _mission_criticality, _running=False):
|
||||
class IER(object):
|
||||
"""Information Exchange Requirement class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_load,
|
||||
_protocol,
|
||||
_port,
|
||||
_source_node_id,
|
||||
_dest_node_id,
|
||||
_mission_criticality,
|
||||
_running=False,
|
||||
):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_id: The IER id
|
||||
@@ -25,7 +37,6 @@ class IER(object):
|
||||
_mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
_running: Indicates whether the IER is currently running
|
||||
"""
|
||||
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
@@ -39,97 +50,88 @@ class IER(object):
|
||||
|
||||
def get_id(self):
|
||||
"""
|
||||
Gets IER ID
|
||||
Gets IER ID.
|
||||
|
||||
Returns:
|
||||
IER ID
|
||||
"""
|
||||
|
||||
return self.id
|
||||
|
||||
def get_start_step(self):
|
||||
"""
|
||||
Gets IER start step
|
||||
Gets IER start step.
|
||||
|
||||
Returns:
|
||||
IER start step
|
||||
"""
|
||||
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self):
|
||||
"""
|
||||
Gets IER end step
|
||||
Gets IER end step.
|
||||
|
||||
Returns:
|
||||
IER end step
|
||||
"""
|
||||
|
||||
return self.end_step
|
||||
|
||||
def get_load(self):
|
||||
"""
|
||||
Gets IER load
|
||||
Gets IER load.
|
||||
|
||||
Returns:
|
||||
IER load
|
||||
"""
|
||||
|
||||
return self.load
|
||||
|
||||
def get_protocol(self):
|
||||
"""
|
||||
Gets IER protocol
|
||||
Gets IER protocol.
|
||||
|
||||
Returns:
|
||||
IER protocol
|
||||
"""
|
||||
|
||||
return self.protocol
|
||||
|
||||
def get_port(self):
|
||||
"""
|
||||
Gets IER port
|
||||
Gets IER port.
|
||||
|
||||
Returns:
|
||||
IER port
|
||||
"""
|
||||
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self):
|
||||
"""
|
||||
Gets IER source node ID
|
||||
Gets IER source node ID.
|
||||
|
||||
Returns:
|
||||
IER source node ID
|
||||
"""
|
||||
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self):
|
||||
"""
|
||||
Gets IER destination node ID
|
||||
Gets IER destination node ID.
|
||||
|
||||
Returns:
|
||||
IER destination node ID
|
||||
"""
|
||||
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self):
|
||||
"""
|
||||
Informs whether the IER is currently running
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
Returns:
|
||||
True if running
|
||||
"""
|
||||
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value):
|
||||
"""
|
||||
Sets the running state of the IER
|
||||
Sets the running state of the IER.
|
||||
|
||||
Args:
|
||||
_value: running status
|
||||
@@ -138,10 +140,9 @@ class IER(object):
|
||||
|
||||
def get_mission_criticality(self):
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function)
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
Returns:
|
||||
Mission criticality value (0 lowest to 5 highest)
|
||||
"""
|
||||
|
||||
return self.mission_criticality
|
||||
@@ -1,19 +1,24 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Implements Pattern of Life on the network (nodes and links) resulting from the red agent attack
|
||||
"""
|
||||
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
|
||||
|
||||
from networkx import shortest_path
|
||||
|
||||
from primaite.common.enums import *
|
||||
from primaite.common.enums import (
|
||||
HARDWARE_STATE,
|
||||
NODE_POL_INITIATOR,
|
||||
NODE_POL_TYPE,
|
||||
SOFTWARE_STATE,
|
||||
TYPE,
|
||||
)
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_VERBOSE = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life) resulting from red agent attack
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
@@ -21,9 +26,8 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
links: The links within the environment
|
||||
iers: The red agent IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
@@ -66,7 +70,10 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
if source_node.get_state() == HARDWARE_STATE.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if source_node.get_service_state(protocol) == SOFTWARE_STATE.COMPROMISED:
|
||||
if (
|
||||
source_node.get_service_state(protocol)
|
||||
== SOFTWARE_STATE.COMPROMISED
|
||||
):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -78,7 +85,6 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.get_type() == TYPE.SWITCH:
|
||||
# It's a switch
|
||||
@@ -105,10 +111,21 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port)
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port
|
||||
)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print("ACL block on source: " + source_node.get_ip_address() + ", dest: " + dest_node.get_ip_address() + ", protocol: " + protocol + ", port: " + port)
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.get_ip_address()
|
||||
+ ", dest: "
|
||||
+ dest_node.get_ip_address()
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
@@ -140,8 +157,10 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
@@ -149,7 +168,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count+=1
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
@@ -157,12 +176,14 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1])
|
||||
link_id = edge_dict[0].get('id')
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count+=1
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
@@ -185,17 +206,17 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
"""
|
||||
Applies node pattern of life
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
iers: The red agent IERs
|
||||
node_pol: The red agent node pattern of life to apply
|
||||
step: The step number
|
||||
step: The step number.
|
||||
"""
|
||||
|
||||
if _VERBOSE:
|
||||
print("Applying Node Red Agent PoL")
|
||||
|
||||
@@ -209,7 +230,9 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
source_node_service_state_value = (
|
||||
node_instruction.get_source_node_service_state()
|
||||
)
|
||||
|
||||
passed_checks = False
|
||||
|
||||
@@ -228,7 +251,10 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
if source_node.has_service(source_node_service_name):
|
||||
if source_node.get_service_state(source_node_service_name) == SOFTWARE_STATE[source_node_service_state_value]:
|
||||
if (
|
||||
source_node.get_service_state(source_node_service_name)
|
||||
== SOFTWARE_STATE[source_node_service_state_value]
|
||||
):
|
||||
passed_checks = True
|
||||
else:
|
||||
# Do nothing, no matching state value
|
||||
@@ -248,7 +274,9 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
target_node.set_state(state)
|
||||
elif pol_type == NODE_POL_TYPE.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
target_node.set_os_state(state)
|
||||
elif pol_type == NODE_POL_TYPE.SERVICE:
|
||||
# Change a service state
|
||||
@@ -256,22 +284,33 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
|
||||
def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
"""
|
||||
Checks if the RED IER is incoming.
|
||||
|
||||
TODO: Write more descriptive docstring with params and returns.
|
||||
"""
|
||||
node_id = node.get_id()
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if node_pol_type == NODE_POL_TYPE.OPERATING or node_pol_type == NODE_POL_TYPE.OS or node_pol_type == NODE_POL_TYPE.FILE:
|
||||
if (
|
||||
node_pol_type == NODE_POL_TYPE.OPERATING
|
||||
or node_pol_type == NODE_POL_TYPE.OS
|
||||
or node_pol_type == NODE_POL_TYPE.FILE
|
||||
):
|
||||
# It's looking to change operating state, file system or O/S state, so valid
|
||||
return True
|
||||
elif node_pol_type == NODE_POL_TYPE.SERVICE:
|
||||
@@ -297,5 +336,3 @@ def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
else:
|
||||
# The IER destination is not this node, or the IER is not running
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
The Transaction class
|
||||
"""
|
||||
"""The Transaction class."""
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""
|
||||
Transaction class
|
||||
"""
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number):
|
||||
"""
|
||||
Init
|
||||
Init.
|
||||
|
||||
Args:
|
||||
_timestamp: The time this object was created
|
||||
@@ -17,7 +15,6 @@ class Transaction(object):
|
||||
_episode_number: The episode number
|
||||
_step_number: The step number
|
||||
"""
|
||||
|
||||
self.timestamp = _timestamp
|
||||
self.agent_identifier = _agent_identifier
|
||||
self.episode_number = _episode_number
|
||||
@@ -25,45 +22,36 @@ class Transaction(object):
|
||||
|
||||
def set_obs_space_pre(self, _obs_space_pre):
|
||||
"""
|
||||
Sets the observation space (pre)
|
||||
Sets the observation space (pre).
|
||||
|
||||
Args:
|
||||
_obs_space_pre: The observation space before any actions are taken
|
||||
"""
|
||||
|
||||
self.obs_space_pre = _obs_space_pre
|
||||
|
||||
def set_obs_space_post(self, _obs_space_post):
|
||||
"""
|
||||
Sets the observation space (post)
|
||||
Sets the observation space (post).
|
||||
|
||||
Args:
|
||||
_obs_space_post: The observation space after any actions are taken
|
||||
"""
|
||||
|
||||
self.obs_space_post = _obs_space_post
|
||||
|
||||
def set_reward(self, _reward):
|
||||
"""
|
||||
Sets the reward
|
||||
Sets the reward.
|
||||
|
||||
Args:
|
||||
_reward: The reward value
|
||||
"""
|
||||
|
||||
self.reward = _reward
|
||||
|
||||
def set_action_space(self, _action_space):
|
||||
"""
|
||||
Sets the action space
|
||||
Sets the action space.
|
||||
|
||||
Args:
|
||||
_action_space: The action space invoked by the agent
|
||||
"""
|
||||
|
||||
self.action_space = _action_space
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,40 +1,35 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Writes the Transaction log list out to file for evaluation to utilse
|
||||
"""
|
||||
"""Writes the Transaction log list out to file for evaluation to utilse."""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os.path
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
def turn_action_space_to_array(_action_space):
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_action_space: The action space
|
||||
_action_space: The action space.
|
||||
"""
|
||||
|
||||
return_array = []
|
||||
for x in range(len(_action_space)):
|
||||
return_array.append(str(_action_space[x]))
|
||||
|
||||
return return_array
|
||||
|
||||
|
||||
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
Args:
|
||||
_obs_space: The observation space
|
||||
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
_obs_features: The number of features associated with the asset
|
||||
"""
|
||||
|
||||
return_array = []
|
||||
for x in range(_obs_assets):
|
||||
for y in range(_obs_features):
|
||||
@@ -42,15 +37,15 @@ def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
||||
|
||||
return return_array
|
||||
|
||||
|
||||
def write_transaction_to_file(_transaction_list):
|
||||
"""
|
||||
Writes transaction logs to file to support training evaluation
|
||||
Writes transaction logs to file to support training evaluation.
|
||||
|
||||
Args:
|
||||
_transaction_list: The list of transactions from all steps and all episodes
|
||||
_num_episodes: The number of episodes that were conducted
|
||||
_num_episodes: The number of episodes that were conducted.
|
||||
"""
|
||||
|
||||
# Get the first transaction and use it to determine the makeup of the observation space and action space
|
||||
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action space as "AS_1"
|
||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||
@@ -62,43 +57,53 @@ def write_transaction_to_file(_transaction_list):
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append('AS_' + str(x))
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Create the observation space headers array
|
||||
obs_header_initial = []
|
||||
obs_header_new = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
obs_header_initial.append('OSI_' + str(x) + '_' + str(y))
|
||||
obs_header_new.append('OSN_' + str(x) + '_' + str(y))
|
||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
||||
|
||||
# Open up a csv file
|
||||
header = ['Timestamp', 'Episode', 'Step', 'Reward']
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + obs_header_initial + obs_header_new
|
||||
now = datetime.now() # current date and time
|
||||
now = datetime.now() # current date and time
|
||||
time = now.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
try:
|
||||
path = 'outputs/results/'
|
||||
path = "outputs/results/"
|
||||
is_dir = os.path.isdir(path)
|
||||
if not is_dir:
|
||||
os.makedirs(path)
|
||||
|
||||
filename = "outputs/results/all_transactions_" + time + ".csv"
|
||||
csv_file = open(filename, 'w', encoding='UTF8', newline='')
|
||||
csv_file = open(filename, "w", encoding="UTF8", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(header)
|
||||
|
||||
for transaction in _transaction_list:
|
||||
csv_data = [str(transaction.timestamp), str(transaction.episode_number), str(transaction.step_number), str(transaction.reward)]
|
||||
csv_data = csv_data + turn_action_space_to_array(transaction.action_space) + \
|
||||
turn_obs_space_to_array(transaction.obs_space_pre, obs_assets, obs_features) + \
|
||||
turn_obs_space_to_array(transaction.obs_space_post, obs_assets, obs_features)
|
||||
csv_data = [
|
||||
str(transaction.timestamp),
|
||||
str(transaction.episode_number),
|
||||
str(transaction.step_number),
|
||||
str(transaction.reward),
|
||||
]
|
||||
csv_data = (
|
||||
csv_data
|
||||
+ turn_action_space_to_array(transaction.action_space)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_pre, obs_assets, obs_features
|
||||
)
|
||||
+ turn_obs_space_to_array(
|
||||
transaction.obs_space_post, obs_assets, obs_features
|
||||
)
|
||||
)
|
||||
csv_writer.writerow(csv_data)
|
||||
|
||||
csv_file.close()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.error("Could not save the transaction file")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
|
||||
|
||||
|
||||
@@ -1,61 +1,48 @@
|
||||
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Used to tes the ACL functions
|
||||
"""
|
||||
"""Used to tes the ACL functions."""
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
|
||||
|
||||
def test_acl_address_match_1():
|
||||
"""
|
||||
Test that matching IP addresses produce True
|
||||
"""
|
||||
|
||||
"""Test that matching IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""
|
||||
Test that mismatching IP addresses produce False
|
||||
"""
|
||||
|
||||
def test_acl_address_match_2():
|
||||
"""Test that mismatching IP addresses produce False."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""
|
||||
Test the ANY condition for source IP addresses produce True
|
||||
"""
|
||||
|
||||
def test_acl_address_match_3():
|
||||
"""Test the ANY condition for source IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""
|
||||
Test the ANY condition for dest IP addresses produce True
|
||||
"""
|
||||
|
||||
def test_acl_address_match_4():
|
||||
"""Test the ANY condition for dest IP addresses produce True."""
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
|
||||
|
||||
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
|
||||
|
||||
def test_check_acl_block_affirmative():
|
||||
"""
|
||||
Test the block function (affirmative)
|
||||
"""
|
||||
|
||||
def test_check_acl_block_affirmative():
|
||||
"""Test the block function (affirmative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
|
||||
@@ -66,15 +53,19 @@ def test_check_acl_block_affirmative():
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
|
||||
acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
|
||||
|
||||
def test_check_acl_block_negative():
|
||||
"""
|
||||
Test the block function (negative)
|
||||
"""
|
||||
|
||||
def test_check_acl_block_negative():
|
||||
"""Test the block function (negative)."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
|
||||
@@ -85,21 +76,27 @@ def test_check_acl_block_negative():
|
||||
acl_rule_protocol = "TCP"
|
||||
acl_rule_port = "80"
|
||||
|
||||
acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port)
|
||||
acl.add_rule(
|
||||
acl_rule_permission,
|
||||
acl_rule_source,
|
||||
acl_rule_destination,
|
||||
acl_rule_protocol,
|
||||
acl_rule_port,
|
||||
)
|
||||
|
||||
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
|
||||
|
||||
def test_rule_hash():
|
||||
"""
|
||||
Test the rule hash
|
||||
"""
|
||||
|
||||
def test_rule_hash():
|
||||
"""Test the rule hash."""
|
||||
# Create the Access Control List
|
||||
acl = AccessControlList()
|
||||
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_remote = acl.get_dictionary_hash(
|
||||
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
|
||||
)
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
Reference in New Issue
Block a user