Ran pre-commit hook on all files and performed changes to fix flake8 failures

This commit is contained in:
Chris McCarthy
2023-05-25 11:42:19 +01:00
parent aa8284897a
commit 4eb1658966
42 changed files with 1371 additions and 1170 deletions

View File

@@ -3,4 +3,4 @@ Index-servers =
PrimAITE PrimAITE
[PrimAITE] [PrimAITE]
Repository = https://pkgs.dev.azure.com/ma-dev-uk/PrimAITE/_packaging/PrimAITE/pypi/upload/ Repository = https://pkgs.dev.azure.com/ma-dev-uk/PrimAITE/_packaging/PrimAITE/pypi/upload/

12
.flake8 Normal file
View File

@@ -0,0 +1,12 @@
[flake8]
max-line-length=120
extend-ignore =
D105
D107
D100
D104
E203
E712
D401
exclude =
docs/source/*

View File

@@ -1 +1 @@
include src/primaite/config/*.yaml include src/primaite/config/*.yaml

View File

@@ -1 +1 @@
# PrimAITE # PrimAITE

View File

@@ -3,12 +3,15 @@
# For the full list of built-in configuration values, see the documentation: # For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html # https://www.sphinx-doc.org/en/master/usage/configuration.html
import datetime
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import os import os
import sys import sys
import datetime
import furo import furo # noqa
sys.path.insert(0, os.path.abspath("../")) sys.path.insert(0, os.path.abspath("../"))
@@ -33,7 +36,6 @@ templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

View File

@@ -105,7 +105,7 @@ The status changes that can be made to a node are as follows:
* ON * ON
* OFF * OFF
* RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state * RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state
* Active Nodes and Service Nodes: * Active Nodes and Service Nodes:
@@ -194,7 +194,7 @@ An example observation space is provided below:
:widths: 25 25 25 25 25 25 25 :widths: 25 25 25 25 25 25 25
:header-rows: 1 :header-rows: 1
* - * -
- ID - ID
- Operating State - Operating State
- O/S State - O/S State
@@ -326,8 +326,8 @@ A reward value is presented back to the blue agent on the conclusion of every st
**Node and service status** **Node and service status**
On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment) On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment)
and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a
difference with respect to the reference environment, the before and after states are compared, and a score determined. See :ref:`config` for details of reward values. difference with respect to the reference environment, the before and after states are compared, and a score determined. See :ref:`config` for details of reward values.
**IER status** **IER status**

View File

@@ -66,83 +66,83 @@ The config_main.yaml file consists of the following attributes:
The score to give when the node should be resetting, but is off The score to give when the node should be resetting, but is off
* **Node Operating State [onShouldBeOff]** [int] * **Node Operating State [onShouldBeOff]** [int]
The score to give when the node should be off, but is on The score to give when the node should be off, but is on
* **Node Operating State [onShouldBeResetting]** [int] * **Node Operating State [onShouldBeResetting]** [int]
The score to give when the node should be resetting, but is on The score to give when the node should be resetting, but is on
* **Node Operating State [resettingShouldBeOn]** [int] * **Node Operating State [resettingShouldBeOn]** [int]
The score to give when the node should be on, but is resetting The score to give when the node should be on, but is resetting
* **Node Operating State [resettingShouldBeOff]** [int] * **Node Operating State [resettingShouldBeOff]** [int]
The score to give when the node should be off, but is resetting The score to give when the node should be off, but is resetting
* **Node Operating State [resetting]** [int] * **Node Operating State [resetting]** [int]
The score to give when the node is resetting The score to give when the node is resetting
* **Node Operating System or Service State [goodShouldBePatching]** [int] * **Node Operating System or Service State [goodShouldBePatching]** [int]
The score to give when the state should be patching, but is good The score to give when the state should be patching, but is good
* **Node Operating System or Service State [goodShouldBeCompromised]** [int] * **Node Operating System or Service State [goodShouldBeCompromised]** [int]
The score to give when the state should be compromised, but is good The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [goodShouldBeOverwhelmed]** [int] * **Node Operating System or Service State [goodShouldBeOverwhelmed]** [int]
The score to give when the state should be overwhelmed, but is good The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patchingShouldBeGood]** [int] * **Node Operating System or Service State [patchingShouldBeGood]** [int]
The score to give when the state should be good, but is patching The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patchingShouldBeCompromised]** [int] * **Node Operating System or Service State [patchingShouldBeCompromised]** [int]
The score to give when the state should be compromised, but is patching The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patchingShouldBeOverwhelmed]** [int] * **Node Operating System or Service State [patchingShouldBeOverwhelmed]** [int]
The score to give when the state should be overwhelmed, but is patching The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [int] * **Node Operating System or Service State [patching]** [int]
The score to give when the state is patching The score to give when the state is patching
* **Node Operating System or Service State [compromisedShouldBeGood]** [int] * **Node Operating System or Service State [compromisedShouldBeGood]** [int]
The score to give when the state should be good, but is compromised The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromisedShouldBePatching]** [int] * **Node Operating System or Service State [compromisedShouldBePatching]** [int]
The score to give when the state should be patching, but is compromised The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromisedShouldBeOverwhelmed]** [int] * **Node Operating System or Service State [compromisedShouldBeOverwhelmed]** [int]
The score to give when the state should be overwhelmed, but is compromised The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [int] * **Node Operating System or Service State [compromised]** [int]
The score to give when the state is compromised The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmedShouldBeGood]** [int] * **Node Operating System or Service State [overwhelmedShouldBeGood]** [int]
The score to give when the state should be good, but is overwhelmed The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmedShouldBePatching]** [int] * **Node Operating System or Service State [overwhelmedShouldBePatching]** [int]
The score to give when the state should be patching, but is overwhelmed The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmedShouldBeCompromised]** [int] * **Node Operating System or Service State [overwhelmedShouldBeCompromised]** [int]
The score to give when the state should be compromised, but is overwhelmed The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [int] * **Node Operating System or Service State [overwhelmed]** [int]
The score to give when the state is overwhelmed The score to give when the state is overwhelmed
* **Node File System State [goodShouldBeRepairing]** [int] * **Node File System State [goodShouldBeRepairing]** [int]
@@ -246,11 +246,11 @@ The config_main.yaml file consists of the following attributes:
The score to give when the state is scanning The score to give when the state is scanning
* **IER Status [redIerRunning]** [int] * **IER Status [redIerRunning]** [int]
The score to give when a red agent IER is permitted to run The score to give when a red agent IER is permitted to run
* **IER Status [greenIerBlocked]** [int] * **IER Status [greenIerBlocked]** [int]
The score to give when a green agent IER is prevented from running The score to give when a green agent IER is prevented from running
**Patching / Reset Durations** **Patching / Reset Durations**
@@ -260,14 +260,14 @@ The config_main.yaml file consists of the following attributes:
The number of steps to take when patching an Operating System The number of steps to take when patching an Operating System
* **nodeResetDuration** [int] * **nodeResetDuration** [int]
The number of steps to take when resetting a node's operating state The number of steps to take when resetting a node's operating state
* **servicePatchingDuration** [int] * **servicePatchingDuration** [int]
The number of steps to take when patching a service The number of steps to take when patching a service
* **fileSystemRepairingLimit** [int]: * **fileSystemRepairingLimit** [int]:
The number of steps to take when repairing the file system The number of steps to take when repairing the file system
@@ -285,23 +285,23 @@ config_[name].yaml:
The config_[name].yaml file consists of the following attributes: The config_[name].yaml file consists of the following attributes:
* **itemType: ACTIONS** [enum] * **itemType: ACTIONS** [enum]
Determines whether a NODE or ACL action space format is adopted for the session Determines whether a NODE or ACL action space format is adopted for the session
* **itemType: STEPS** [int] * **itemType: STEPS** [int]
Determines the number of steps to run in each episode of the session Determines the number of steps to run in each episode of the session
* **itemType: PORTS** [int] * **itemType: PORTS** [int]
Provides a list of ports modelled in this session Provides a list of ports modelled in this session
* **itemType: SERVICES** [freetext] * **itemType: SERVICES** [freetext]
Provides a list of services modelled in this session Provides a list of services modelled in this session
* **itemType: NODE** * **itemType: NODE**
Defines a node included in the system laydown being simulated. It should consist of the following attributes: Defines a node included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item * **id** [int]: Unique ID for this YAML item
@@ -318,9 +318,9 @@ The config_[name].yaml file consists of the following attributes:
* **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list * **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list
* **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list * **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list
* **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED * **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
* **itemType: LINK** * **itemType: LINK**
Defines a link included in the system laydown being simulated. It should consist of the following attributes: Defines a link included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item * **id** [int]: Unique ID for this YAML item
@@ -344,7 +344,7 @@ The config_[name].yaml file consists of the following attributes:
* **missionCriticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest) * **missionCriticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest)
* **itemType: RED_IER** * **itemType: RED_IER**
Defines a red agent Information Exchange Requirement (IER). It should consist of: Defines a red agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item * **id** [int]: Unique ID for this YAML item
@@ -358,7 +358,7 @@ The config_[name].yaml file consists of the following attributes:
* **missionCriticality** [enum]: Not currently used. Default to 0 * **missionCriticality** [enum]: Not currently used. Default to 0
* **itemType: GREEN_POL** * **itemType: GREEN_POL**
Defines a green agent pattern-of-life instruction. It should consist of: Defines a green agent pattern-of-life instruction. It should consist of:
* **id** [int]: Unique ID for this YAML item * **id** [int]: Unique ID for this YAML item
@@ -370,7 +370,7 @@ The config_[name].yaml file consists of the following attributes:
* **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for operating system state) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state) * **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for operating system state) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state)
* **itemType: RED_POL** * **itemType: RED_POL**
Defines a red agent pattern-of-life instruction. It should consist of: Defines a red agent pattern-of-life instruction. It should consist of:
* **id** [int]: Unique ID for this YAML item * **id** [int]: Unique ID for this YAML item
@@ -386,7 +386,7 @@ The config_[name].yaml file consists of the following attributes:
* **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED * **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
* **itemType: ACL_RULE** * **itemType: ACL_RULE**
Defines an initial Access Control List (ACL) rule. It should consist of: Defines an initial Access Control List (ACL) rule. It should consist of:
* **id** [int]: Unique ID for this YAML item * **id** [int]: Unique ID for this YAML item
@@ -394,4 +394,4 @@ The config_[name].yaml file consists of the following attributes:
* **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format * **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format * **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list * **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
* **port** [int]: Defines the port for the rule. Must match a value in the ports list * **port** [int]: Defines the port for the rule. Must match a value in the ports list

View File

@@ -10,7 +10,7 @@ PrimAITE is built with the following versions of dependencies:
* numpy 1.23.5 * numpy 1.23.5
* networkx 2.8.8 * networkx 2.8.8
* gym 0.21.0 * gym 0.21.0
* matplotlib 3.6.2 * matplotlib 3.6.2
* stable_baselines_3 1.6.2 * stable_baselines_3 1.6.2
The latest release of PrimAITE has been tested against the following versions of dependencies: The latest release of PrimAITE has been tested against the following versions of dependencies:
@@ -20,7 +20,5 @@ The latest release of PrimAITE has been tested against the following versions of
* numpy 1.23.5 * numpy 1.23.5
* networkx 2.8.8 * networkx 2.8.8
* gym 0.21.0 * gym 0.21.0
* matplotlib 3.6.2 * matplotlib 3.6.2
* stable_baselines_3 1.6.2 * stable_baselines_3 1.6.2

View File

@@ -39,4 +39,4 @@ For each training session, assuming the agent being trained implements the *save
**Logging** **Logging**
PrimAITE also provides output logs (for diagnosis) using the Python Logging package. These can be found in the *[Install Directory]\\Primaite\\Primaite\\logs* directory PrimAITE also provides output logs (for diagnosis) using the Python Logging package. These can be found in the *[Install Directory]\\Primaite\\Primaite\\logs* directory

View File

@@ -24,7 +24,7 @@ Integrating a blue agent with PrimAITE requires some modification of the code wi
* Stable Baselines 3 PPO (run_stable_baselines3_ppo) * Stable Baselines 3 PPO (run_stable_baselines3_ppo)
* Stable Baselines 3 A2C (run_stable_baselines3_a2c) * Stable Baselines 3 A2C (run_stable_baselines3_a2c)
The selection of which agent type to use is made via the config_main.yaml file. In order to train a user generated agent, The selection of which agent type to use is made via the config_main.yaml file. In order to train a user generated agent,
the run_generic function should be selected, and should be modified (typically) to be: the run_generic function should be selected, and should be modified (typically) to be:
.. code:: python .. code:: python
@@ -46,7 +46,7 @@ Where:
* the *save_agent()* assumes that a *save()* function has been defined in the user created agent. If not, this line can be ommitted (although it is encouraged, since it will allow the agent to be saved and ported) * the *save_agent()* assumes that a *save()* function has been defined in the user created agent. If not, this line can be ommitted (although it is encouraged, since it will allow the agent to be saved and ported)
The code below provides a suggested format for the learn() function within the user created agent. The code below provides a suggested format for the learn() function within the user created agent.
It's important to include the *self.environment.reset()* call within the episode loop in order that the It's important to include the *self.environment.reset()* call within the episode loop in order that the
environment is reset between episodes. Note that the example below should not be considered exhaustive. environment is reset between episodes. Note that the example below should not be considered exhaustive.
.. code:: python .. code:: python
@@ -58,7 +58,7 @@ environment is reset between episodes. Note that the example below should not be
# reset the environment # reset the environment
self.environment.reset() self.environment.reset()
done = False done = False
for step in range(max_steps): for step in range(max_steps):
# calculate the action # calculate the action
action = ... action = ...
@@ -77,12 +77,10 @@ environment is reset between episodes. Note that the example below should not be
break break
**Running the session** **Running the session**
In order to execute a session, carry out the following steps: In order to execute a session, carry out the following steps:
1. Navigate to "[Install directory]\\Primaite\\Primaite\\” 1. Navigate to "[Install directory]\\Primaite\\Primaite\\”
2. Start a console window (type “CMD” in path window, or start a console window first and navigate to “[Install Directory]\\Primaite\\Primaite\\”) 2. Start a console window (type “CMD” in path window, or start a console window first and navigate to “[Install Directory]\\Primaite\\Primaite\\”)
3. Type “python main.py” 3. Type “python main.py”
4. The session will start with an output indicating the current episode, and average reward value for the episode 4. The session will start with an output indicating the current episode, and average reward value for the episode

View File

@@ -10,8 +10,9 @@ class bdist_wheel(_bdist_wheel): # noqa
# Source: https://stackoverflow.com/a/45150383 # Source: https://stackoverflow.com/a/45150383
self.root_is_pure = False self.root_is_pure = False
setup( setup(
cmdclass={ cmdclass={
"bdist_wheel": bdist_wheel, "bdist_wheel": bdist_wheel,
} }
) )

View File

@@ -1 +1 @@
1.2.0 1.2.0

View File

@@ -1,25 +1,19 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """A class that implements the access control list implementation for the network."""
A class that implements the access control list implementation for the network
"""
from primaite.acl.acl_rule import ACLRule from primaite.acl.acl_rule import ACLRule
class AccessControlList():
""" class AccessControlList:
Access Control List class """Access Control List class."""
"""
def __init__(self): def __init__(self):
""" """Init."""
Init self.acl = {} # A dictionary of ACL Rules
"""
self.acl = {} # A dictionary of ACL Rules
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
""" """
Checks for IP address matches Checks for IP address matches.
Args: Args:
_rule: The rule being checked _rule: The rule being checked
@@ -29,18 +23,28 @@ class AccessControlList():
Returns: Returns:
True if match; False otherwise. True if match; False otherwise.
""" """
if (
if ((_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) or (
(_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address) or _rule.get_source_ip() == _source_ip_address
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY") or and _rule.get_dest_ip() == _dest_ip_address
(_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")): )
or (
_rule.get_source_ip() == "ANY"
and _rule.get_dest_ip() == _dest_ip_address
)
or (
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == "ANY"
)
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
):
return True return True
else: else:
return False return False
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
""" """
Checks for rules that block a protocol / port Checks for rules that block a protocol / port.
Args: Args:
_source_ip_address: the source IP address to check _source_ip_address: the source IP address to check
@@ -51,11 +55,17 @@ class AccessControlList():
Returns: Returns:
Indicates block if all conditions are satisfied. Indicates block if all conditions are satisfied.
""" """
for rule_key, rule_value in self.acl.items(): for rule_key, rule_value in self.acl.items():
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): if self.check_address_match(
if ((rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and rule_value, _source_ip_address, _dest_ip_address
(str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY")): ):
if (
rule_value.get_protocol() == _protocol
or rule_value.get_protocol() == "ANY"
) and (
str(rule_value.get_port()) == str(_port)
or rule_value.get_port() == "ANY"
):
# There's a matching rule. Get the permission # There's a matching rule. Get the permission
if rule_value.get_permission() == "DENY": if rule_value.get_permission() == "DENY":
return True return True
@@ -67,7 +77,7 @@ class AccessControlList():
def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): def add_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
""" """
Adds a new rule Adds a new rule.
Args: Args:
_permission: the permission value (e.g. "ALLOW" or "DENY") _permission: the permission value (e.g. "ALLOW" or "DENY")
@@ -76,14 +86,13 @@ class AccessControlList():
_protocol: the protocol _protocol: the protocol
_port: the port _port: the port
""" """
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(new_rule) hash_value = hash(new_rule)
self.acl[hash_value] = new_rule self.acl[hash_value] = new_rule
def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port): def remove_rule(self, _permission, _source_ip, _dest_ip, _protocol, _port):
""" """
Removes a rule Removes a rule.
Args: Args:
_permission: the permission value (e.g. "ALLOW" or "DENY") _permission: the permission value (e.g. "ALLOW" or "DENY")
@@ -92,25 +101,21 @@ class AccessControlList():
_protocol: the protocol _protocol: the protocol
_port: the port _port: the port
""" """
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule) hash_value = hash(rule)
# There will not always be something 'popable' since the agent will be trying random things # There will not always be something 'popable' since the agent will be trying random things
try: try:
self.acl.pop(hash_value) self.acl.pop(hash_value)
except: except Exception:
return return
def remove_all_rules(self): def remove_all_rules(self):
""" """Removes all rules."""
Removes all rules
"""
self.acl.clear() self.acl.clear()
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
""" """
Produces a hash value for a rule Produces a hash value for a rule.
Args: Args:
_permission: the permission value (e.g. "ALLOW" or "DENY") _permission: the permission value (e.g. "ALLOW" or "DENY")
@@ -122,13 +127,6 @@ class AccessControlList():
Returns: Returns:
Hash value based on rule parameters. Hash value based on rule parameters.
""" """
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule) hash_value = hash(rule)
return hash_value return hash_value

View File

@@ -1,16 +1,13 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """A class that implements an access control list rule."""
A class that implements an access control list rule
"""
class ACLRule():
""" class ACLRule:
Access Control List Rule class """Access Control List Rule class."""
"""
def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port): def __init__(self, _permission, _source_ip, _dest_ip, _protocol, _port):
""" """
Init Init.
Args: Args:
_permission: The permission (ALLOW or DENY) _permission: The permission (ALLOW or DENY)
@@ -19,7 +16,6 @@ class ACLRule():
_protocol: The rule protocol _protocol: The rule protocol
_port: The rule port _port: The rule port
""" """
self.permission = _permission self.permission = _permission
self.source_ip = _source_ip self.source_ip = _source_ip
self.dest_ip = _dest_ip self.dest_ip = _dest_ip
@@ -28,47 +24,45 @@ class ACLRule():
def __hash__(self): def __hash__(self):
""" """
Override the hash function Override the hash function.
Returns: Returns:
Returns hash of core parameters. Returns hash of core parameters.
""" """
return hash(
return hash((self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)) (self.permission, self.source_ip, self.dest_ip, self.protocol, self.port)
)
def get_permission(self): def get_permission(self):
""" """
Gets the permission attribute Gets the permission attribute.
Returns: Returns:
Returns permission attribute Returns permission attribute
""" """
return self.permission return self.permission
def get_source_ip(self): def get_source_ip(self):
""" """
Gets the source IP address attribute Gets the source IP address attribute.
Returns: Returns:
Returns source IP address attribute Returns source IP address attribute
""" """
return self.source_ip return self.source_ip
def get_dest_ip(self): def get_dest_ip(self):
""" """
Gets the desintation IP address attribute Gets the desintation IP address attribute.
Returns: Returns:
Returns destination IP address attribute Returns destination IP address attribute
""" """
return self.dest_ip return self.dest_ip
def get_protocol(self): def get_protocol(self):
""" """
Gets the protocol attribute Gets the protocol attribute.
Returns: Returns:
Returns protocol attribute Returns protocol attribute
@@ -77,12 +71,9 @@ class ACLRule():
def get_port(self): def get_port(self):
""" """
Gets the port attribute Gets the port attribute.
Returns: Returns:
Returns port attribute Returns port attribute
""" """
return self.port return self.port

View File

@@ -1,2 +1 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,39 +1,35 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """The config class."""
The config class
"""
class config_values_main(object): class config_values_main(object):
""" """Class to hold main config values."""
Class to hold main config values
"""
def __init__(self): def __init__(self):
""" """Init."""
Init
"""
# Generic # Generic
self.agent_identifier = "" # the agent in use self.agent_identifier = "" # the agent in use
self.num_episodes = 0 # number of episodes to train over self.num_episodes = 0 # number of episodes to train over
self.num_steps = 0 # number of steps in an episode self.num_steps = 0 # number of steps in an episode
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
self.config_filename_use_case = "" # the filename for the Use Case config file self.config_filename_use_case = "" # the filename for the Use Case config file
self.session_type = "" # the session type to run (TRAINING or EVALUATION) self.session_type = "" # the session type to run (TRAINING or EVALUATION)
# Environment # Environment
self.observation_space_high_value = 0 # The high value for the observation space self.observation_space_high_value = (
0 # The high value for the observation space
)
# Reward values # Reward values
# Generic # Generic
self.all_ok = 0 self.all_ok = 0
# Node Operating State # Node Operating State
self.off_should_be_on = 0 self.off_should_be_on = 0
self.off_should_be_resetting = 0 self.off_should_be_resetting = 0
self.on_should_be_off = 0 self.on_should_be_off = 0
self.on_should_be_resetting = 0 self.on_should_be_resetting = 0
self.resetting_should_be_on = 0 self.resetting_should_be_on = 0
self.resetting_should_be_off = 0 self.resetting_should_be_off = 0
self.resetting = 0 self.resetting = 0
# Node O/S or Service State # Node O/S or Service State
self.good_should_be_patching = 0 self.good_should_be_patching = 0
@@ -46,7 +42,7 @@ class config_values_main(object):
self.compromised_should_be_good = 0 self.compromised_should_be_good = 0
self.compromised_should_be_patching = 0 self.compromised_should_be_patching = 0
self.compromised_should_be_overwhelmed = 0 self.compromised_should_be_overwhelmed = 0
self.compromised = 0 self.compromised = 0
self.overwhelmed_should_be_good = 0 self.overwhelmed_should_be_good = 0
self.overwhelmed_should_be_patching = 0 self.overwhelmed_should_be_patching = 0
self.overwhelmed_should_be_compromised = 0 self.overwhelmed_should_be_compromised = 0
@@ -59,11 +55,15 @@ class config_values_main(object):
self.repairing_should_be_good = 0 self.repairing_should_be_good = 0
self.repairing_should_be_restoring = 0 self.repairing_should_be_restoring = 0
self.repairing_should_be_corrupt = 0 self.repairing_should_be_corrupt = 0
self.repairing_should_be_destroyed = 0 # Repairing does not fix destroyed state - you need to restore self.repairing_should_be_destroyed = (
0 # Repairing does not fix destroyed state - you need to restore
)
self.repairing = 0 self.repairing = 0
self.restoring_should_be_good = 0 self.restoring_should_be_good = 0
self.restoring_should_be_repairing = 0 self.restoring_should_be_repairing = 0
self.restoring_should_be_corrupt = 0 # Not the optimal method (as repair will fix corruption) self.restoring_should_be_corrupt = (
0 # Not the optimal method (as repair will fix corruption)
)
self.restoring_should_be_destroyed = 0 self.restoring_should_be_destroyed = 0
self.restoring = 0 self.restoring = 0
self.corrupt_should_be_good = 0 self.corrupt_should_be_good = 0
@@ -82,10 +82,9 @@ class config_values_main(object):
self.green_ier_blocked = 0 self.green_ier_blocked = 0
# Patching / Reset # Patching / Reset
self.os_patching_duration = 0 # The time taken to patch the OS self.os_patching_duration = 0 # The time taken to patch the OS
self.node_reset_duration = 0 # The time taken to reset a node (hardware) self.node_reset_duration = 0 # The time taken to reset a node (hardware)
self.service_patching_duration = 0 # The time taken to patch a service self.service_patching_duration = 0 # The time taken to patch a service
self.file_system_repairing_limit = 0 # The time take to repair a file self.file_system_repairing_limit = 0 # The time take to repair a file
self.file_system_restoring_limit = 0 # The time take to restore a file self.file_system_restoring_limit = 0 # The time take to restore a file
self.file_system_scanning_limit = 0 # The time taken to scan the file system self.file_system_scanning_limit = 0 # The time taken to scan the file system

View File

@@ -1,14 +1,11 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Enumerations for APE."""
Enumerations for APE
"""
from enum import Enum from enum import Enum
class TYPE(Enum): class TYPE(Enum):
""" """Node type enumeration."""
Node type enumeration
"""
CCTV = 1 CCTV = 1
SWITCH = 2 SWITCH = 2
@@ -21,10 +18,9 @@ class TYPE(Enum):
ACTUATOR = 9 ACTUATOR = 9
SERVER = 10 SERVER = 10
class PRIORITY(Enum): class PRIORITY(Enum):
""" """Node priority enumeration."""
Node priority enumeration
"""
P1 = 1 P1 = 1
P2 = 2 P2 = 2
@@ -32,48 +28,43 @@ class PRIORITY(Enum):
P4 = 4 P4 = 4
P5 = 5 P5 = 5
class HARDWARE_STATE(Enum): class HARDWARE_STATE(Enum):
""" """Node hardware state enumeration."""
Node hardware state enumeration
"""
ON = 1 ON = 1
OFF = 2 OFF = 2
RESETTING = 3 RESETTING = 3
class SOFTWARE_STATE(Enum): class SOFTWARE_STATE(Enum):
""" """O/S or Service state enumeration."""
O/S or Service state enumeration
"""
GOOD = 1 GOOD = 1
PATCHING = 2 PATCHING = 2
COMPROMISED = 3 COMPROMISED = 3
OVERWHELMED = 4 OVERWHELMED = 4
class NODE_POL_TYPE(Enum): class NODE_POL_TYPE(Enum):
""" """Node Pattern of Life type enumeration."""
Node Pattern of Life type enumeration
"""
OPERATING = 1 OPERATING = 1
OS = 2 OS = 2
SERVICE = 3 SERVICE = 3
FILE = 4 FILE = 4
class NODE_POL_INITIATOR(Enum): class NODE_POL_INITIATOR(Enum):
""" """Node Pattern of Life initiator enumeration."""
Node Pattern of Life initiator enumeration
"""
DIRECT = 1 DIRECT = 1
IER = 2 IER = 2
SERVICE = 3 SERVICE = 3
class PROTOCOL(Enum): class PROTOCOL(Enum):
""" """Service protocol enumeration."""
Service protocol enumeration
"""
LDAP = 0 LDAP = 0
FTP = 1 FTP = 1
@@ -84,18 +75,16 @@ class PROTOCOL(Enum):
TCP = 6 TCP = 6
NONE = 7 NONE = 7
class ACTION_TYPE(Enum): class ACTION_TYPE(Enum):
""" """Action type enumeration."""
Action type enumeration
"""
NODE = 0 NODE = 0
ACL = 1 ACL = 1
class FILE_SYSTEM_STATE(Enum): class FILE_SYSTEM_STATE(Enum):
""" """File System State."""
File System State
"""
GOOD = 1 GOOD = 1
CORRUPT = 2 CORRUPT = 2

View File

@@ -1,59 +1,47 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """The protocol class."""
The protocol class
"""
class Protocol(object): class Protocol(object):
""" """Protocol class."""
Protocol class
"""
def __init__(self, _name): def __init__(self, _name):
""" """
Init Init.
Args: Args:
_name: The protocol name _name: The protocol name
""" """
self.name = _name self.name = _name
self.load = 0 # bps self.load = 0 # bps
def get_name(self): def get_name(self):
""" """
Gets the protocol name Gets the protocol name.
Returns: Returns:
The protocol name The protocol name
""" """
return self.name return self.name
def get_load(self): def get_load(self):
""" """
Gets the protocol load Gets the protocol load.
Returns: Returns:
The protocol load (bps) The protocol load (bps)
""" """
return self.load return self.load
def add_load(self, _load): def add_load(self, _load):
""" """
Adds load to the protocol Adds load to the protocol.
Args: Args:
_load: The load to add _load: The load to add
""" """
self.load += _load self.load += _load
def clear_load(self): def clear_load(self):
""" """Clears the load on this protocol."""
Clears the load on this protocol
"""
self.load = 0 self.load = 0

View File

@@ -1,25 +1,21 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """The Service class."""
The Service class
"""
from primaite.common.enums import SOFTWARE_STATE from primaite.common.enums import SOFTWARE_STATE
class Service(object): class Service(object):
""" """Service class."""
Service class
"""
def __init__(self, _name, _port, _state): def __init__(self, _name, _port, _state):
""" """
Init Init.
Args: Args:
_name: The service name _name: The service name
_port: The service port _port: The service port
_state: The service state _state: The service state
""" """
self.name = _name self.name = _name
self.port = _port self.port = _port
self.state = _state self.state = _state
@@ -27,74 +23,61 @@ class Service(object):
def set_name(self, _name): def set_name(self, _name):
""" """
Sets the service name Sets the service name.
Args: Args:
_name: The service name _name: The service name
""" """
self.name = _name self.name = _name
def get_name(self): def get_name(self):
""" """
Gets the service name Gets the service name.
Returns: Returns:
The service name The service name
""" """
return self.name return self.name
def set_port(self, _port): def set_port(self, _port):
""" """
Sets the service port Sets the service port.
Args: Args:
_port: The service port _port: The service port
""" """
self.port = _port self.port = _port
def get_port(self): def get_port(self):
""" """
Gets the service port Gets the service port.
Returns: Returns:
The service port The service port
""" """
return self.port return self.port
def set_state(self, _state): def set_state(self, _state):
""" """
Sets the service state Sets the service state.
Args: Args:
_state: The service state _state: The service state
""" """
self.state = _state self.state = _state
def get_state(self): def get_state(self):
""" """
Gets the service state Gets the service state.
Returns: Returns:
The service state The service state
""" """
return self.state return self.state
def reduce_patching_count(self): def reduce_patching_count(self):
""" """Reduces the patching count for the service."""
Reduces the patching count for the service
"""
self.patching_count -= 1 self.patching_count -= 1
if self.patching_count <= 0: if self.patching_count <= 0:
self.patching_count = 0 self.patching_count = 0
self.state = SOFTWARE_STATE.GOOD self.state = SOFTWARE_STATE.GOOD

View File

@@ -4,10 +4,10 @@
steps: 128 steps: 128
- itemType: PORTS - itemType: PORTS
portsList: portsList:
- port: '80' - port: '80'
- itemType: SERVICES - itemType: SERVICES
serviceList: serviceList:
- name: TCP - name: TCP
- itemType: NODE - itemType: NODE
id: '1' id: '1'
name: PC1 name: PC1
@@ -19,9 +19,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '2' id: '2'
name: SERVER name: SERVER
@@ -33,9 +33,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '3' id: '3'
name: PC2 name: PC2
@@ -47,9 +47,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '4' id: '4'
name: SWITCH1 name: SWITCH1

View File

@@ -4,10 +4,10 @@
steps: 128 steps: 128
- itemType: PORTS - itemType: PORTS
portsList: portsList:
- port: '80' - port: '80'
- itemType: SERVICES - itemType: SERVICES
serviceList: serviceList:
- name: TCP - name: TCP
- itemType: NODE - itemType: NODE
id: '1' id: '1'
name: PC1 name: PC1
@@ -19,9 +19,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '2' id: '2'
name: PC2 name: PC2
@@ -33,9 +33,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '3' id: '3'
name: PC3 name: PC3
@@ -47,9 +47,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '4' id: '4'
name: PC4 name: PC4
@@ -61,9 +61,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '5' id: '5'
name: SWITCH1 name: SWITCH1
@@ -85,9 +85,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '7' id: '7'
name: SWITCH2 name: SWITCH2
@@ -109,9 +109,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '9' id: '9'
name: SERVER1 name: SERVER1
@@ -123,9 +123,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '10' id: '10'
name: SERVER2 name: SERVER2
@@ -137,9 +137,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: LINK - itemType: LINK
id: '11' id: '11'
name: link1 name: link1

View File

@@ -4,10 +4,10 @@
steps: 256 steps: 256
- itemType: PORTS - itemType: PORTS
portsList: portsList:
- port: '80' - port: '80'
- itemType: SERVICES - itemType: SERVICES
serviceList: serviceList:
- name: TCP - name: TCP
- itemType: NODE - itemType: NODE
id: '1' id: '1'
name: PC1 name: PC1
@@ -19,9 +19,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '2' id: '2'
name: PC2 name: PC2
@@ -33,9 +33,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '3' id: '3'
name: SWITCH1 name: SWITCH1
@@ -57,9 +57,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: LINK - itemType: LINK
id: '5' id: '5'
name: link1 name: link1

View File

@@ -4,14 +4,14 @@
steps: 256 steps: 256
- itemType: PORTS - itemType: PORTS
portsList: portsList:
- port: '80' - port: '80'
- port: '1433' - port: '1433'
- port: '53' - port: '53'
- itemType: SERVICES - itemType: SERVICES
serviceList: serviceList:
- name: TCP - name: TCP
- name: TCP_SQL - name: TCP_SQL
- name: UDP - name: UDP
- itemType: NODE - itemType: NODE
id: '1' id: '1'
name: CLIENT_1 name: CLIENT_1
@@ -23,12 +23,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '2' id: '2'
name: CLIENT_2 name: CLIENT_2
@@ -40,9 +40,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '3' id: '3'
name: SWITCH_1 name: SWITCH_1
@@ -64,12 +64,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '5' id: '5'
name: MANAGEMENT_CONSOLE name: MANAGEMENT_CONSOLE
@@ -81,12 +81,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '6' id: '6'
name: SWITCH_2 name: SWITCH_2
@@ -108,12 +108,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: TCP_SQL - name: TCP_SQL
port: '1433' port: '1433'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '8' id: '8'
name: DATABASE_SERVER name: DATABASE_SERVER
@@ -125,15 +125,15 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: TCP_SQL - name: TCP_SQL
port: '1433' port: '1433'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '9' id: '9'
name: BACKUP_SERVER name: BACKUP_SERVER
@@ -145,9 +145,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: LINK - itemType: LINK
id: '10' id: '10'
name: LINK_1 name: LINK_1
@@ -529,5 +529,5 @@
protocol: TCP protocol: TCP
state: OVERWHELMED state: OVERWHELMED
sourceNodeId: '8' sourceNodeId: '8'
sourceNodeService: TCP_SQL sourceNodeService: TCP_SQL
sourceNodeServiceState: COMPROMISED sourceNodeServiceState: COMPROMISED

View File

@@ -4,14 +4,14 @@
steps: 256 steps: 256
- itemType: PORTS - itemType: PORTS
portsList: portsList:
- port: '80' - port: '80'
- port: '1433' - port: '1433'
- port: '53' - port: '53'
- itemType: SERVICES - itemType: SERVICES
serviceList: serviceList:
- name: TCP - name: TCP
- name: TCP_SQL - name: TCP_SQL
- name: UDP - name: UDP
- itemType: NODE - itemType: NODE
id: '1' id: '1'
name: CLIENT_1 name: CLIENT_1
@@ -23,12 +23,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '2' id: '2'
name: CLIENT_2 name: CLIENT_2
@@ -40,9 +40,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '3' id: '3'
name: SWITCH_1 name: SWITCH_1
@@ -64,12 +64,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '5' id: '5'
name: MANAGEMENT_CONSOLE name: MANAGEMENT_CONSOLE
@@ -81,12 +81,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '6' id: '6'
name: SWITCH_2 name: SWITCH_2
@@ -108,12 +108,12 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: TCP_SQL - name: TCP_SQL
port: '1433' port: '1433'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '8' id: '8'
name: DATABASE_SERVER name: DATABASE_SERVER
@@ -125,15 +125,15 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: TCP_SQL - name: TCP_SQL
port: '1433' port: '1433'
state: GOOD state: GOOD
- name: UDP - name: UDP
port: '53' port: '53'
state: GOOD state: GOOD
- itemType: NODE - itemType: NODE
id: '9' id: '9'
name: BACKUP_SERVER name: BACKUP_SERVER
@@ -145,9 +145,9 @@
softwareState: GOOD softwareState: GOOD
fileSystemState: GOOD fileSystemState: GOOD
services: services:
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- itemType: LINK - itemType: LINK
id: '10' id: '10'
name: LINK_1 name: LINK_1
@@ -529,5 +529,5 @@
protocol: TCP protocol: TCP
state: OVERWHELMED state: OVERWHELMED
sourceNodeId: '8' sourceNodeId: '8'
sourceNodeService: TCP_SQL sourceNodeService: TCP_SQL
sourceNodeServiceState: COMPROMISED sourceNodeServiceState: COMPROMISED

View File

@@ -21,18 +21,18 @@ agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values # Environment config values
# The high value for the observation space # The high value for the observation space
observationSpaceHighValue: 1000000000 observationSpaceHighValue: 1000000000
# Reward values # Reward values
# Generic # Generic
allOk: 0 allOk: 0
# Node Operating State # Node Operating State
offShouldBeOn: -10 offShouldBeOn: -10
offShouldBeResetting: -5 offShouldBeResetting: -5
onShouldBeOff: -2 onShouldBeOff: -2
onShouldBeResetting: -5 onShouldBeResetting: -5
resettingShouldBeOn: -5 resettingShouldBeOn: -5
resettingShouldBeOff: -2 resettingShouldBeOff: -2
resetting: -3 resetting: -3
# Node O/S or Service State # Node O/S or Service State
goodShouldBePatching: 2 goodShouldBePatching: 2
@@ -45,7 +45,7 @@ patching: -3
compromisedShouldBeGood: -20 compromisedShouldBeGood: -20
compromisedShouldBePatching: -20 compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20 compromisedShouldBeOverwhelmed: -20
compromised: -20 compromised: -20
overwhelmedShouldBeGood: -20 overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20 overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20 overwhelmedShouldBeCompromised: -20
@@ -62,7 +62,7 @@ repairingShouldBeDestroyed: 0
repairing: -3 repairing: -3
restoringShouldBeGood: -10 restoringShouldBeGood: -10
restoringShouldBeRepairing: -2 restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1 restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2 restoringShouldBeDestroyed: 2
restoring: -6 restoring: -6
corruptShouldBeGood: -10 corruptShouldBeGood: -10

View File

@@ -1,2 +1 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.

View File

@@ -1,39 +1,45 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
Main environment module containing the PRIMmary AI Training Evironment (Primaite) class
"""
import numpy as np
import networkx as nx
import copy import copy
import csv import csv
import yaml
import os.path
import logging import logging
import os.path
from gym import Env, spaces
from matplotlib import pyplot as plt
from datetime import datetime from datetime import datetime
from primaite.common.enums import * import networkx as nx
import numpy as np
import yaml
from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.common.enums import (
ACTION_TYPE,
FILE_SYSTEM_STATE,
HARDWARE_STATE,
NODE_POL_INITIATOR,
NODE_POL_TYPE,
PRIORITY,
SOFTWARE_STATE,
TYPE,
)
from primaite.common.service import Service
from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link from primaite.links.link import Link
from primaite.pol.ier import IER from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.passive_node import PassiveNode from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode from primaite.nodes.service_node import ServiceNode
from primaite.common.service import Service from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.acl.access_control_list import AccessControlList from primaite.pol.ier import IER
from primaite.environment.reward import calculate_reward_function from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction from primaite.transactions.transaction import Transaction
class Primaite(Env): class Primaite(Env):
""" """PRIMmary AI Training Evironment (Primaite) class."""
PRIMmary AI Training Evironment (Primaite) class
"""
# Observation / Action Space contants # Observation / Action Space contants
OBSERVATION_SPACE_FIXED_PARAMETERS = 4 OBSERVATION_SPACE_FIXED_PARAMETERS = 4
@@ -42,11 +48,11 @@ class Primaite(Env):
ACTION_SPACE_ACL_ACTION_VALUES = 3 ACTION_SPACE_ACL_ACTION_VALUES = 3
ACTION_SPACE_ACL_PERMISSION_VALUES = 2 ACTION_SPACE_ACL_PERMISSION_VALUES = 2
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
def __init__(self, _config_values, _transaction_list): def __init__(self, _config_values, _transaction_list):
""" """
Init Init.
Args: Args:
_episode_steps: The number of steps for the episode _episode_steps: The number of steps for the episode
@@ -54,7 +60,6 @@ class Primaite(Env):
_transaction_list: The list of transactions to populate _transaction_list: The list of transactions to populate
_agent_identifier: Identifier for the agent _agent_identifier: Identifier for the agent
""" """
super(Primaite, self).__init__() super(Primaite, self).__init__()
# Take a copy of the config values # Take a copy of the config values
@@ -140,10 +145,12 @@ class Primaite(Env):
# Open the config file and build the environment laydown # Open the config file and build the environment laydown
try: try:
self.config_file = open("config/" + self.config_values.config_filename_use_case, "r") self.config_file = open(
"config/" + self.config_values.config_filename_use_case, "r"
)
self.config_data = yaml.safe_load(self.config_file) self.config_data = yaml.safe_load(self.config_file)
self.load_config() self.load_config()
except Exception as e: except Exception:
logging.error("Could not load the environment configuration") logging.error("Could not load the environment configuration")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
@@ -162,17 +169,17 @@ class Primaite(Env):
try: try:
plt.tight_layout() plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True) nx.draw_networkx(self.network, with_labels=True)
now = datetime.now() # current date and time now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S") time = now.strftime("%Y%m%d_%H%M%S")
path = 'outputs/diagrams' path = "outputs/diagrams"
is_dir = os.path.isdir(path) is_dir = os.path.isdir(path)
if not is_dir: if not is_dir:
os.makedirs(path) os.makedirs(path)
filename = "outputs/diagrams/network_" + time + ".png" filename = "outputs/diagrams/network_" + time + ".png"
plt.savefig(filename, format="PNG") plt.savefig(filename, format="PNG")
plt.clf() plt.clf()
except Exception as a: except Exception:
logging.error("Could not save network diagram") logging.error("Could not save network diagram")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
print("Could not save network diagram") print("Could not save network diagram")
@@ -194,16 +201,22 @@ class Primaite(Env):
# - service F state | service F loading # - service F state | service F loading
# - service G state | service G loading # - service G state | service G loading
# Calculate the number of items that need to be included in the observation space # Calculate the number of items that need to be included in the
# observation space
num_items = self.num_links + self.num_nodes num_items = self.num_links + self.num_nodes
# Set the number of observation parameters, being # of services plus id, operating state, file system state and O/S state (i.e. 4) # Set the number of observation parameters, being # of services plus id,
self.num_observation_parameters = self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS # operating state, file system state and O/S state (i.e. 4)
self.num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
# Define the observation shape # Define the observation shape
self.observation_shape = (num_items, self.num_observation_parameters) self.observation_shape = (num_items, self.num_observation_parameters)
self.observation_space = spaces.Box(low=0, self.observation_space = spaces.Box(
high=self.config_values.observation_space_high_value, low=0,
shape=self.observation_shape, high=self.config_values.observation_space_high_value,
dtype=np.int64) shape=self.observation_shape,
dtype=np.int64,
)
# This is the observation that is sent back via the rest and step functions # This is the observation that is sent back via the rest and step functions
self.env_obs = np.zeros(self.observation_shape, dtype=np.int64) self.env_obs = np.zeros(self.observation_shape, dtype=np.int64)
@@ -216,7 +229,14 @@ class Primaite(Env):
# [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state) # [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state)
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore)
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # [0, num services] - resolves to service ID (0 = nothing, resolves to service)
self.action_space = spaces.MultiDiscrete([self.num_nodes, self.ACTION_SPACE_NODE_PROPERTY_VALUES, self.ACTION_SPACE_NODE_ACTION_VALUES, self.num_services]) self.action_space = spaces.MultiDiscrete(
[
self.num_nodes,
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
self.ACTION_SPACE_NODE_ACTION_VALUES,
self.num_services,
]
)
else: else:
logging.info("Action space type ACL selected") logging.info("Action space type ACL selected")
# Terms (for ACL action space): # Terms (for ACL action space):
@@ -226,42 +246,52 @@ class Primaite(Env):
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol) # [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port) # [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_space = spaces.MultiDiscrete([self.ACTION_SPACE_ACL_ACTION_VALUES, self.ACTION_SPACE_ACL_PERMISSION_VALUES, self.num_nodes + 1, self.num_nodes + 1, self.num_services + 1, self.num_ports + 1]) self.action_space = spaces.MultiDiscrete(
[
self.ACTION_SPACE_ACL_ACTION_VALUES,
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
self.num_nodes + 1,
self.num_nodes + 1,
self.num_services + 1,
self.num_ports + 1,
]
)
# Set up a csv to store the results of the training # Set up a csv to store the results of the training
try: try:
now = datetime.now() # current date and time now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S") time = now.strftime("%Y%m%d_%H%M%S")
header = ['Episode', 'Average Reward'] header = ["Episode", "Average Reward"]
# Check whether the output/rerults folder exists (doesn't exist by default install) # Check whether the output/rerults folder exists (doesn't exist by default install)
path = 'outputs/results/' path = "outputs/results/"
is_dir = os.path.isdir(path) is_dir = os.path.isdir(path)
if not is_dir: if not is_dir:
os.makedirs(path) os.makedirs(path)
filename = "outputs/results/average_reward_per_episode_" + time + ".csv" filename = "outputs/results/average_reward_per_episode_" + time + ".csv"
self.csv_file = open(filename, 'w', encoding='UTF8', newline='') self.csv_file = open(filename, "w", encoding="UTF8", newline="")
self.csv_writer = csv.writer(self.csv_file) self.csv_writer = csv.writer(self.csv_file)
self.csv_writer.writerow(header) self.csv_writer.writerow(header)
except Exception as e: except Exception:
logging.error("Could not create csv file to hold average reward per episode") logging.error(
"Could not create csv file to hold average reward per episode"
)
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
def reset(self): def reset(self):
""" """
AI Gym Reset function AI Gym Reset function.
Returns: Returns:
Environment observation space (reset) Environment observation space (reset)
""" """
csv_data = self.episode_count, self.average_reward csv_data = self.episode_count, self.average_reward
self.csv_writer.writerow(csv_data) self.csv_writer.writerow(csv_data)
self.episode_count += 1 self.episode_count += 1
# Don't need to reset links, as they are cleared and recalculated every step # Don't need to reset links, as they are cleared and recalculated every step
# Clear the ACL # Clear the ACL
self.init_acl() self.init_acl()
@@ -280,7 +310,7 @@ class Primaite(Env):
def step(self, action): def step(self, action):
""" """
AI Gym Step function AI Gym Step function.
Args: Args:
action: Action space from agent action: Action space from agent
@@ -291,7 +321,6 @@ class Primaite(Env):
done: Indicates episode is complete if True done: Indicates episode is complete if True
step_info: Additional information relating to this step step_info: Additional information relating to this step
""" """
if self.step_count == 0: if self.step_count == 0:
print("Episode: " + str(self.episode_count) + " running") print("Episode: " + str(self.episode_count) + " running")
@@ -299,14 +328,16 @@ class Primaite(Env):
done = False done = False
self.step_count += 1 self.step_count += 1
#print("Episode step: " + str(self.stepCount)) # print("Episode step: " + str(self.stepCount))
# Need to clear traffic on all links first # Need to clear traffic on all links first
for link_key, link_value in self.links.items(): for link_key, link_value in self.links.items():
link_value.clear_traffic() link_value.clear_traffic()
# Create a Transaction (metric) object for this step # Create a Transaction (metric) object for this step
transaction = Transaction(datetime.now(), self.agent_identifier, self.episode_count, self.step_count) transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
)
# Load the initial observation space into the transaction # Load the initial observation space into the transaction
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
# Load the action space into the transaction # Load the action space into the transaction
@@ -316,50 +347,97 @@ class Primaite(Env):
self.apply_time_based_updates() self.apply_time_based_updates()
# 2. Apply PoL # 2. Apply PoL
apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL
apply_iers(self.network, self.nodes, self.links, self.green_iers, self.acl, self.step_count) # Network PoL apply_iers(
self.network,
self.nodes,
self.links,
self.green_iers,
self.acl,
self.step_count,
) # Network PoL
# Take snapshots of nodes and links # Take snapshots of nodes and links
self.nodes_post_pol = copy.deepcopy(self.nodes) self.nodes_post_pol = copy.deepcopy(self.nodes)
self.links_post_pol = copy.deepcopy(self.links) self.links_post_pol = copy.deepcopy(self.links)
# Reference # Reference
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
apply_iers(self.network_reference, self.nodes_reference, self.links_reference, self.green_iers, self.acl, self.step_count) # Network PoL apply_iers(
self.network_reference,
self.nodes_reference,
self.links_reference,
self.green_iers,
self.acl,
self.step_count,
) # Network PoL
# 3. Implement Red Action # 3. Implement Red Action
apply_red_agent_iers(self.network, self.nodes, self.links, self.red_iers, self.acl, self.step_count) apply_red_agent_iers(
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count) self.network,
self.nodes,
self.links,
self.red_iers,
self.acl,
self.step_count,
)
apply_red_agent_node_pol(
self.nodes, self.red_iers, self.red_node_pol, self.step_count
)
# Take snapshots of nodes and links # Take snapshots of nodes and links
self.nodes_post_red = copy.deepcopy(self.nodes) self.nodes_post_red = copy.deepcopy(self.nodes)
self.links_post_red = copy.deepcopy(self.links) self.links_post_red = copy.deepcopy(self.links)
# 4. Implement Blue Action # 4. Implement Blue Action
self.interpret_action_and_apply(action) self.interpret_action_and_apply(action)
# 5. Reapply normal and Red agent IER PoL, as we need to see what effect the blue agent action has had (if any) on link status # 5. Reapply normal and Red agent IER PoL, as we need to see what
# effect the blue agent action has had (if any) on link status
# Need to clear traffic on all links first # Need to clear traffic on all links first
for link_key, link_value in self.links.items(): for link_key, link_value in self.links.items():
link_value.clear_traffic() link_value.clear_traffic()
apply_iers(self.network, self.nodes, self.links, self.green_iers, self.acl, self.step_count) apply_iers(
apply_red_agent_iers(self.network, self.nodes, self.links, self.red_iers, self.acl, self.step_count) self.network,
self.nodes,
self.links,
self.green_iers,
self.acl,
self.step_count,
)
apply_red_agent_iers(
self.network,
self.nodes,
self.links,
self.red_iers,
self.acl,
self.step_count,
)
# Take snapshots of nodes and links # Take snapshots of nodes and links
self.nodes_post_blue = copy.deepcopy(self.nodes) self.nodes_post_blue = copy.deepcopy(self.nodes)
self.links_post_blue = copy.deepcopy(self.links) self.links_post_blue = copy.deepcopy(self.links)
# 6. Calculate reward signal (for RL) # 6. Calculate reward signal (for RL)
reward = calculate_reward_function(self.nodes_post_pol, self.nodes_post_blue, self.nodes_reference, self.green_iers, self.red_iers, self.step_count, self.config_values) reward = calculate_reward_function(
#print("Step reward: " + str(reward)) self.nodes_post_pol,
self.nodes_post_blue,
self.nodes_reference,
self.green_iers,
self.red_iers,
self.step_count,
self.config_values,
)
# print("Step reward: " + str(reward))
self.total_reward += reward self.total_reward += reward
if self.step_count == self.episode_steps: if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count self.average_reward = self.total_reward / self.step_count
if self.config_values.session_type == "EVALUATION": if self.config_values.session_type == "EVALUATION":
# For evaluation, need to trigger the done value = True when step count is reached in order to prevent neverending episode # For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True done = True
print("Average reward: " + str(self.average_reward)) print("Average reward: " + str(self.average_reward))
# Load the reward into the transaction # Load the reward into the transaction
transaction.set_reward(reward) transaction.set_reward(reward)
# 7. Output Verbose # 7. Output Verbose
#self.output_link_status() # self.output_link_status()
# 8. Update env_obs # 8. Update env_obs
self.update_environent_obs() self.update_environent_obs()
@@ -373,38 +451,33 @@ class Primaite(Env):
return self.env_obs, reward, done, self.step_info return self.env_obs, reward, done, self.step_info
def __close__(self): def __close__(self):
""" """Override close function."""
Override close function
"""
self.csv_file.close() self.csv_file.close()
self.config_file.close() self.config_file.close()
def init_acl(self): def init_acl(self):
""" """Initialise the Access Control List."""
Initialise the Access Control List self.acl.remove_all_rules()
"""
self.acl.remove_all_rules()
def output_link_status(self): def output_link_status(self):
""" """Output the link status of all links to the console."""
Output the link status of all links to the console
"""
for link_key, link_value in self.links.items(): for link_key, link_value in self.links.items():
print("Link ID: " + link_value.get_id()) print("Link ID: " + link_value.get_id())
for protocol in link_value.get_protocol_list(): for protocol in link_value.get_protocol_list():
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) print(
" Protocol: "
+ protocol.get_name().name
+ ", Load: "
+ str(protocol.get_load())
)
def interpret_action_and_apply(self, _action): def interpret_action_and_apply(self, _action):
""" """
Applies agent actions to the nodes and Access Control List Applies agent actions to the nodes and Access Control List.
Args: Args:
_action: The action space from the agent _action: The action space from the agent
""" """
# At the moment, actions are only affecting nodes # At the moment, actions are only affecting nodes
if self.action_type == ACTION_TYPE.NODE: if self.action_type == ACTION_TYPE.NODE:
self.apply_actions_to_nodes(_action) self.apply_actions_to_nodes(_action)
@@ -413,12 +486,11 @@ class Primaite(Env):
def apply_actions_to_nodes(self, _action): def apply_actions_to_nodes(self, _action):
""" """
Applies agent actions to the nodes Applies agent actions to the nodes.
Args: Args:
_action: The action space from the agent _action: The action space from the agent
""" """
node_id = _action[0] node_id = _action[0]
node_property = _action[1] node_property = _action[1]
property_action = _action[2] property_action = _action[2]
@@ -427,7 +499,7 @@ class Primaite(Env):
# Check that the action is requesting a valid node # Check that the action is requesting a valid node
try: try:
node = self.nodes[str(node_id)] node = self.nodes[str(node_id)]
except: except Exception:
return return
if node_property == 0: if node_property == 0:
@@ -472,7 +544,9 @@ class Primaite(Env):
return return
elif property_action == 1: elif property_action == 1:
# Patch (valid action if it's good or compromised) # Patch (valid action if it's good or compromised)
node.set_service_state(self.services_list[service_index], SOFTWARE_STATE.PATCHING) node.set_service_state(
self.services_list[service_index], SOFTWARE_STATE.PATCHING
)
else: else:
# Node is not of Service Type # Node is not of Service Type
return return
@@ -488,7 +562,10 @@ class Primaite(Env):
elif property_action == 2: elif property_action == 2:
# Repair # Repair
# You cannot repair a destroyed file system - it needs restoring # You cannot repair a destroyed file system - it needs restoring
if node.get_file_system_state_actual() != FILE_SYSTEM_STATE.DESTROYED: if (
node.get_file_system_state_actual()
!= FILE_SYSTEM_STATE.DESTROYED
):
node.set_file_system_state(FILE_SYSTEM_STATE.REPAIRING) node.set_file_system_state(FILE_SYSTEM_STATE.REPAIRING)
elif property_action == 3: elif property_action == 3:
# Restore # Restore
@@ -501,12 +578,11 @@ class Primaite(Env):
def apply_actions_to_acl(self, _action): def apply_actions_to_acl(self, _action):
""" """
Applies agent actions to the Access Control List [TO DO] Applies agent actions to the Access Control List [TO DO].
Args: Args:
_action: The action space from the agent _action: The action space from the agent
""" """
action_decision = _action[0] action_decision = _action[0]
action_permission = _action[1] action_permission = _action[1]
action_source_ip = _action[2] action_source_ip = _action[2]
@@ -517,7 +593,7 @@ class Primaite(Env):
if action_decision == 0: if action_decision == 0:
# It's decided to do nothing # It's decided to do nothing
return return
else: else:
# It's decided to create a new ACL rule or remove an existing rule # It's decided to create a new ACL rule or remove an existing rule
# Permission value # Permission value
if action_permission == 0: if action_permission == 0:
@@ -556,18 +632,31 @@ class Primaite(Env):
# Now add or remove # Now add or remove
if action_decision == 1: if action_decision == 1:
# Add the rule # Add the rule
self.acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port) self.acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
)
elif action_decision == 2: elif action_decision == 2:
# Remove the rule # Remove the rule
self.acl.remove_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port) self.acl.remove_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
)
else: else:
return return
def apply_time_based_updates(self): def apply_time_based_updates(self):
""" """
Updates anything that needs to count down and then change state (e.g. reset / patching status) Updates anything that needs to count down and then change state.
e.g. reset / patching status
""" """
for node_key, node in self.nodes.items(): for node_key, node in self.nodes.items():
if node.get_state() == HARDWARE_STATE.RESETTING: if node.get_state() == HARDWARE_STATE.RESETTING:
node.update_resetting_status() node.update_resetting_status()
@@ -605,10 +694,7 @@ class Primaite(Env):
pass pass
def update_environent_obs(self): def update_environent_obs(self):
""" """Updates the observation space based on the node and link status."""
# Updates the observation space based on the node and link status
"""
item_index = 0 item_index = 0
# Do nodes first # Do nodes first
@@ -617,15 +703,19 @@ class Primaite(Env):
self.env_obs[item_index][1] = node.get_state().value self.env_obs[item_index][1] = node.get_state().value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.env_obs[item_index][2] = node.get_os_state().value self.env_obs[item_index][2] = node.get_os_state().value
self.env_obs[item_index][3] = node.get_file_system_state_observed().value self.env_obs[item_index][
3
] = node.get_file_system_state_observed().value
else: else:
self.env_obs[item_index][2] = 0 self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0 self.env_obs[item_index][3] = 0
service_index = 4 service_index = 4
if isinstance(node, ServiceNode): if isinstance(node, ServiceNode):
for service in self.services_list: for service in self.services_list:
if node.has_service(service): if node.has_service(service):
self.env_obs[item_index][service_index] = node.get_service_state(service).value self.env_obs[item_index][
service_index
] = node.get_service_state(service).value
else: else:
self.env_obs[item_index][service_index] = 0 self.env_obs[item_index][service_index] = 0
service_index += 1 service_index += 1
@@ -650,17 +740,14 @@ class Primaite(Env):
item_index += 1 item_index += 1
def load_config(self): def load_config(self):
""" """Loads config data in order to build the environment configuration."""
# Loads config data in order to build the environment configuration
"""
for item in self.config_data: for item in self.config_data:
if item["itemType"] == "NODE": if item["itemType"] == "NODE":
# Create a node # Create a node
self.create_node(item) self.create_node(item)
elif item["itemType"] == "LINK": elif item["itemType"] == "LINK":
# Create a link # Create a link
self.create_link(item) self.create_link(item)
elif item["itemType"] == "GREEN_IER": elif item["itemType"] == "GREEN_IER":
# Create a Green IER # Create a Green IER
self.create_green_ier(item) self.create_green_ier(item)
@@ -697,12 +784,11 @@ class Primaite(Env):
def create_node(self, item): def create_node(self, item):
""" """
Creates a node from config data Creates a node from config data.
Args: Args:
item: A config data item item: A config data item
""" """
# All nodes have these parameters # All nodes have these parameters
node_id = item["id"] node_id = item["id"]
node_name = item["name"] node_name = item["name"]
@@ -712,19 +798,46 @@ class Primaite(Env):
node_hardware_state = HARDWARE_STATE[item["hardwareState"]] node_hardware_state = HARDWARE_STATE[item["hardwareState"]]
if node_base_type == "PASSIVE": if node_base_type == "PASSIVE":
node = PassiveNode(node_id, node_name, node_type, node_priority, node_hardware_state, self.config_values) node = PassiveNode(
node_id,
node_name,
node_type,
node_priority,
node_hardware_state,
self.config_values,
)
elif node_base_type == "ACTIVE": elif node_base_type == "ACTIVE":
# Active nodes have IP address, operating system state and file system state # Active nodes have IP address, operating system state and file system state
node_ip_address = item["ipAddress"] node_ip_address = item["ipAddress"]
node_software_state = SOFTWARE_STATE[item["softwareState"]] node_software_state = SOFTWARE_STATE[item["softwareState"]]
node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]] node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]]
node = ActiveNode(node_id, node_name, node_type, node_priority, node_hardware_state, node_ip_address, node_software_state, node_file_system_state, self.config_values) node = ActiveNode(
node_id,
node_name,
node_type,
node_priority,
node_hardware_state,
node_ip_address,
node_software_state,
node_file_system_state,
self.config_values,
)
elif node_base_type == "SERVICE": elif node_base_type == "SERVICE":
# Service nodes have IP address, operating system state, file system state and list of services # Service nodes have IP address, operating system state, file system state and list of services
node_ip_address = item["ipAddress"] node_ip_address = item["ipAddress"]
node_software_state = SOFTWARE_STATE[item["softwareState"]] node_software_state = SOFTWARE_STATE[item["softwareState"]]
node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]] node_file_system_state = FILE_SYSTEM_STATE[item["fileSystemState"]]
node = ServiceNode(node_id, node_name, node_type, node_priority, node_hardware_state, node_ip_address, node_software_state, node_file_system_state, self.config_values) node = ServiceNode(
node_id,
node_name,
node_type,
node_priority,
node_hardware_state,
node_ip_address,
node_software_state,
node_file_system_state,
self.config_values,
)
node_services = item["services"] node_services = item["services"]
for service in node_services: for service in node_services:
service_protocol = service["name"] service_protocol = service["name"]
@@ -752,12 +865,11 @@ class Primaite(Env):
def create_link(self, item): def create_link(self, item):
""" """
Creates a link from config data Creates a link from config data.
Args: Args:
item: A config data item item: A config data item
""" """
link_id = item["id"] link_id = item["id"]
link_name = item["name"] link_name = item["name"]
link_bandwidth = item["bandwidth"] link_bandwidth = item["bandwidth"]
@@ -771,7 +883,13 @@ class Primaite(Env):
self.network.add_edge(source_node, dest_node, id=link_name) self.network.add_edge(source_node, dest_node, id=link_name)
# Add link to link dictionary # Add link to link dictionary
self.links[link_name] = Link(link_id, link_bandwidth, source_node.get_name(), dest_node.get_name(), self.services_list) self.links[link_name] = Link(
link_id,
link_bandwidth,
source_node.get_name(),
dest_node.get_name(),
self.services_list,
)
# Reference # Reference
source_node_ref = self.nodes_reference[link_source] source_node_ref = self.nodes_reference[link_source]
@@ -781,16 +899,21 @@ class Primaite(Env):
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name) self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
# Add link to link dictionary (reference) # Add link to link dictionary (reference)
self.links_reference[link_name] = Link(link_id, link_bandwidth, source_node_ref.get_name(), dest_node_ref.get_name(), self.services_list) self.links_reference[link_name] = Link(
link_id,
link_bandwidth,
source_node_ref.get_name(),
dest_node_ref.get_name(),
self.services_list,
)
def create_green_ier(self, item): def create_green_ier(self, item):
""" """
Creates a green IER from config data Creates a green IER from config data.
Args: Args:
item: A config data item item: A config data item
""" """
ier_id = item["id"] ier_id = item["id"]
ier_start_step = item["startStep"] ier_start_step = item["startStep"]
ier_end_step = item["endStep"] ier_end_step = item["endStep"]
@@ -802,16 +925,25 @@ class Primaite(Env):
ier_mission_criticality = item["missionCriticality"] ier_mission_criticality = item["missionCriticality"]
# Create IER and add to green IER dictionary # Create IER and add to green IER dictionary
self.green_iers[ier_id] = IER(ier_id, ier_start_step, ier_end_step, ier_load, ier_protocol, ier_port, ier_source, ier_destination, ier_mission_criticality) self.green_iers[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
ier_source,
ier_destination,
ier_mission_criticality,
)
def create_red_ier(self, item): def create_red_ier(self, item):
""" """
Creates a red IER from config data Creates a red IER from config data.
Args: Args:
item: A config data item item: A config data item
""" """
ier_id = item["id"] ier_id = item["id"]
ier_start_step = item["startStep"] ier_start_step = item["startStep"]
ier_end_step = item["endStep"] ier_end_step = item["endStep"]
@@ -823,21 +955,30 @@ class Primaite(Env):
ier_mission_criticality = item["missionCriticality"] ier_mission_criticality = item["missionCriticality"]
# Create IER and add to red IER dictionary # Create IER and add to red IER dictionary
self.red_iers[ier_id] = IER(ier_id, ier_start_step, ier_end_step, ier_load, ier_protocol, ier_port, ier_source, ier_destination, ier_mission_criticality) self.red_iers[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
ier_source,
ier_destination,
ier_mission_criticality,
)
def create_green_pol(self, item): def create_green_pol(self, item):
""" """
Creates a green PoL object from config data Creates a green PoL object from config data.
Args: Args:
item: A config data item item: A config data item
""" """
pol_id = item["id"] pol_id = item["id"]
pol_start_step = item["startStep"] pol_start_step = item["startStep"]
pol_end_step = item["endStep"] pol_end_step = item["endStep"]
pol_node = item["nodeId"] pol_node = item["nodeId"]
pol_type = NODE_POL_TYPE[item["type"]] pol_type = NODE_POL_TYPE[item["type"]]
# State depends on whether this is Operating, O/S, file system or Service PoL type # State depends on whether this is Operating, O/S, file system or Service PoL type
if pol_type == NODE_POL_TYPE.OPERATING: if pol_type == NODE_POL_TYPE.OPERATING:
@@ -850,16 +991,23 @@ class Primaite(Env):
pol_protocol = item["protocol"] pol_protocol = item["protocol"]
pol_state = SOFTWARE_STATE[item["state"]] pol_state = SOFTWARE_STATE[item["state"]]
self.node_pol[pol_id] = NodeStateInstructionGreen(pol_id, pol_start_step, pol_end_step, pol_node, pol_type, pol_protocol, pol_state) self.node_pol[pol_id] = NodeStateInstructionGreen(
pol_id,
pol_start_step,
pol_end_step,
pol_node,
pol_type,
pol_protocol,
pol_state,
)
def create_red_pol(self, item): def create_red_pol(self, item):
""" """
Creates a red PoL object from config data Creates a red PoL object from config data.
Args: Args:
item: A config data item item: A config data item
""" """
pol_id = item["id"] pol_id = item["id"]
pol_start_step = item["startStep"] pol_start_step = item["startStep"]
pol_end_step = item["endStep"] pol_end_step = item["endStep"]
@@ -880,32 +1028,48 @@ class Primaite(Env):
pol_source_node_service = item["sourceNodeService"] pol_source_node_service = item["sourceNodeService"]
pol_source_node_service_state = item["sourceNodeServiceState"] pol_source_node_service_state = item["sourceNodeServiceState"]
self.red_node_pol[pol_id] = NodeStateInstructionRed(pol_id, pol_start_step, pol_end_step, pol_target_node_id, pol_initiator, pol_type, pol_protocol, pol_state, pol_source_node_id, pol_source_node_service, pol_source_node_service_state) self.red_node_pol[pol_id] = NodeStateInstructionRed(
pol_id,
pol_start_step,
pol_end_step,
pol_target_node_id,
pol_initiator,
pol_type,
pol_protocol,
pol_state,
pol_source_node_id,
pol_source_node_service,
pol_source_node_service_state,
)
def create_acl_rule(self, item): def create_acl_rule(self, item):
""" """
Creates an ACL rule from config data Creates an ACL rule from config data.
Args: Args:
item: A config data item item: A config data item
""" """
acl_rule_permission = item["permission"] acl_rule_permission = item["permission"]
acl_rule_source = item["source"] acl_rule_source = item["source"]
acl_rule_destination = item["destination"] acl_rule_destination = item["destination"]
acl_rule_protocol = item["protocol"] acl_rule_protocol = item["protocol"]
acl_rule_port = item["port"] acl_rule_port = item["port"]
self.acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port) self.acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
)
def create_services_list(self, services): def create_services_list(self, services):
""" """
Creates a list of services (enum) from config data Creates a list of services (enum) from config data.
Args: Args:
item: A config data item representing the services item: A config data item representing the services
""" """
service_list = services["serviceList"] service_list = services["serviceList"]
for service in service_list: for service in service_list:
@@ -917,12 +1081,11 @@ class Primaite(Env):
def create_ports_list(self, ports): def create_ports_list(self, ports):
""" """
Creates a list of ports from config data Creates a list of ports from config data.
Args: Args:
item: A config data item representing the ports item: A config data item representing the ports
""" """
ports_list = ports["portsList"] ports_list = ports["portsList"]
for port in ports_list: for port in ports_list:
@@ -934,35 +1097,34 @@ class Primaite(Env):
def get_action_info(self, action_info): def get_action_info(self, action_info):
""" """
Extracts action_info Extracts action_info.
Args: Args:
item: A config data item representing action info item: A config data item representing action info
""" """
self.action_type = ACTION_TYPE[action_info["type"]] self.action_type = ACTION_TYPE[action_info["type"]]
def get_steps_info(self, steps_info): def get_steps_info(self, steps_info):
""" """
Extracts steps_info Extracts steps_info.
Args: Args:
item: A config data item representing steps info item: A config data item representing steps info
""" """
self.episode_steps = int(steps_info["steps"]) self.episode_steps = int(steps_info["steps"])
logging.info("Training episodes have " + str(self.episode_steps) + " steps") logging.info("Training episodes have " + str(self.episode_steps) + " steps")
def reset_environment(self): def reset_environment(self):
""" """
# Resets environment using config data config data in order to build the environment configuration # Resets environment.
"""
Uses config data config data in order to build the environment
configuration.
"""
for item in self.config_data: for item in self.config_data:
if item["itemType"] == "NODE": if item["itemType"] == "NODE":
# Reset a node's state (normal and reference) # Reset a node's state (normal and reference)
self.reset_node(item) self.reset_node(item)
elif item["itemType"] == "ACL_RULE": elif item["itemType"] == "ACL_RULE":
# Create an ACL rule (these are cleared on reset, so just need to recreate them) # Create an ACL rule (these are cleared on reset, so just need to recreate them)
self.create_acl_rule(item) self.create_acl_rule(item)
@@ -970,7 +1132,6 @@ class Primaite(Env):
# Do nothing (bad formatting or not relevant to reset) # Do nothing (bad formatting or not relevant to reset)
pass pass
# Reset the IER status so they are not running initially # Reset the IER status so they are not running initially
# Green IERs # Green IERs
for ier_key, ier_value in self.green_iers.items(): for ier_key, ier_value in self.green_iers.items():
@@ -981,12 +1142,11 @@ class Primaite(Env):
def reset_node(self, item): def reset_node(self, item):
""" """
Resets the statuses of a node Resets the statuses of a node.
Args: Args:
item: A config data item item: A config data item
""" """
# All nodes have these parameters # All nodes have these parameters
node_id = item["id"] node_id = item["id"]
node_base_type = item["baseType"] node_base_type = item["baseType"]
@@ -1027,10 +1187,3 @@ class Primaite(Env):
else: else:
# Bad formatting # Bad formatting
pass pass

View File

@@ -1,15 +1,21 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Implements reward function."""
Implements reward function from primaite.common.enums import FILE_SYSTEM_STATE, HARDWARE_STATE, SOFTWARE_STATE
"""
from primaite.common.enums import *
from primaite.nodes.active_node import ActiveNode from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode from primaite.nodes.service_node import ServiceNode
def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green_iers, red_iers, step_count, config_values):
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
red_iers,
step_count,
config_values,
):
""" """
Compares the states of the initial and final nodes/links to get a reward Compares the states of the initial and final nodes/links to get a reward.
Args: Args:
initial_nodes: The nodes before red and blue agents take effect initial_nodes: The nodes before red and blue agents take effect
@@ -20,29 +26,36 @@ def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green
step_count: current step step_count: current step
config_values: Config values config_values: Config values
""" """
reward_value = 0 reward_value = 0
# For each node, compare operating state, o/s operating state, service states # For each node, compare operating state, o/s operating state, service states
for node_key, final_node in final_nodes.items(): for node_key, final_node in final_nodes.items():
initial_node = initial_nodes[node_key] initial_node = initial_nodes[node_key]
reference_node = reference_nodes[node_key] reference_node = reference_nodes[node_key]
# Operating State # Operating State
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values) reward_value += score_node_operating_state(
final_node, initial_node, reference_node, config_values
)
# Operating System State # Operating System State
if (isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode)): if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values) reward_value += score_node_os_state(
final_node, initial_node, reference_node, config_values
)
# Service State # Service State
if (isinstance(final_node, ServiceNode)): if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values) reward_value += score_node_service_state(
final_node, initial_node, reference_node, config_values
)
# File System State # File System State
if isinstance(final_node, ActiveNode): if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values) reward_value += score_node_file_system(
final_node, initial_node, reference_node, config_values
)
# Go through each red IER - penalise if it is running # Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items(): for ier_key, ier_value in red_iers.items():
start_step = ier_value.get_start_step() start_step = ier_value.get_start_step()
@@ -57,14 +70,17 @@ def calculate_reward_function(initial_nodes, final_nodes, reference_nodes, green
stop_step = ier_value.get_end_step() stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step: if step_count >= start_step and step_count <= stop_step:
if not ier_value.get_is_running(): if not ier_value.get_is_running():
reward_value += config_values.green_ier_blocked * ier_value.get_mission_criticality() reward_value += (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
)
return reward_value return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values): def score_node_operating_state(final_node, initial_node, reference_node, config_values):
""" """
Calculates score relating to the operating state of a node Calculates score relating to the operating state of a node.
Args: Args:
final_node: The node after red and blue agents take effect final_node: The node after red and blue agents take effect
@@ -72,8 +88,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
config_values: Config values config_values: Config values
""" """
score = 0
score = 0
final_node_operating_state = final_node.get_state() final_node_operating_state = final_node.get_state()
initial_node_operating_state = initial_node.get_state() initial_node_operating_state = initial_node.get_state()
reference_node_operating_state = reference_node.get_state() reference_node_operating_state = reference_node.get_state()
@@ -81,7 +96,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
if final_node_operating_state == reference_node_operating_state: if final_node_operating_state == reference_node_operating_state:
# All is well - we're no different from the reference situation # All is well - we're no different from the reference situation
score += config_values.all_ok score += config_values.all_ok
else: else:
# We're different from the reference situation # We're different from the reference situation
# Need to compare initial and final state of node (i.e. after red and blue actions) # Need to compare initial and final state of node (i.e. after red and blue actions)
if initial_node_operating_state == HARDWARE_STATE.ON: if initial_node_operating_state == HARDWARE_STATE.ON:
@@ -95,7 +110,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
if final_node_operating_state == HARDWARE_STATE.ON: if final_node_operating_state == HARDWARE_STATE.ON:
score += config_values.on_should_be_off score += config_values.on_should_be_off
elif final_node_operating_state == HARDWARE_STATE.RESETTING: elif final_node_operating_state == HARDWARE_STATE.RESETTING:
score += config_values.resetting_should_be_off score += config_values.resetting_should_be_off
else: else:
pass pass
elif initial_node_operating_state == HARDWARE_STATE.RESETTING: elif initial_node_operating_state == HARDWARE_STATE.RESETTING:
@@ -112,9 +127,10 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score return score
def score_node_os_state(final_node, initial_node, reference_node, config_values): def score_node_os_state(final_node, initial_node, reference_node, config_values):
""" """
Calculates score relating to the operating system state of a node Calculates score relating to the operating system state of a node.
Args: Args:
final_node: The node after red and blue agents take effect final_node: The node after red and blue agents take effect
@@ -122,8 +138,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
config_values: Config values config_values: Config values
""" """
score = 0
score = 0
final_node_os_state = final_node.get_os_state() final_node_os_state = final_node.get_os_state()
initial_node_os_state = initial_node.get_os_state() initial_node_os_state = initial_node.get_os_state()
reference_node_os_state = reference_node.get_os_state() reference_node_os_state = reference_node.get_os_state()
@@ -131,7 +146,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
if final_node_os_state == reference_node_os_state: if final_node_os_state == reference_node_os_state:
# All is well - we're no different from the reference situation # All is well - we're no different from the reference situation
score += config_values.all_ok score += config_values.all_ok
else: else:
# We're different from the reference situation # We're different from the reference situation
# Need to compare initial and final state of node (i.e. after red and blue actions) # Need to compare initial and final state of node (i.e. after red and blue actions)
if initial_node_os_state == SOFTWARE_STATE.GOOD: if initial_node_os_state == SOFTWARE_STATE.GOOD:
@@ -145,18 +160,18 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
if final_node_os_state == SOFTWARE_STATE.GOOD: if final_node_os_state == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_patching score += config_values.good_should_be_patching
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED: elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_patching score += config_values.compromised_should_be_patching
elif final_node_os_state == SOFTWARE_STATE.PATCHING: elif final_node_os_state == SOFTWARE_STATE.PATCHING:
score += config_values.patching score += config_values.patching
else: else:
pass pass
elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED: elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED:
if final_node_os_state == SOFTWARE_STATE.GOOD: if final_node_os_state == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_compromised score += config_values.good_should_be_compromised
elif final_node_os_state == SOFTWARE_STATE.PATCHING: elif final_node_os_state == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_compromised score += config_values.patching_should_be_compromised
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED: elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised score += config_values.compromised
else: else:
pass pass
else: else:
@@ -164,9 +179,10 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score return score
def score_node_service_state(final_node, initial_node, reference_node, config_values): def score_node_service_state(final_node, initial_node, reference_node, config_values):
""" """
Calculates score relating to the service state(s) of a node Calculates score relating to the service state(s) of a node.
Args: Args:
final_node: The node after red and blue agents take effect final_node: The node after red and blue agents take effect
@@ -174,12 +190,11 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
config_values: Config values config_values: Config values
""" """
score = 0
score = 0
final_node_services = final_node.get_services() final_node_services = final_node.get_services()
initial_node_services = initial_node.get_services() initial_node_services = initial_node.get_services()
reference_node_services = reference_node.get_services() reference_node_services = reference_node.get_services()
for service_key, final_service in final_node_services.items(): for service_key, final_service in final_node_services.items():
reference_service = reference_node_services[service_key] reference_service = reference_node_services[service_key]
initial_service = initial_node_services[service_key] initial_service = initial_node_services[service_key]
@@ -203,11 +218,11 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
if final_service.get_state() == SOFTWARE_STATE.GOOD: if final_service.get_state() == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_patching score += config_values.good_should_be_patching
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED: elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_patching score += config_values.compromised_should_be_patching
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED: elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
score += config_values.overwhelmed_should_be_patching score += config_values.overwhelmed_should_be_patching
elif final_service.get_state() == SOFTWARE_STATE.PATCHING: elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
score += config_values.patching score += config_values.patching
else: else:
pass pass
elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED: elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED:
@@ -216,9 +231,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
elif final_service.get_state() == SOFTWARE_STATE.PATCHING: elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_compromised score += config_values.patching_should_be_compromised
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED: elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised score += config_values.compromised
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED: elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
score += config_values.overwhelmed_should_be_compromised score += config_values.overwhelmed_should_be_compromised
else: else:
pass pass
elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED: elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
@@ -227,9 +242,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
elif final_service.get_state() == SOFTWARE_STATE.PATCHING: elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_overwhelmed score += config_values.patching_should_be_overwhelmed
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED: elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_overwhelmed score += config_values.compromised_should_be_overwhelmed
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED: elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
score += config_values.overwhelmed score += config_values.overwhelmed
else: else:
pass pass
else: else:
@@ -237,17 +252,17 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score return score
def score_node_file_system(final_node, initial_node, reference_node, config_values): def score_node_file_system(final_node, initial_node, reference_node, config_values):
""" """
Calculates score relating to the file system state of a node Calculates score relating to the file system state of a node.
Args: Args:
final_node: The node after red and blue agents take effect final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect reference_node: The node if there had been no red or blue effect
""" """
score = 0
score = 0
final_node_file_system_state = final_node.get_file_system_state_actual() final_node_file_system_state = final_node.get_file_system_state_actual()
initial_node_file_system_state = initial_node.get_file_system_state_actual() initial_node_file_system_state = initial_node.get_file_system_state_actual()
reference_node_file_system_state = reference_node.get_file_system_state_actual() reference_node_file_system_state = reference_node.get_file_system_state_actual()
@@ -259,7 +274,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
if final_node_file_system_state == reference_node_file_system_state: if final_node_file_system_state == reference_node_file_system_state:
# All is well - we're no different from the reference situation # All is well - we're no different from the reference situation
score += config_values.all_ok score += config_values.all_ok
else: else:
# We're different from the reference situation # We're different from the reference situation
# Need to compare initial and final state of node (i.e. after red and blue actions) # Need to compare initial and final state of node (i.e. after red and blue actions)
if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD: if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
@@ -277,15 +292,15 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD: if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
score += config_values.good_should_be_repairing score += config_values.good_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring_should_be_repairing score += config_values.restoring_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt_should_be_repairing score += config_values.corrupt_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed_should_be_repairing score += config_values.destroyed_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING: elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
score += config_values.repairing score += config_values.repairing
else: else:
pass pass
elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD: if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
score += config_values.good_should_be_restoring score += config_values.good_should_be_restoring
@@ -294,9 +309,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt_should_be_restoring score += config_values.corrupt_should_be_restoring
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed_should_be_restoring score += config_values.destroyed_should_be_restoring
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring score += config_values.restoring
else: else:
pass pass
elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
@@ -307,9 +322,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring_should_be_corrupt score += config_values.restoring_should_be_corrupt
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed_should_be_corrupt score += config_values.destroyed_should_be_corrupt
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt score += config_values.corrupt
else: else:
pass pass
elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
@@ -320,9 +335,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING: elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring_should_be_destroyed score += config_values.restoring_should_be_destroyed
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT: elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt_should_be_destroyed score += config_values.corrupt_should_be_destroyed
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED: elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed score += config_values.destroyed
else: else:
pass pass
else: else:
@@ -332,9 +347,9 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
if final_node_scanning_state == reference_node_scanning_state: if final_node_scanning_state == reference_node_scanning_state:
# All is well - we're no different from the reference situation # All is well - we're no different from the reference situation
score += config_values.all_ok score += config_values.all_ok
else: else:
# We're different from the reference situation # We're different from the reference situation
# We're scanning the file system which incurs a penalty (as it slows down systems) # We're scanning the file system which incurs a penalty (as it slows down systems)
score += config_values.scanning score += config_values.scanning
return score return score

View File

@@ -1,19 +1,15 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """The link class."""
The link class
"""
from primaite.common.protocol import Protocol from primaite.common.protocol import Protocol
from primaite.common.enums import *
class Link(object): class Link(object):
""" """Link class."""
Link class
"""
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
""" """
Init Init.
Args: Args:
_id: The IER id _id: The IER id
@@ -22,9 +18,8 @@ class Link(object):
_dest_node_name: The name of the destination node _dest_node_name: The name of the destination node
_protocols: The protocols to add to the link _protocols: The protocols to add to the link
""" """
self.id = _id self.id = _id
self.bandwidth = _bandwidth self.bandwidth = _bandwidth
self.source_node_name = _source_node_name self.source_node_name = _source_node_name
self.dest_node_name = _dest_node_name self.dest_node_name = _dest_node_name
self.protocol_list = [] self.protocol_list = []
@@ -35,72 +30,65 @@ class Link(object):
def add_protocol(self, _protocol): def add_protocol(self, _protocol):
""" """
Adds a new protocol to the list of protocols on this link Adds a new protocol to the list of protocols on this link.
Args: Args:
_protocol: The protocol to be added (enum) _protocol: The protocol to be added (enum)
""" """
self.protocol_list.append(Protocol(_protocol)) self.protocol_list.append(Protocol(_protocol))
def get_id(self): def get_id(self):
""" """
Gets link ID Gets link ID.
Returns: Returns:
Link ID Link ID
""" """
return self.id return self.id
def get_source_node_name(self): def get_source_node_name(self):
""" """
Gets source node name Gets source node name.
Returns: Returns:
Source node name Source node name
""" """
return self.source_node_name return self.source_node_name
def get_dest_node_name(self): def get_dest_node_name(self):
""" """
Gets destination node name Gets destination node name.
Returns: Returns:
Destination node name Destination node name
""" """
return self.dest_node_name return self.dest_node_name
def get_bandwidth(self): def get_bandwidth(self):
""" """
Gets bandwidth of link Gets bandwidth of link.
Returns: Returns:
Link bandwidth (bps) Link bandwidth (bps)
""" """
return self.bandwidth return self.bandwidth
def get_protocol_list(self): def get_protocol_list(self):
""" """
Gets list of protocols on this link Gets list of protocols on this link.
Returns: Returns:
List of protocols on this link List of protocols on this link
""" """
return self.protocol_list return self.protocol_list
def get_current_load(self): def get_current_load(self):
""" """
Gets current total load on this link Gets current total load on this link.
Returns: Returns:
Total load on this link (bps) Total load on this link (bps)
""" """
total_load = 0 total_load = 0
for protocol in self.protocol_list: for protocol in self.protocol_list:
total_load += protocol.get_load() total_load += protocol.get_load()
@@ -108,13 +96,12 @@ class Link(object):
def add_protocol_load(self, _protocol, _load): def add_protocol_load(self, _protocol, _load):
""" """
Adds a loading to a protocol on this link Adds a loading to a protocol on this link.
Args: Args:
_protocol: The protocol to load _protocol: The protocol to load
_load: The amount to load (bps) _load: The amount to load (bps)
""" """
for protocol in self.protocol_list: for protocol in self.protocol_list:
if protocol.get_name() == _protocol: if protocol.get_name() == _protocol:
protocol.add_load(_load) protocol.add_load(_load)
@@ -122,11 +109,6 @@ class Link(object):
pass pass
def clear_traffic(self): def clear_traffic(self):
""" """Clears all traffic on this link."""
Clears all traffic on this link
"""
for protocol in self.protocol_list: for protocol in self.protocol_list:
protocol.clear_load() protocol.clear_load()

View File

@@ -1,37 +1,31 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """
Primaite - main (harness) module Primaite - main (harness) module.
Coding Standards: PEP 8 Coding Standards: PEP 8
""" """
from sys import exc_info
import time
import yaml
import os.path
import logging import logging
import os.path
import time
from datetime import datetime from datetime import datetime
import yaml
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite.common.config_values_main import config_values_main
from primaite.environment.primaite_env import Primaite from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import write_transaction_to_file from primaite.transactions.transactions_to_file import write_transaction_to_file
from primaite.common.config_values_main import config_values_main
from stable_baselines3 import PPO # FUNCTIONS #
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
################################# FUNCTIONS ######################################
def run_generic(): def run_generic():
""" """Run against a generic agent."""
Run against a generic agent
"""
for episode in range(0, config_values.num_episodes): for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps): for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action # Send the observation space to the agent to get an action
# TEMP - random action for now # TEMP - random action for now
# action = env.blue_agent_action(obs) # action = env.blue_agent_action(obs)
@@ -54,15 +48,20 @@ def run_generic():
def run_stable_baselines3_ppo(): def run_stable_baselines3_ppo():
""" """Run against a stable_baselines3 PPO agent."""
Run against a stable_baselines3 PPO agent
"""
if config_values.load_agent == True: if config_values.load_agent == True:
try: try:
agent = PPO.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps) agent = PPO.load(
except: config_values.agent_load_file,
print("ERROR: Could not load agent at location: " + config_values.agent_load_file) env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
logging.error("Could not load agent") logging.error("Could not load agent")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
else: else:
@@ -73,7 +72,7 @@ def run_stable_baselines3_ppo():
print("Starting training session...") print("Starting training session...")
logging.info("Starting training session...") logging.info("Starting training session...")
for episode in range(0, config_values.num_episodes): for episode in range(0, config_values.num_episodes):
agent.learn(total_timesteps=1) agent.learn(total_timesteps=1)
save_agent(agent) save_agent(agent)
else: else:
# Default to being in an evaluation session # Default to being in an evaluation session
@@ -83,16 +82,22 @@ def run_stable_baselines3_ppo():
env.close() env.close()
def run_stable_baselines3_a2c():
"""
Run against a stable_baselines3 A2C agent
"""
def run_stable_baselines3_a2c():
"""Run against a stable_baselines3 A2C agent."""
if config_values.load_agent == True: if config_values.load_agent == True:
try: try:
agent = A2C.load(config_values.agent_load_file, env, verbose=0, n_steps=config_values.num_steps) agent = A2C.load(
except: config_values.agent_load_file,
print("ERROR: Could not load agent at location: " + config_values.agent_load_file) env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
logging.error("Could not load agent") logging.error("Could not load agent")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
else: else:
@@ -103,143 +108,213 @@ def run_stable_baselines3_a2c():
print("Starting training session...") print("Starting training session...")
logging.info("Starting training session...") logging.info("Starting training session...")
for episode in range(0, config_values.num_episodes): for episode in range(0, config_values.num_episodes):
agent.learn(total_timesteps=1) agent.learn(total_timesteps=1)
save_agent(agent) save_agent(agent)
else: else:
# Default to being in an evaluation session # Default to being in an evaluation session
print("Starting evaluation session...") print("Starting evaluation session...")
logging.info("Starting evaluation session...") logging.info("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close() env.close()
def save_agent(_agent):
"""
Persist an agent (only works for stable baselines3 agents at present)
"""
now = datetime.now() # current date and time def save_agent(_agent):
"""Persist an agent (only works for stable baselines3 agents at present)."""
now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S") time = now.strftime("%Y%m%d_%H%M%S")
try: try:
path = 'outputs/agents/' path = "outputs/agents/"
is_dir = os.path.isdir(path) is_dir = os.path.isdir(path)
if not is_dir: if not is_dir:
os.makedirs(path) os.makedirs(path)
filename = "outputs/agents/agent_saved_" + time filename = "outputs/agents/agent_saved_" + time
_agent.save(filename) _agent.save(filename)
logging.info("Trained agent saved as " + filename) logging.info("Trained agent saved as " + filename)
except Exception as e: except Exception:
logging.error("Could not save agent") logging.error("Could not save agent")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
def configure_logging():
"""
Configures logging
"""
def configure_logging():
"""Configures logging."""
try: try:
now = datetime.now() # current date and time now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S") time = now.strftime("%Y%m%d_%H%M%S")
filename = "logs/app_" + time + ".log" filename = "logs/app_" + time + ".log"
path = 'logs/' path = "logs/"
is_dir = os.path.isdir(path) is_dir = os.path.isdir(path)
if not is_dir: if not is_dir:
os.makedirs(path) os.makedirs(path)
logging.basicConfig(filename=filename, filemode='w', format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO) logging.basicConfig(
except: filename=filename,
filemode="w",
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
level=logging.INFO,
)
except Exception:
print("ERROR: Could not start logging") print("ERROR: Could not start logging")
def load_config_values():
"""
Loads the config values from the main config file into a config object
"""
def load_config_values():
"""Loads the config values from the main config file into a config object."""
try: try:
# Generic # Generic
config_values.agent_identifier = config_data['agentIdentifier'] config_values.agent_identifier = config_data["agentIdentifier"]
config_values.num_episodes = int(config_data['numEpisodes']) config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data['timeDelay']) config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = config_data['configFilename'] config_values.config_filename_use_case = config_data["configFilename"]
config_values.session_type = config_data['sessionType'] config_values.session_type = config_data["sessionType"]
config_values.load_agent = bool(config_data['loadAgent']) config_values.load_agent = bool(config_data["loadAgent"])
config_values.agent_load_file = config_data['agentLoadFile'] config_values.agent_load_file = config_data["agentLoadFile"]
# Environment # Environment
config_values.observation_space_high_value = int(config_data['observationSpaceHighValue']) config_values.observation_space_high_value = int(
config_data["observationSpaceHighValue"]
)
# Reward values # Reward values
# Generic # Generic
config_values.all_ok = int(config_data['allOk']) config_values.all_ok = int(config_data["allOk"])
# Node Operating State # Node Operating State
config_values.off_should_be_on = int(config_data['offShouldBeOn']) config_values.off_should_be_on = int(config_data["offShouldBeOn"])
config_values.off_should_be_resetting = int(config_data['offShouldBeResetting']) config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
config_values.on_should_be_off = int(config_data['onShouldBeOff']) config_values.on_should_be_off = int(config_data["onShouldBeOff"])
config_values.on_should_be_resetting = int(config_data['onShouldBeResetting']) config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
config_values.resetting_should_be_on = int(config_data['resettingShouldBeOn']) config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
config_values.resetting_should_be_off = int(config_data['resettingShouldBeOff']) config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
config_values.resetting = int(config_data['resetting']) config_values.resetting = int(config_data["resetting"])
# Node O/S or Service State # Node O/S or Service State
config_values.good_should_be_patching = int(config_data['goodShouldBePatching']) config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
config_values.good_should_be_compromised = int(config_data['goodShouldBeCompromised']) config_values.good_should_be_compromised = int(
config_values.good_should_be_overwhelmed = int(config_data['goodShouldBeOverwhelmed']) config_data["goodShouldBeCompromised"]
config_values.patching_should_be_good = int(config_data['patchingShouldBeGood']) )
config_values.patching_should_be_compromised = int(config_data['patchingShouldBeCompromised']) config_values.good_should_be_overwhelmed = int(
config_values.patching_should_be_overwhelmed = int(config_data['patchingShouldBeOverwhelmed']) config_data["goodShouldBeOverwhelmed"]
config_values.patching = int(config_data['patching']) )
config_values.compromised_should_be_good = int(config_data['compromisedShouldBeGood']) config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
config_values.compromised_should_be_patching = int(config_data['compromisedShouldBePatching']) config_values.patching_should_be_compromised = int(
config_values.compromised_should_be_overwhelmed = int(config_data['compromisedShouldBeOverwhelmed']) config_data["patchingShouldBeCompromised"]
config_values.compromised = int(config_data['compromised']) )
config_values.overwhelmed_should_be_good = int(config_data['overwhelmedShouldBeGood']) config_values.patching_should_be_overwhelmed = int(
config_values.overwhelmed_should_be_patching = int(config_data['overwhelmedShouldBePatching']) config_data["patchingShouldBeOverwhelmed"]
config_values.overwhelmed_should_be_compromised = int(config_data['overwhelmedShouldBeCompromised']) )
config_values.overwhelmed = int(config_data['overwhelmed']) config_values.patching = int(config_data["patching"])
config_values.compromised_should_be_good = int(
config_data["compromisedShouldBeGood"]
)
config_values.compromised_should_be_patching = int(
config_data["compromisedShouldBePatching"]
)
config_values.compromised_should_be_overwhelmed = int(
config_data["compromisedShouldBeOverwhelmed"]
)
config_values.compromised = int(config_data["compromised"])
config_values.overwhelmed_should_be_good = int(
config_data["overwhelmedShouldBeGood"]
)
config_values.overwhelmed_should_be_patching = int(
config_data["overwhelmedShouldBePatching"]
)
config_values.overwhelmed_should_be_compromised = int(
config_data["overwhelmedShouldBeCompromised"]
)
config_values.overwhelmed = int(config_data["overwhelmed"])
# Node File System State # Node File System State
config_values.good_should_be_repairing = int(config_data['goodShouldBeRepairing']) config_values.good_should_be_repairing = int(
config_values.good_should_be_restoring = int(config_data['goodShouldBeRestoring']) config_data["goodShouldBeRepairing"]
config_values.good_should_be_corrupt = int(config_data['goodShouldBeCorrupt']) )
config_values.good_should_be_destroyed = int(config_data['goodShouldBeDestroyed']) config_values.good_should_be_restoring = int(
config_values.repairing_should_be_good = int(config_data['repairingShouldBeGood']) config_data["goodShouldBeRestoring"]
config_values.repairing_should_be_restoring = int(config_data['repairingShouldBeRestoring']) )
config_values.repairing_should_be_corrupt = int(config_data['repairingShouldBeCorrupt']) config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
config_values.repairing_should_be_destroyed = int(config_data['repairingShouldBeDestroyed']) config_values.good_should_be_destroyed = int(
config_values.repairing = int(config_data['repairing']) config_data["goodShouldBeDestroyed"]
config_values.restoring_should_be_good = int(config_data['restoringShouldBeGood']) )
config_values.restoring_should_be_repairing = int(config_data['restoringShouldBeRepairing']) config_values.repairing_should_be_good = int(
config_values.restoring_should_be_corrupt = int(config_data['restoringShouldBeCorrupt']) config_data["repairingShouldBeGood"]
config_values.restoring_should_be_destroyed = int(config_data['restoringShouldBeDestroyed']) )
config_values.restoring = int(config_data['restoring']) config_values.repairing_should_be_restoring = int(
config_values.corrupt_should_be_good = int(config_data['corruptShouldBeGood']) config_data["repairingShouldBeRestoring"]
config_values.corrupt_should_be_repairing = int(config_data['corruptShouldBeRepairing']) )
config_values.corrupt_should_be_restoring = int(config_data['corruptShouldBeRestoring']) config_values.repairing_should_be_corrupt = int(
config_values.corrupt_should_be_destroyed = int(config_data['corruptShouldBeDestroyed']) config_data["repairingShouldBeCorrupt"]
config_values.corrupt = int(config_data['corrupt']) )
config_values.destroyed_should_be_good = int(config_data['destroyedShouldBeGood']) config_values.repairing_should_be_destroyed = int(
config_values.destroyed_should_be_repairing = int(config_data['destroyedShouldBeRepairing']) config_data["repairingShouldBeDestroyed"]
config_values.destroyed_should_be_restoring = int(config_data['destroyedShouldBeRestoring']) )
config_values.destroyed_should_be_corrupt = int(config_data['destroyedShouldBeCorrupt']) config_values.repairing = int(config_data["repairing"])
config_values.destroyed = int(config_data['destroyed']) config_values.restoring_should_be_good = int(
config_values.scanning = int(config_data['scanning']) config_data["restoringShouldBeGood"]
)
config_values.restoring_should_be_repairing = int(
config_data["restoringShouldBeRepairing"]
)
config_values.restoring_should_be_corrupt = int(
config_data["restoringShouldBeCorrupt"]
)
config_values.restoring_should_be_destroyed = int(
config_data["restoringShouldBeDestroyed"]
)
config_values.restoring = int(config_data["restoring"])
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
config_values.corrupt_should_be_repairing = int(
config_data["corruptShouldBeRepairing"]
)
config_values.corrupt_should_be_restoring = int(
config_data["corruptShouldBeRestoring"]
)
config_values.corrupt_should_be_destroyed = int(
config_data["corruptShouldBeDestroyed"]
)
config_values.corrupt = int(config_data["corrupt"])
config_values.destroyed_should_be_good = int(
config_data["destroyedShouldBeGood"]
)
config_values.destroyed_should_be_repairing = int(
config_data["destroyedShouldBeRepairing"]
)
config_values.destroyed_should_be_restoring = int(
config_data["destroyedShouldBeRestoring"]
)
config_values.destroyed_should_be_corrupt = int(
config_data["destroyedShouldBeCorrupt"]
)
config_values.destroyed = int(config_data["destroyed"])
config_values.scanning = int(config_data["scanning"])
# IER status # IER status
config_values.red_ier_running = int(config_data['redIerRunning']) config_values.red_ier_running = int(config_data["redIerRunning"])
config_values.green_ier_blocked = int(config_data['greenIerBlocked']) config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
# Patching / Reset durations # Patching / Reset durations
config_values.os_patching_duration = int(config_data['osPatchingDuration']) config_values.os_patching_duration = int(config_data["osPatchingDuration"])
config_values.node_reset_duration = int(config_data['nodeResetDuration']) config_values.node_reset_duration = int(config_data["nodeResetDuration"])
config_values.service_patching_duration = int(config_data['servicePatchingDuration']) config_values.service_patching_duration = int(
config_values.file_system_repairing_limit = int(config_data['fileSystemRepairingLimit']) config_data["servicePatchingDuration"]
config_values.file_system_restoring_limit = int(config_data['fileSystemRestoringLimit']) )
config_values.file_system_scanning_limit = int(config_data['fileSystemScanningLimit']) config_values.file_system_repairing_limit = int(
config_data["fileSystemRepairingLimit"]
logging.info("Training agent: " + config_values.agent_identifier) )
logging.info("Training environment config: " + config_values.config_filename_use_case) config_values.file_system_restoring_limit = int(
logging.info("Training cycle has " + str(config_values.num_episodes) + " episodes") config_data["fileSystemRestoringLimit"]
)
config_values.file_system_scanning_limit = int(
config_data["fileSystemScanningLimit"]
)
except Exception as e: logging.info("Training agent: " + config_values.agent_identifier)
logging.info(
"Training environment config: " + config_values.config_filename_use_case
)
logging.info(
"Training cycle has " + str(config_values.num_episodes) + " episodes"
)
except Exception:
logging.error("Could not save load config data") logging.error("Could not save load config data")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
################################# MAIN PROCESS ############################################ # MAIN PROCESS #
# Starting point # Starting point
@@ -257,25 +332,25 @@ try:
config_values = config_values_main() config_values = config_values_main()
# Load in config data # Load in config data
load_config_values() load_config_values()
except Exception as e: except Exception:
logging.error("Could not load main config") logging.error("Could not load main config")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
# Create a list of transactions # Create a list of transactions
# A transaction is an object holding the: # A transaction is an object holding the:
# - episode # # - episode #
# - step # # - step #
# - initial observation space # - initial observation space
# - action # - action
# - reward # - reward
# - new observation space # - new observation space
transaction_list = [] transaction_list = []
# Create the Primaite environment # Create the Primaite environment
try: try:
env = Primaite(config_values, transaction_list) env = Primaite(config_values, transaction_list)
logging.info("PrimAITE environment created") logging.info("PrimAITE environment created")
except Exception as e: except Exception:
logging.error("Could not create PrimAITE environment") logging.error("Could not create PrimAITE environment")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)
@@ -302,11 +377,3 @@ config_file_main.close
print("Finished") print("Finished")
logging.info("Finished") logging.info("Finished")

View File

@@ -1,19 +1,26 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """An Active Node (i.e. not an actuator)."""
An Active Node (i.e. not an actuator) from primaite.common.enums import FILE_SYSTEM_STATE, SOFTWARE_STATE
"""
from primaite.nodes.node import Node from primaite.nodes.node import Node
from primaite.common.enums import *
class ActiveNode(Node): class ActiveNode(Node):
""" """Active Node class."""
Active Node class
"""
def __init__(self, _id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values): def __init__(
self,
_id,
_name,
_type,
_priority,
_state,
_ip_address,
_os_state,
_file_system_state,
_config_values,
):
""" """
Init Init.
Args: Args:
_id: The node ID _id: The node ID
@@ -26,7 +33,6 @@ class ActiveNode(Node):
_file_system_state: The node file system state _file_system_state: The node file system state
_config_values: The config values _config_values: The config values
""" """
super().__init__(_id, _name, _type, _priority, _state, _config_values) super().__init__(_id, _name, _type, _priority, _state, _config_values)
self.ip_address = _ip_address self.ip_address = _ip_address
# Related to O/S # Related to O/S
@@ -39,20 +45,18 @@ class ActiveNode(Node):
self.file_system_scanning_count = 0 self.file_system_scanning_count = 0
self.file_system_action_count = 0 self.file_system_action_count = 0
def set_ip_address(self, _ip_address): def set_ip_address(self, _ip_address):
""" """
Sets IP address Sets IP address.
Args: Args:
_ip_address: IP address _ip_address: IP address
""" """
self.ip_address = _ip_address self.ip_address = _ip_address
def get_ip_address(self): def get_ip_address(self):
""" """
Gets IP address Gets IP address.
Returns: Returns:
IP address IP address
@@ -61,24 +65,22 @@ class ActiveNode(Node):
def set_os_state(self, _os_state): def set_os_state(self, _os_state):
""" """
Sets operating system state Sets operating system state.
Args: Args:
_os_state: Operating system state _os_state: Operating system state
""" """
self.os_state = _os_state self.os_state = _os_state
if _os_state == SOFTWARE_STATE.PATCHING: if _os_state == SOFTWARE_STATE.PATCHING:
self.patching_count = self.config_values.os_patching_duration self.patching_count = self.config_values.os_patching_duration
def set_os_state_if_not_compromised(self, _os_state): def set_os_state_if_not_compromised(self, _os_state):
""" """
Sets operating system state if the node is not compromised Sets operating system state if the node is not compromised.
Args: Args:
_os_state: Operating system state _os_state: Operating system state
""" """
if self.os_state != SOFTWARE_STATE.COMPROMISED: if self.os_state != SOFTWARE_STATE.COMPROMISED:
self.os_state = _os_state self.os_state = _os_state
if _os_state == SOFTWARE_STATE.PATCHING: if _os_state == SOFTWARE_STATE.PATCHING:
@@ -86,19 +88,15 @@ class ActiveNode(Node):
def get_os_state(self): def get_os_state(self):
""" """
Gets operating system state Gets operating system state.
Returns: Returns:
Operating system state Operating system state
""" """
return self.os_state return self.os_state
def update_os_patching_status(self): def update_os_patching_status(self):
""" """Updates operating system status based on patching cycle."""
Updates operating system status based on patching cycle
"""
self.patching_count -= 1 self.patching_count -= 1
if self.patching_count <= 0: if self.patching_count <= 0:
self.patching_count = 0 self.patching_count = 0
@@ -106,87 +104,88 @@ class ActiveNode(Node):
def set_file_system_state(self, _file_system_state): def set_file_system_state(self, _file_system_state):
""" """
Sets the file system state (actual and observed) Sets the file system state (actual and observed).
Args: Args:
_file_system_state: File system state _file_system_state: File system state
""" """
self.file_system_state_actual = _file_system_state self.file_system_state_actual = _file_system_state
if _file_system_state == FILE_SYSTEM_STATE.REPAIRING: if _file_system_state == FILE_SYSTEM_STATE.REPAIRING:
self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING
elif _file_system_state == FILE_SYSTEM_STATE.RESTORING: elif _file_system_state == FILE_SYSTEM_STATE.RESTORING:
self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING
elif _file_system_state == FILE_SYSTEM_STATE.GOOD: elif _file_system_state == FILE_SYSTEM_STATE.GOOD:
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
def set_file_system_state_if_not_compromised(self, _file_system_state): def set_file_system_state_if_not_compromised(self, _file_system_state):
""" """
Sets the file system state (actual and observed) if not in a compromised state Sets the file system state (actual and observed) if not in a compromised state.
Use for green PoL to prevent it overturning a compromised state Use for green PoL to prevent it overturning a compromised state
Args: Args:
_file_system_state: File system state _file_system_state: File system state
""" """
if (
if self.file_system_state_actual != FILE_SYSTEM_STATE.CORRUPT and self.file_system_state_actual != FILE_SYSTEM_STATE.DESTROYED: self.file_system_state_actual != FILE_SYSTEM_STATE.CORRUPT
and self.file_system_state_actual != FILE_SYSTEM_STATE.DESTROYED
):
self.file_system_state_actual = _file_system_state self.file_system_state_actual = _file_system_state
if _file_system_state == FILE_SYSTEM_STATE.REPAIRING: if _file_system_state == FILE_SYSTEM_STATE.REPAIRING:
self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING self.file_system_state_observed = FILE_SYSTEM_STATE.REPAIRING
elif _file_system_state == FILE_SYSTEM_STATE.RESTORING: elif _file_system_state == FILE_SYSTEM_STATE.RESTORING:
self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING self.file_system_state_observed = FILE_SYSTEM_STATE.RESTORING
elif _file_system_state == FILE_SYSTEM_STATE.GOOD: elif _file_system_state == FILE_SYSTEM_STATE.GOOD:
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD
def get_file_system_state_actual(self): def get_file_system_state_actual(self):
""" """
Gets file system state (actual) Gets file system state (actual).
Returns: Returns:
File system state (actual) File system state (actual)
""" """
return self.file_system_state_actual return self.file_system_state_actual
def get_file_system_state_observed(self): def get_file_system_state_observed(self):
""" """
Gets file system state (observed) Gets file system state (observed).
Returns: Returns:
File system state (observed) File system state (observed)
""" """
return self.file_system_state_observed return self.file_system_state_observed
def start_file_system_scan(self): def start_file_system_scan(self):
""" """Starts a file system scan."""
Starts a file system scan
"""
self.file_system_scanning = True self.file_system_scanning = True
self.file_system_scanning_count = self.config_values.file_system_scanning_limit self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def is_scanning_file_system(self): def is_scanning_file_system(self):
""" """
Gets true/false on whether file system is being scanned Gets true/false on whether file system is being scanned.
Returns: Returns:
True if file system is being scanned True if file system is being scanned
""" """
return self.file_system_scanning return self.file_system_scanning
def update_file_system_state(self): def update_file_system_state(self):
""" """Updates file system status based on scanning / restore / repair cycle."""
Updates file system status based on scanning / restore / repair cycle
"""
# Deprecate both the action count (for restoring or reparing) and the scanning count # Deprecate both the action count (for restoring or reparing) and the scanning count
self.file_system_action_count -= 1 self.file_system_action_count -= 1
self.file_system_scanning_count -= 1 self.file_system_scanning_count -= 1
@@ -194,7 +193,10 @@ class ActiveNode(Node):
# Reparing / Restoring updates # Reparing / Restoring updates
if self.file_system_action_count <= 0: if self.file_system_action_count <= 0:
self.file_system_action_count = 0 self.file_system_action_count = 0
if self.file_system_state_actual == FILE_SYSTEM_STATE.REPAIRING or self.file_system_state_actual == FILE_SYSTEM_STATE.RESTORING: if (
self.file_system_state_actual == FILE_SYSTEM_STATE.REPAIRING
or self.file_system_state_actual == FILE_SYSTEM_STATE.RESTORING
):
self.file_system_state_actual = FILE_SYSTEM_STATE.GOOD self.file_system_state_actual = FILE_SYSTEM_STATE.GOOD
self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD self.file_system_state_observed = FILE_SYSTEM_STATE.GOOD

View File

@@ -1,18 +1,14 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """The base Node class."""
The base Node class from primaite.common.enums import HARDWARE_STATE
"""
from primaite.common.enums import *
class Node: class Node:
""" """Node class."""
Node class
"""
def __init__(self, _id, _name, _type, _priority, _state, _config_values): def __init__(self, _id, _name, _type, _priority, _state, _config_values):
""" """
Init Init.
Args: Args:
_id: The node id _id: The node id
@@ -21,156 +17,124 @@ class Node:
_priority: The priority of the node _priority: The priority of the node
_state: The state of the node _state: The state of the node
""" """
self.id = _id self.id = _id
self.name = _name self.name = _name
self.type = _type self.type = _type
self.priority = _priority self.priority = _priority
self.operating_state = _state self.operating_state = _state
self.resetting_count = 0 self.resetting_count = 0
self.config_values = _config_values self.config_values = _config_values
def __repr__(self): def __repr__(self):
""" """Returns the name of the node."""
Returns the name of the node
"""
return self.name return self.name
def set_id(self, _id): def set_id(self, _id):
""" """
Sets the node ID Sets the node ID.
Args: Args:
_id: The node ID _id: The node ID
""" """
self.id = _id self.id = _id
def get_id(self): def get_id(self):
""" """
Gets the node ID Gets the node ID.
Returns: Returns:
The node ID The node ID
""" """
return self.id return self.id
def set_name(self, _name): def set_name(self, _name):
""" """
Sets the node name Sets the node name.
Args: Args:
_name: The node name _name: The node name
""" """
self.name = _name self.name = _name
def get_name(self): def get_name(self):
""" """
Gets the node name Gets the node name.
Returns: Returns:
The node name The node name
""" """
return self.name return self.name
def set_type(self, _type): def set_type(self, _type):
""" """
Sets the node type Sets the node type.
Args: Args:
_type: The node type _type: The node type
""" """
self.type = _type self.type = _type
def get_type(self): def get_type(self):
""" """
Gets the node type Gets the node type.
Returns: Returns:
The node type The node type
""" """
return self.type return self.type
def set_priority(self, _priority): def set_priority(self, _priority):
""" """
Sets the node priority Sets the node priority.
Args: Args:
_priority: The node priority _priority: The node priority
""" """
self.priority = _priority self.priority = _priority
def get_priority(self): def get_priority(self):
""" """
Gets the node priority Gets the node priority.
Returns: Returns:
The node priority The node priority
""" """
return self.priority return self.priority
def set_state(self, _state): def set_state(self, _state):
""" """
Sets the node state Sets the node state.
Args: Args:
_state: The node state _state: The node state
""" """
self.operating_state = _state self.operating_state = _state
def get_state(self): def get_state(self):
""" """
Gets the node operating state Gets the node operating state.
Returns: Returns:
The node operating state The node operating state
""" """
return self.operating_state return self.operating_state
def turn_on(self): def turn_on(self):
""" """Sets the node state to ON."""
Sets the node state to ON
"""
self.operating_state = HARDWARE_STATE.ON self.operating_state = HARDWARE_STATE.ON
def turn_off(self): def turn_off(self):
""" """Sets the node state to OFF."""
Sets the node state to OFF
"""
self.operating_state = HARDWARE_STATE.OFF self.operating_state = HARDWARE_STATE.OFF
def reset(self): def reset(self):
""" """Sets the node state to Resetting and starts the reset count."""
Sets the node state to Resetting and starts the reset count
"""
self.operating_state = HARDWARE_STATE.RESETTING self.operating_state = HARDWARE_STATE.RESETTING
self.resetting_count = self.config_values.node_reset_duration self.resetting_count = self.config_values.node_reset_duration
def update_resetting_status(self): def update_resetting_status(self):
""" """Updates the resetting count."""
Updates the resetting count
"""
self.resetting_count -= 1 self.resetting_count -= 1
if self.resetting_count <= 0: if self.resetting_count <= 0:
self.resetting_count = 0 self.resetting_count = 0
self.operating_state = HARDWARE_STATE.ON self.operating_state = HARDWARE_STATE.ON

View File

@@ -1,16 +1,22 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Defines node behaviour for Green PoL."""
Defines node behaviour for Green PoL
"""
class NodeStateInstructionGreen(object): class NodeStateInstructionGreen(object):
""" """The Node State Instruction class."""
The Node State Instruction class
"""
def __init__(self, _id, _start_step, _end_step, _node_id, _node_pol_type, _service_name, _state): def __init__(
self,
_id,
_start_step,
_end_step,
_node_id,
_node_pol_type,
_service_name,
_state,
):
""" """
Init Init.
Args: Args:
_id: The node state instruction id _id: The node state instruction id
@@ -21,72 +27,64 @@ class NodeStateInstructionGreen(object):
_service_name: The service name _service_name: The service name
_state: The state (node or service) _state: The state (node or service)
""" """
self.id = _id self.id = _id
self.start_step = _start_step self.start_step = _start_step
self.end_step = _end_step self.end_step = _end_step
self.node_id = _node_id self.node_id = _node_id
self.node_pol_type = _node_pol_type self.node_pol_type = _node_pol_type
self.service_name = _service_name # Not used when not a service instruction self.service_name = _service_name # Not used when not a service instruction
self.state = _state self.state = _state
def get_start_step(self): def get_start_step(self):
""" """
Gets the start step Gets the start step.
Returns: Returns:
The start step The start step
""" """
return self.start_step return self.start_step
def get_end_step(self): def get_end_step(self):
""" """
Gets the end step Gets the end step.
Returns: Returns:
The end step The end step
""" """
return self.end_step return self.end_step
def get_node_id(self): def get_node_id(self):
""" """
Gets the node ID Gets the node ID.
Returns: Returns:
The node ID The node ID
""" """
return self.node_id return self.node_id
def get_node_pol_type(self): def get_node_pol_type(self):
""" """
Gets the node pattern of life type (enum) Gets the node pattern of life type (enum).
Returns: Returns:
The node pattern of life type (enum) The node pattern of life type (enum)
""" """
return self.node_pol_type return self.node_pol_type
def get_service_name(self): def get_service_name(self):
""" """
Gets the service name Gets the service name.
Returns: Returns:
The service name The service name
""" """
return self.service_name return self.service_name
def get_state(self): def get_state(self):
""" """
Gets the state (node or service) Gets the state (node or service).
Returns: Returns:
The state (node or service) The state (node or service)
""" """
return self.state return self.state

View File

@@ -1,16 +1,26 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Defines node behaviour for Green PoL."""
Defines node behaviour for Green PoL
"""
class NodeStateInstructionRed(object): class NodeStateInstructionRed(object):
""" """The Node State Instruction class."""
The Node State Instruction class
"""
def __init__(self, _id, _start_step, _end_step, _target_node_id, _pol_initiator, _pol_type, pol_protocol, _pol_state, _pol_source_node_id, _pol_source_node_service, _pol_source_node_service_state): def __init__(
self,
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_pol_type,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
):
""" """
Init Init.
Args: Args:
_id: The node state instruction id _id: The node state instruction id
@@ -25,14 +35,13 @@ class NodeStateInstructionRed(object):
_pol_source_node_service: The source node service (used for initiator type SERVICE) _pol_source_node_service: The source node service (used for initiator type SERVICE)
_pol_source_node_service_state: The source node service state (used for initiator type SERVICE) _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
""" """
self.id = _id self.id = _id
self.start_step = _start_step self.start_step = _start_step
self.end_step = _end_step self.end_step = _end_step
self.target_node_id = _target_node_id self.target_node_id = _target_node_id
self.initiator = _pol_initiator self.initiator = _pol_initiator
self.pol_type = _pol_type self.pol_type = _pol_type
self.service_name = pol_protocol # Not used when not a service instruction self.service_name = pol_protocol # Not used when not a service instruction
self.state = _pol_state self.state = _pol_state
self.source_node_id = _pol_source_node_id self.source_node_id = _pol_source_node_id
self.source_node_service = _pol_source_node_service self.source_node_service = _pol_source_node_service
@@ -40,101 +49,90 @@ class NodeStateInstructionRed(object):
def get_start_step(self): def get_start_step(self):
""" """
Gets the start step Gets the start step.
Returns: Returns:
The start step The start step
""" """
return self.start_step return self.start_step
def get_end_step(self): def get_end_step(self):
""" """
Gets the end step Gets the end step.
Returns: Returns:
The end step The end step
""" """
return self.end_step return self.end_step
def get_target_node_id(self): def get_target_node_id(self):
""" """
Gets the node ID Gets the node ID.
Returns: Returns:
The node ID The node ID
""" """
return self.target_node_id return self.target_node_id
def get_initiator(self): def get_initiator(self):
""" """
Gets the initiator Gets the initiator.
Returns: Returns:
The initiator The initiator
""" """
return self.initiator return self.initiator
def get_pol_type(self): def get_pol_type(self):
""" """
Gets the node pattern of life type (enum) Gets the node pattern of life type (enum).
Returns: Returns:
The node pattern of life type (enum) The node pattern of life type (enum)
""" """
return self.pol_type return self.pol_type
def get_service_name(self): def get_service_name(self):
""" """
Gets the service name Gets the service name.
Returns: Returns:
The service name The service name
""" """
return self.service_name return self.service_name
def get_state(self): def get_state(self):
""" """
Gets the state (node or service) Gets the state (node or service).
Returns: Returns:
The state (node or service) The state (node or service)
""" """
return self.state return self.state
def get_source_node_id(self): def get_source_node_id(self):
""" """
Gets the source node id (used for initiator type SERVICE) Gets the source node id (used for initiator type SERVICE).
Returns: Returns:
The source node id The source node id
""" """
return self.source_node_id return self.source_node_id
def get_source_node_service(self): def get_source_node_service(self):
""" """
Gets the source node service (used for initiator type SERVICE) Gets the source node service (used for initiator type SERVICE).
Returns: Returns:
The source node service The source node service
""" """
return self.source_node_service return self.source_node_service
def get_source_node_service_state(self): def get_source_node_service_state(self):
""" """
Gets the source node service state (used for initiator type SERVICE) Gets the source node service state (used for initiator type SERVICE).
Returns: Returns:
The source node service state The source node service state
""" """
return self.source_node_service_state return self.source_node_service_state

View File

@@ -1,18 +1,15 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """The Passive Node class (i.e. an actuator)."""
The Passive Node class (i.e. an actuator)
"""
from primaite.nodes.node import Node from primaite.nodes.node import Node
class PassiveNode(Node): class PassiveNode(Node):
""" """The Passive Node class."""
The Passive Node class
"""
def __init__(self, _id, _name, _type, _priority, _state, _config_values): def __init__(self, _id, _name, _type, _priority, _state, _config_values):
""" """
Init Init.
Args: Args:
_id: The node id _id: The node id
@@ -21,17 +18,15 @@ class PassiveNode(Node):
_priority: The priority of the node _priority: The priority of the node
_state: The state of the node _state: The state of the node
""" """
# Pass through to Super for now # Pass through to Super for now
super().__init__(_id, _name, _type, _priority, _state, _config_values) super().__init__(_id, _name, _type, _priority, _state, _config_values)
def get_ip_address(self): def get_ip_address(self):
""" """
Gets the node IP address Gets the node IP address.
Returns: Returns:
The node IP address The node IP address
""" """
# No concept of IP address for passive nodes for now # No concept of IP address for passive nodes for now
return "" return ""

View File

@@ -1,19 +1,26 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """A Service Node (i.e. not an actuator)."""
A Service Node (i.e. not an actuator) from primaite.common.enums import SOFTWARE_STATE
"""
from primaite.nodes.active_node import ActiveNode from primaite.nodes.active_node import ActiveNode
from primaite.common.enums import *
class ServiceNode(ActiveNode): class ServiceNode(ActiveNode):
""" """ServiceNode class."""
ServiceNode class
"""
def __init__(self, _id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values): def __init__(
self,
_id,
_name,
_type,
_priority,
_state,
_ip_address,
_os_state,
_file_system_state,
_config_values,
):
""" """
Init Init.
Args: Args:
_id: The node id _id: The node id
@@ -25,38 +32,44 @@ class ServiceNode(ActiveNode):
_osState: The operating system state of the node _osState: The operating system state of the node
_file_system_state: The file system state of the node _file_system_state: The file system state of the node
""" """
super().__init__(
super().__init__(_id, _name, _type, _priority, _state, _ip_address, _os_state, _file_system_state, _config_values) _id,
_name,
_type,
_priority,
_state,
_ip_address,
_os_state,
_file_system_state,
_config_values,
)
self.services = {} self.services = {}
def add_service(self, _service): def add_service(self, _service):
""" """
Adds a service to the node Adds a service to the node.
Args: Args:
_service: The service to add _service: The service to add
""" """
self.services[_service.get_name()] = _service self.services[_service.get_name()] = _service
def get_services(self): def get_services(self):
""" """
Gets the dictionary of services on this node Gets the dictionary of services on this node.
Returns: Returns:
Dictionary of services on this node Dictionary of services on this node
""" """
return self.services return self.services
def has_service(self, _protocol): def has_service(self, _protocol):
""" """
Indicates whether a service is on a node Indicates whether a service is on a node.
Returns: Returns:
True if service (protocol) is on the node True if service (protocol) is on the node
""" """
for service_key, service_value in self.services.items(): for service_key, service_value in self.services.items():
if service_key == _protocol: if service_key == _protocol:
return True return True
@@ -66,12 +79,11 @@ class ServiceNode(ActiveNode):
def service_running(self, _protocol): def service_running(self, _protocol):
""" """
Indicates whether a service is in a running state on the node Indicates whether a service is in a running state on the node.
Returns: Returns:
True if service (protocol) is in a running state on the node True if service (protocol) is in a running state on the node
""" """
for service_key, service_value in self.services.items(): for service_key, service_value in self.services.items():
if service_key == _protocol: if service_key == _protocol:
if service_value.get_state() != SOFTWARE_STATE.PATCHING: if service_value.get_state() != SOFTWARE_STATE.PATCHING:
@@ -84,12 +96,11 @@ class ServiceNode(ActiveNode):
def service_is_overwhelmed(self, _protocol): def service_is_overwhelmed(self, _protocol):
""" """
Indicates whether a service is in an overwhelmed state on the node Indicates whether a service is in an overwhelmed state on the node.
Returns: Returns:
True if service (protocol) is in an overwhelmed state on the node True if service (protocol) is in an overwhelmed state on the node
""" """
for service_key, service_value in self.services.items(): for service_key, service_value in self.services.items():
if service_key == _protocol: if service_key == _protocol:
if service_value.get_state() == SOFTWARE_STATE.OVERWHELMED: if service_value.get_state() == SOFTWARE_STATE.OVERWHELMED:
@@ -102,61 +113,61 @@ class ServiceNode(ActiveNode):
def set_service_state(self, _protocol, _state): def set_service_state(self, _protocol, _state):
""" """
Sets the state of a service (protocol) on the node Sets the state of a service (protocol) on the node.
Args: Args:
_protocol: The service (protocol) _protocol: The service (protocol)
_state: The state value _state: The state value
""" """
for service_key, service_value in self.services.items(): for service_key, service_value in self.services.items():
if service_key == _protocol: if service_key == _protocol:
# Can't set to compromised if you're in a patching state # Can't set to compromised if you're in a patching state
if (_state == SOFTWARE_STATE.COMPROMISED and service_value.get_state() != SOFTWARE_STATE.PATCHING) or _state != SOFTWARE_STATE.COMPROMISED: if (
_state == SOFTWARE_STATE.COMPROMISED
and service_value.get_state() != SOFTWARE_STATE.PATCHING
) or _state != SOFTWARE_STATE.COMPROMISED:
service_value.set_state(_state) service_value.set_state(_state)
else: else:
# Do nothing # Do nothing
pass pass
if _state == SOFTWARE_STATE.PATCHING: if _state == SOFTWARE_STATE.PATCHING:
service_value.patching_count = self.config_values.service_patching_duration service_value.patching_count = (
self.config_values.service_patching_duration
)
else: else:
# Do nothing # Do nothing
pass pass
def set_service_state_if_not_compromised(self, _protocol, _state): def set_service_state_if_not_compromised(self, _protocol, _state):
""" """
Sets the state of a service (protocol) on the node if the operating state is not "compromised" Sets the state of a service (protocol) on the node if the operating state is not "compromised".
Args: Args:
_protocol: The service (protocol) _protocol: The service (protocol)
_state: The state value _state: The state value
""" """
for service_key, service_value in self.services.items(): for service_key, service_value in self.services.items():
if service_key == _protocol: if service_key == _protocol:
if service_value.get_state() != SOFTWARE_STATE.COMPROMISED: if service_value.get_state() != SOFTWARE_STATE.COMPROMISED:
service_value.set_state(_state) service_value.set_state(_state)
if _state == SOFTWARE_STATE.PATCHING: if _state == SOFTWARE_STATE.PATCHING:
service_value.patching_count = self.config_values.service_patching_duration service_value.patching_count = (
self.config_values.service_patching_duration
)
def get_service_state(self, _protocol): def get_service_state(self, _protocol):
""" """
Gets the state of a service Gets the state of a service.
Returns: Returns:
The state of the service The state of the service
""" """
for service_key, service_value in self.services.items(): for service_key, service_value in self.services.items():
if service_key == _protocol: if service_key == _protocol:
return service_value.get_state() return service_value.get_state()
def update_services_patching_status(self): def update_services_patching_status(self):
""" """Updates the patching counter for any service that are patching."""
Updates the patching counter for any service that are patching
"""
for service_key, service_value in self.services.items(): for service_key, service_value in self.services.items():
if service_value.get_state() == SOFTWARE_STATE.PATCHING: if service_value.get_state() == SOFTWARE_STATE.PATCHING:
service_value.reduce_patching_count() service_value.reduce_patching_count()

View File

@@ -1,19 +1,18 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Implements Pattern of Life on the network (nodes and links)."""
Implements Pattern of Life on the network (nodes and links)
"""
from networkx import shortest_path from networkx import shortest_path
from primaite.common.enums import * from primaite.common.enums import HARDWARE_STATE, NODE_POL_TYPE, SOFTWARE_STATE, TYPE
from primaite.nodes.active_node import ActiveNode from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode from primaite.nodes.service_node import ServiceNode
_VERBOSE = False _VERBOSE = False
def apply_iers(network, nodes, links, iers, acl, step): def apply_iers(network, nodes, links, iers, acl, step):
""" """
Applies IERs to the links (link pattern of life) Applies IERs to the links (link pattern of life).
Args: Args:
network: The network modelled in the environment network: The network modelled in the environment
@@ -21,9 +20,8 @@ def apply_iers(network, nodes, links, iers, acl, step):
links: The links within the environment links: The links within the environment
iers: The IERs to apply to the links iers: The IERs to apply to the links
acl: The Access Control List acl: The Access Control List
step: The step number step: The step number.
""" """
if _VERBOSE: if _VERBOSE:
print("Applying IERs") print("Applying IERs")
@@ -38,7 +36,7 @@ def apply_iers(network, nodes, links, iers, acl, step):
source_node_id = ier_value.get_source_node_id() source_node_id = ier_value.get_source_node_id()
dest_node_id = ier_value.get_dest_node_id() dest_node_id = ier_value.get_dest_node_id()
# Need to set the running status to false first for all IERs # Need to set the running status to false first for all IERs
ier_value.set_is_running(False) ier_value.set_is_running(False)
source_valid = True source_valid = True
@@ -46,8 +44,8 @@ def apply_iers(network, nodes, links, iers, acl, step):
acl_block = False acl_block = False
if step >= start_step and step <= stop_step: if step >= start_step and step <= stop_step:
# continue -------------------------- # continue --------------------------
# Get the source and destination node for this link # Get the source and destination node for this link
source_node = nodes[source_node_id] source_node = nodes[source_node_id]
dest_node = nodes[dest_node_id] dest_node = nodes[dest_node_id]
@@ -55,7 +53,10 @@ def apply_iers(network, nodes, links, iers, acl, step):
# 1. Check the source node situation # 1. Check the source node situation
if source_node.get_type() == TYPE.SWITCH: if source_node.get_type() == TYPE.SWITCH:
# It's a switch # It's a switch
if source_node.get_state() == HARDWARE_STATE.ON and source_node.get_os_state() != SOFTWARE_STATE.PATCHING: if (
source_node.get_state() == HARDWARE_STATE.ON
and source_node.get_os_state() != SOFTWARE_STATE.PATCHING
):
source_valid = True source_valid = True
else: else:
# IER no longer valid # IER no longer valid
@@ -66,9 +67,14 @@ def apply_iers(network, nodes, links, iers, acl, step):
pass pass
else: else:
# It's not a switch or an actuator (so active node) # It's not a switch or an actuator (so active node)
if source_node.get_state() == HARDWARE_STATE.ON and source_node.get_os_state() != SOFTWARE_STATE.PATCHING: if (
source_node.get_state() == HARDWARE_STATE.ON
and source_node.get_os_state() != SOFTWARE_STATE.PATCHING
):
if source_node.has_service(protocol): if source_node.has_service(protocol):
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol): if source_node.service_running(
protocol
) and not source_node.service_is_overwhelmed(protocol):
source_valid = True source_valid = True
else: else:
source_valid = False source_valid = False
@@ -80,11 +86,13 @@ def apply_iers(network, nodes, links, iers, acl, step):
# Do nothing - IER no longer valid # Do nothing - IER no longer valid
source_valid = False source_valid = False
# 2. Check the dest node situation # 2. Check the dest node situation
if dest_node.get_type() == TYPE.SWITCH: if dest_node.get_type() == TYPE.SWITCH:
# It's a switch # It's a switch
if dest_node.get_state() == HARDWARE_STATE.ON and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING: if (
dest_node.get_state() == HARDWARE_STATE.ON
and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING
):
dest_valid = True dest_valid = True
else: else:
# IER no longer valid # IER no longer valid
@@ -94,9 +102,14 @@ def apply_iers(network, nodes, links, iers, acl, step):
pass pass
else: else:
# It's not a switch or an actuator (so active node) # It's not a switch or an actuator (so active node)
if dest_node.get_state() == HARDWARE_STATE.ON and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING: if (
dest_node.get_state() == HARDWARE_STATE.ON
and dest_node.get_os_state() != SOFTWARE_STATE.PATCHING
):
if dest_node.has_service(protocol): if dest_node.has_service(protocol):
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol): if dest_node.service_running(
protocol
) and not dest_node.service_is_overwhelmed(protocol):
dest_valid = True dest_valid = True
else: else:
dest_valid = False dest_valid = False
@@ -109,10 +122,21 @@ def apply_iers(network, nodes, links, iers, acl, step):
dest_valid = False dest_valid = False
# 3. Check that the ACL doesn't block it # 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port) acl_block = acl.is_blocked(
source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port
)
if acl_block: if acl_block:
if _VERBOSE: if _VERBOSE:
print("ACL block on source: " + source_node.get_ip_address() + ", dest: " + dest_node.get_ip_address() + ", protocol: " + protocol + ", port: " + port) print(
"ACL block on source: "
+ source_node.get_ip_address()
+ ", dest: "
+ dest_node.get_ip_address()
+ ", protocol: "
+ protocol
+ ", port: "
+ port
)
else: else:
if _VERBOSE: if _VERBOSE:
print("No ACL block") print("No ACL block")
@@ -131,20 +155,25 @@ def apply_iers(network, nodes, links, iers, acl, step):
# We might have a switch in the path, so check all nodes are operational # We might have a switch in the path, so check all nodes are operational
for node in path_node_list: for node in path_node_list:
if node.get_state() != HARDWARE_STATE.ON or node.get_os_state() == SOFTWARE_STATE.PATCHING: if (
node.get_state() != HARDWARE_STATE.ON
or node.get_os_state() == SOFTWARE_STATE.PATCHING
):
path_valid = False path_valid = False
if path_valid: if path_valid:
if _VERBOSE: if _VERBOSE:
print("Applying IER to link(s)") print("Applying IER to link(s)")
count = 0 count = 0
link_capacity_exceeded = False link_capacity_exceeded = False
# Check that the link capacity is not exceeded by the new load # Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1: while count < path_node_list_length - 1:
# Get the link between the next two nodes # Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1]) edge_dict = network.get_edge_data(
link_id = edge_dict[0].get('id') path_node_list[count], path_node_list[count + 1]
)
link_id = edge_dict[0].get("id")
link = links[link_id] link = links[link_id]
# Check whether the new load exceeds the bandwidth # Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth(): if (link.get_current_load() + load) > link.get_bandwidth():
@@ -152,7 +181,7 @@ def apply_iers(network, nodes, links, iers, acl, step):
if _VERBOSE: if _VERBOSE:
print("Link capacity exceeded") print("Link capacity exceeded")
pass pass
count+=1 count += 1
# Check whether the link capacity for any links on this path have been exceeded # Check whether the link capacity for any links on this path have been exceeded
if link_capacity_exceeded == False: if link_capacity_exceeded == False:
@@ -160,20 +189,22 @@ def apply_iers(network, nodes, links, iers, acl, step):
count = 0 count = 0
while count < path_node_list_length - 1: while count < path_node_list_length - 1:
# Get the link between the next two nodes # Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1]) edge_dict = network.get_edge_data(
link_id = edge_dict[0].get('id') path_node_list[count], path_node_list[count + 1]
)
link_id = edge_dict[0].get("id")
link = links[link_id] link = links[link_id]
# Add the load from this IER # Add the load from this IER
link.add_protocol_load(protocol, load) link.add_protocol_load(protocol, load)
count+=1 count += 1
# This IER is now valid, so set it to running # This IER is now valid, so set it to running
ier_value.set_is_running(True) ier_value.set_is_running(True)
else: else:
# One of the nodes is not operational # One of the nodes is not operational
if _VERBOSE: if _VERBOSE:
print("Path not valid - one or more nodes not operational") print("Path not valid - one or more nodes not operational")
pass pass
else: else:
if _VERBOSE: if _VERBOSE:
print("Source, Dest or ACL were not valid") print("Source, Dest or ACL were not valid")
@@ -183,19 +214,19 @@ def apply_iers(network, nodes, links, iers, acl, step):
# Do nothing - IER no longer valid # Do nothing - IER no longer valid
pass pass
def apply_node_pol(nodes, node_pol, step): def apply_node_pol(nodes, node_pol, step):
""" """
Applies node pattern of life Applies node pattern of life.
Args: Args:
nodes: The nodes within the environment nodes: The nodes within the environment
node_pol: The node pattern of life to apply node_pol: The node pattern of life to apply
step: The step number step: The step number.
""" """
if _VERBOSE: if _VERBOSE:
print("Applying Node PoL") print("Applying Node PoL")
for key, node_instruction in node_pol.items(): for key, node_instruction in node_pol.items():
start_step = node_instruction.get_start_step() start_step = node_instruction.get_start_step()
stop_step = node_instruction.get_end_step() stop_step = node_instruction.get_end_step()
@@ -205,7 +236,7 @@ def apply_node_pol(nodes, node_pol, step):
state = node_instruction.get_state() state = node_instruction.get_state()
if step >= start_step and step <= stop_step: if step >= start_step and step <= stop_step:
# continue -------------------------- # continue --------------------------
node = nodes[node_id] node = nodes[node_id]
if node_pol_type == NODE_POL_TYPE.OPERATING: if node_pol_type == NODE_POL_TYPE.OPERATING:
@@ -227,4 +258,4 @@ def apply_node_pol(nodes, node_pol, step):
node.set_file_system_state_if_not_compromised(state) node.set_file_system_state_if_not_compromised(state)
else: else:
# PoL is not valid in this time step # PoL is not valid in this time step
pass pass

View File

@@ -1,17 +1,29 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """
Information Exchange Requirements for APE Information Exchange Requirements for APE.
Used to represent an information flow from source to destination
Used to represent an information flow from source to destination.
""" """
class IER(object):
"""
Information Exchange Requirement class
"""
def __init__(self, _id, _start_step, _end_step, _load, _protocol, _port, _source_node_id, _dest_node_id, _mission_criticality, _running=False): class IER(object):
"""Information Exchange Requirement class."""
def __init__(
self,
_id,
_start_step,
_end_step,
_load,
_protocol,
_port,
_source_node_id,
_dest_node_id,
_mission_criticality,
_running=False,
):
""" """
Init Init.
Args: Args:
_id: The IER id _id: The IER id
@@ -25,13 +37,12 @@ class IER(object):
_mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical) _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
_running: Indicates whether the IER is currently running _running: Indicates whether the IER is currently running
""" """
self.id = _id self.id = _id
self.start_step = _start_step self.start_step = _start_step
self.end_step = _end_step self.end_step = _end_step
self.source_node_id = _source_node_id self.source_node_id = _source_node_id
self.dest_node_id = _dest_node_id self.dest_node_id = _dest_node_id
self.load = _load self.load = _load
self.protocol = _protocol self.protocol = _protocol
self.port = _port self.port = _port
self.mission_criticality = _mission_criticality self.mission_criticality = _mission_criticality
@@ -39,97 +50,88 @@ class IER(object):
def get_id(self): def get_id(self):
""" """
Gets IER ID Gets IER ID.
Returns: Returns:
IER ID IER ID
""" """
return self.id return self.id
def get_start_step(self): def get_start_step(self):
""" """
Gets IER start step Gets IER start step.
Returns: Returns:
IER start step IER start step
""" """
return self.start_step return self.start_step
def get_end_step(self): def get_end_step(self):
""" """
Gets IER end step Gets IER end step.
Returns: Returns:
IER end step IER end step
""" """
return self.end_step return self.end_step
def get_load(self): def get_load(self):
""" """
Gets IER load Gets IER load.
Returns: Returns:
IER load IER load
""" """
return self.load return self.load
def get_protocol(self): def get_protocol(self):
""" """
Gets IER protocol Gets IER protocol.
Returns: Returns:
IER protocol IER protocol
""" """
return self.protocol return self.protocol
def get_port(self): def get_port(self):
""" """
Gets IER port Gets IER port.
Returns: Returns:
IER port IER port
""" """
return self.port return self.port
def get_source_node_id(self): def get_source_node_id(self):
""" """
Gets IER source node ID Gets IER source node ID.
Returns: Returns:
IER source node ID IER source node ID
""" """
return self.source_node_id return self.source_node_id
def get_dest_node_id(self): def get_dest_node_id(self):
""" """
Gets IER destination node ID Gets IER destination node ID.
Returns: Returns:
IER destination node ID IER destination node ID
""" """
return self.dest_node_id return self.dest_node_id
def get_is_running(self): def get_is_running(self):
""" """
Informs whether the IER is currently running Informs whether the IER is currently running.
Returns: Returns:
True if running True if running
""" """
return self.running return self.running
def set_is_running(self, _value): def set_is_running(self, _value):
""" """
Sets the running state of the IER Sets the running state of the IER.
Args: Args:
_value: running status _value: running status
@@ -138,10 +140,9 @@ class IER(object):
def get_mission_criticality(self): def get_mission_criticality(self):
""" """
Gets the IER mission criticality (used in the reward function) Gets the IER mission criticality (used in the reward function).
Returns: Returns:
Mission criticality value (0 lowest to 5 highest) Mission criticality value (0 lowest to 5 highest)
""" """
return self.mission_criticality
return self.mission_criticality

View File

@@ -1,19 +1,24 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Implements POL on the network (nodes and links) resulting from the red agent attack."""
Implements Pattern of Life on the network (nodes and links) resulting from the red agent attack
"""
from networkx import shortest_path from networkx import shortest_path
from primaite.common.enums import * from primaite.common.enums import (
HARDWARE_STATE,
NODE_POL_INITIATOR,
NODE_POL_TYPE,
SOFTWARE_STATE,
TYPE,
)
from primaite.nodes.active_node import ActiveNode from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode from primaite.nodes.service_node import ServiceNode
_VERBOSE = False _VERBOSE = False
def apply_red_agent_iers(network, nodes, links, iers, acl, step): def apply_red_agent_iers(network, nodes, links, iers, acl, step):
""" """
Applies IERs to the links (link pattern of life) resulting from red agent attack Applies IERs to the links (link POL) resulting from red agent attack.
Args: Args:
network: The network modelled in the environment network: The network modelled in the environment
@@ -21,9 +26,8 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
links: The links within the environment links: The links within the environment
iers: The red agent IERs to apply to the links iers: The red agent IERs to apply to the links
acl: The Access Control List acl: The Access Control List
step: The step number step: The step number.
""" """
# Go through each IER and check the conditions for it being applied # Go through each IER and check the conditions for it being applied
# If everything is in place, apply the IER protocol load to the relevant links # If everything is in place, apply the IER protocol load to the relevant links
for ier_key, ier_value in iers.items(): for ier_key, ier_value in iers.items():
@@ -35,7 +39,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
source_node_id = ier_value.get_source_node_id() source_node_id = ier_value.get_source_node_id()
dest_node_id = ier_value.get_dest_node_id() dest_node_id = ier_value.get_dest_node_id()
# Need to set the running status to false first for all IERs # Need to set the running status to false first for all IERs
ier_value.set_is_running(False) ier_value.set_is_running(False)
source_valid = True source_valid = True
@@ -43,8 +47,8 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
acl_block = False acl_block = False
if step >= start_step and step <= stop_step: if step >= start_step and step <= stop_step:
# continue -------------------------- # continue --------------------------
# Get the source and destination node for this link # Get the source and destination node for this link
source_node = nodes[source_node_id] source_node = nodes[source_node_id]
dest_node = nodes[dest_node_id] dest_node = nodes[dest_node_id]
@@ -66,7 +70,10 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
if source_node.get_state() == HARDWARE_STATE.ON: if source_node.get_state() == HARDWARE_STATE.ON:
if source_node.has_service(protocol): if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state # Red agents IERs can only be valid if the source service is in a compromised state
if source_node.get_service_state(protocol) == SOFTWARE_STATE.COMPROMISED: if (
source_node.get_service_state(protocol)
== SOFTWARE_STATE.COMPROMISED
):
source_valid = True source_valid = True
else: else:
source_valid = False source_valid = False
@@ -78,7 +85,6 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
# Do nothing - IER no longer valid # Do nothing - IER no longer valid
source_valid = False source_valid = False
# 2. Check the dest node situation # 2. Check the dest node situation
if dest_node.get_type() == TYPE.SWITCH: if dest_node.get_type() == TYPE.SWITCH:
# It's a switch # It's a switch
@@ -105,10 +111,21 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
dest_valid = False dest_valid = False
# 3. Check that the ACL doesn't block it # 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port) acl_block = acl.is_blocked(
source_node.get_ip_address(), dest_node.get_ip_address(), protocol, port
)
if acl_block: if acl_block:
if _VERBOSE: if _VERBOSE:
print("ACL block on source: " + source_node.get_ip_address() + ", dest: " + dest_node.get_ip_address() + ", protocol: " + protocol + ", port: " + port) print(
"ACL block on source: "
+ source_node.get_ip_address()
+ ", dest: "
+ dest_node.get_ip_address()
+ ", protocol: "
+ protocol
+ ", port: "
+ port
)
else: else:
if _VERBOSE: if _VERBOSE:
print("No ACL block") print("No ACL block")
@@ -130,7 +147,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
for node in path_node_list: for node in path_node_list:
if node.get_state() != HARDWARE_STATE.ON: if node.get_state() != HARDWARE_STATE.ON:
path_valid = False path_valid = False
if path_valid: if path_valid:
if _VERBOSE: if _VERBOSE:
print("Applying IER to link(s)") print("Applying IER to link(s)")
@@ -140,8 +157,10 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
# Check that the link capacity is not exceeded by the new load # Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1: while count < path_node_list_length - 1:
# Get the link between the next two nodes # Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1]) edge_dict = network.get_edge_data(
link_id = edge_dict[0].get('id') path_node_list[count], path_node_list[count + 1]
)
link_id = edge_dict[0].get("id")
link = links[link_id] link = links[link_id]
# Check whether the new load exceeds the bandwidth # Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth(): if (link.get_current_load() + load) > link.get_bandwidth():
@@ -149,7 +168,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
if _VERBOSE: if _VERBOSE:
print("Link capacity exceeded") print("Link capacity exceeded")
pass pass
count+=1 count += 1
# Check whether the link capacity for any links on this path have been exceeded # Check whether the link capacity for any links on this path have been exceeded
if link_capacity_exceeded == False: if link_capacity_exceeded == False:
@@ -157,12 +176,14 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
count = 0 count = 0
while count < path_node_list_length - 1: while count < path_node_list_length - 1:
# Get the link between the next two nodes # Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count+1]) edge_dict = network.get_edge_data(
link_id = edge_dict[0].get('id') path_node_list[count], path_node_list[count + 1]
)
link_id = edge_dict[0].get("id")
link = links[link_id] link = links[link_id]
# Add the load from this IER # Add the load from this IER
link.add_protocol_load(protocol, load) link.add_protocol_load(protocol, load)
count+=1 count += 1
# This IER is now valid, so set it to running # This IER is now valid, so set it to running
ier_value.set_is_running(True) ier_value.set_is_running(True)
if _VERBOSE: if _VERBOSE:
@@ -172,7 +193,7 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
if _VERBOSE: if _VERBOSE:
print("Path not valid - one or more nodes not operational") print("Path not valid - one or more nodes not operational")
pass pass
else: else:
if _VERBOSE: if _VERBOSE:
print("Red IER was NOT allowed to run in step " + str(step)) print("Red IER was NOT allowed to run in step " + str(step))
@@ -185,20 +206,20 @@ def apply_red_agent_iers(network, nodes, links, iers, acl, step):
pass pass
def apply_red_agent_node_pol(nodes, iers, node_pol, step): def apply_red_agent_node_pol(nodes, iers, node_pol, step):
""" """
Applies node pattern of life Applies node pattern of life.
Args: Args:
nodes: The nodes within the environment nodes: The nodes within the environment
iers: The red agent IERs iers: The red agent IERs
node_pol: The red agent node pattern of life to apply node_pol: The red agent node pattern of life to apply
step: The step number step: The step number.
""" """
if _VERBOSE: if _VERBOSE:
print("Applying Node Red Agent PoL") print("Applying Node Red Agent PoL")
for key, node_instruction in node_pol.items(): for key, node_instruction in node_pol.items():
start_step = node_instruction.get_start_step() start_step = node_instruction.get_start_step()
stop_step = node_instruction.get_end_step() stop_step = node_instruction.get_end_step()
@@ -209,12 +230,14 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
state = node_instruction.get_state() state = node_instruction.get_state()
source_node_id = node_instruction.get_source_node_id() source_node_id = node_instruction.get_source_node_id()
source_node_service_name = node_instruction.get_source_node_service() source_node_service_name = node_instruction.get_source_node_service()
source_node_service_state_value = node_instruction.get_source_node_service_state() source_node_service_state_value = (
node_instruction.get_source_node_service_state()
)
passed_checks = False passed_checks = False
if step >= start_step and step <= stop_step: if step >= start_step and step <= stop_step:
# continue -------------------------- # continue --------------------------
target_node = nodes[target_node_id] target_node = nodes[target_node_id]
# Based the action taken on the initiator type # Based the action taken on the initiator type
@@ -228,7 +251,10 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
# Need to check the condition of a service on another node # Need to check the condition of a service on another node
source_node = nodes[source_node_id] source_node = nodes[source_node_id]
if source_node.has_service(source_node_service_name): if source_node.has_service(source_node_service_name):
if source_node.get_service_state(source_node_service_name) == SOFTWARE_STATE[source_node_service_state_value]: if (
source_node.get_service_state(source_node_service_name)
== SOFTWARE_STATE[source_node_service_state_value]
):
passed_checks = True passed_checks = True
else: else:
# Do nothing, no matching state value # Do nothing, no matching state value
@@ -248,7 +274,9 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
target_node.set_state(state) target_node.set_state(state)
elif pol_type == NODE_POL_TYPE.OS: elif pol_type == NODE_POL_TYPE.OS:
# Change OS state # Change OS state
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
target_node.set_os_state(state) target_node.set_os_state(state)
elif pol_type == NODE_POL_TYPE.SERVICE: elif pol_type == NODE_POL_TYPE.SERVICE:
# Change a service state # Change a service state
@@ -256,23 +284,34 @@ def apply_red_agent_node_pol(nodes, iers, node_pol, step):
target_node.set_service_state(service_name, state) target_node.set_service_state(service_name, state)
else: else:
# Change the file system status # Change the file system status
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
target_node.set_file_system_state(state) target_node.set_file_system_state(state)
else: else:
if _VERBOSE: if _VERBOSE:
print("Node Red Agent PoL not allowed - did not pass checks") print("Node Red Agent PoL not allowed - did not pass checks")
else: else:
# PoL is not valid in this time step # PoL is not valid in this time step
pass pass
def is_red_ier_incoming(node, iers, node_pol_type):
def is_red_ier_incoming(node, iers, node_pol_type):
"""
Checks if the RED IER is incoming.
TODO: Write more descriptive docstring with params and returns.
"""
node_id = node.get_id() node_id = node.get_id()
for ier_key, ier_value in iers.items(): for ier_key, ier_value in iers.items():
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id: if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
if node_pol_type == NODE_POL_TYPE.OPERATING or node_pol_type == NODE_POL_TYPE.OS or node_pol_type == NODE_POL_TYPE.FILE: if (
# It's looking to change operating state, file system or O/S state, so valid node_pol_type == NODE_POL_TYPE.OPERATING
or node_pol_type == NODE_POL_TYPE.OS
or node_pol_type == NODE_POL_TYPE.FILE
):
# It's looking to change operating state, file system or O/S state, so valid
return True return True
elif node_pol_type == NODE_POL_TYPE.SERVICE: elif node_pol_type == NODE_POL_TYPE.SERVICE:
# Check if the service is present on the node and running # Check if the service is present on the node and running
@@ -297,5 +336,3 @@ def is_red_ier_incoming(node, iers, node_pol_type):
else: else:
# The IER destination is not this node, or the IER is not running # The IER destination is not this node, or the IER is not running
return False return False

View File

@@ -1,69 +1,57 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """The Transaction class."""
The Transaction class
"""
class Transaction(object): class Transaction(object):
""" """Transaction class."""
Transaction class
"""
def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number): def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number):
""" """
Init Init.
Args: Args:
_timestamp: The time this object was created _timestamp: The time this object was created
_agent_identifier: An identifier for the agent in use _agent_identifier: An identifier for the agent in use
_episode_number: The episode number _episode_number: The episode number
_step_number: The step number _step_number: The step number
""" """
self.timestamp = _timestamp self.timestamp = _timestamp
self.agent_identifier = _agent_identifier self.agent_identifier = _agent_identifier
self.episode_number = _episode_number self.episode_number = _episode_number
self.step_number = _step_number self.step_number = _step_number
def set_obs_space_pre(self, _obs_space_pre): def set_obs_space_pre(self, _obs_space_pre):
""" """
Sets the observation space (pre) Sets the observation space (pre).
Args: Args:
_obs_space_pre: The observation space before any actions are taken _obs_space_pre: The observation space before any actions are taken
""" """
self.obs_space_pre = _obs_space_pre self.obs_space_pre = _obs_space_pre
def set_obs_space_post(self, _obs_space_post): def set_obs_space_post(self, _obs_space_post):
""" """
Sets the observation space (post) Sets the observation space (post).
Args: Args:
_obs_space_post: The observation space after any actions are taken _obs_space_post: The observation space after any actions are taken
""" """
self.obs_space_post = _obs_space_post self.obs_space_post = _obs_space_post
def set_reward(self, _reward): def set_reward(self, _reward):
""" """
Sets the reward Sets the reward.
Args: Args:
_reward: The reward value _reward: The reward value
""" """
self.reward = _reward self.reward = _reward
def set_action_space(self, _action_space): def set_action_space(self, _action_space):
""" """
Sets the action space Sets the action space.
Args: Args:
_action_space: The action space invoked by the agent _action_space: The action space invoked by the agent
""" """
self.action_space = _action_space self.action_space = _action_space

View File

@@ -1,40 +1,35 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Writes the Transaction log list out to file for evaluation to utilse."""
Writes the Transaction log list out to file for evaluation to utilse
"""
import csv import csv
import logging import logging
import os.path import os.path
from datetime import datetime from datetime import datetime
from primaite.transactions.transaction import Transaction
def turn_action_space_to_array(_action_space): def turn_action_space_to_array(_action_space):
""" """
Turns action space into a string array so it can be saved to csv Turns action space into a string array so it can be saved to csv.
Args: Args:
_action_space: The action space _action_space: The action space.
""" """
return_array = [] return_array = []
for x in range(len(_action_space)): for x in range(len(_action_space)):
return_array.append(str(_action_space[x])) return_array.append(str(_action_space[x]))
return return_array return return_array
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
""" """
Turns observation space into a string array so it can be saved to csv Turns observation space into a string array so it can be saved to csv.
Args: Args:
_obs_space: The observation space _obs_space: The observation space
_obs_assets: The number of assets (i.e. nodes or links) in the observation space _obs_assets: The number of assets (i.e. nodes or links) in the observation space
_obs_features: The number of features associated with the asset _obs_features: The number of features associated with the asset
""" """
return_array = [] return_array = []
for x in range(_obs_assets): for x in range(_obs_assets):
for y in range(_obs_features): for y in range(_obs_features):
@@ -42,15 +37,15 @@ def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
return return_array return return_array
def write_transaction_to_file(_transaction_list): def write_transaction_to_file(_transaction_list):
""" """
Writes transaction logs to file to support training evaluation Writes transaction logs to file to support training evaluation.
Args: Args:
_transaction_list: The list of transactions from all steps and all episodes _transaction_list: The list of transactions from all steps and all episodes
_num_episodes: The number of episodes that were conducted _num_episodes: The number of episodes that were conducted.
""" """
# Get the first transaction and use it to determine the makeup of the observation space and action space # Get the first transaction and use it to determine the makeup of the observation space and action space
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action space as "AS_1" # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action space as "AS_1"
# This will be tied into the PrimAITE Use Case so that they make sense # This will be tied into the PrimAITE Use Case so that they make sense
@@ -59,46 +54,56 @@ def write_transaction_to_file(_transaction_list):
obs_assets = template_transation.obs_space_post.shape[0] obs_assets = template_transation.obs_space_post.shape[0]
obs_features = template_transation.obs_space_post.shape[1] obs_features = template_transation.obs_space_post.shape[1]
# Create the action space headers array # Create the action space headers array
action_header = [] action_header = []
for x in range(action_length): for x in range(action_length):
action_header.append('AS_' + str(x)) action_header.append("AS_" + str(x))
# Create the observation space headers array # Create the observation space headers array
obs_header_initial = [] obs_header_initial = []
obs_header_new = [] obs_header_new = []
for x in range(obs_assets): for x in range(obs_assets):
for y in range(obs_features): for y in range(obs_features):
obs_header_initial.append('OSI_' + str(x) + '_' + str(y)) obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
obs_header_new.append('OSN_' + str(x) + '_' + str(y)) obs_header_new.append("OSN_" + str(x) + "_" + str(y))
# Open up a csv file # Open up a csv file
header = ['Timestamp', 'Episode', 'Step', 'Reward'] header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_header_initial + obs_header_new header = header + action_header + obs_header_initial + obs_header_new
now = datetime.now() # current date and time now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S") time = now.strftime("%Y%m%d_%H%M%S")
try: try:
path = 'outputs/results/' path = "outputs/results/"
is_dir = os.path.isdir(path) is_dir = os.path.isdir(path)
if not is_dir: if not is_dir:
os.makedirs(path) os.makedirs(path)
filename = "outputs/results/all_transactions_" + time + ".csv" filename = "outputs/results/all_transactions_" + time + ".csv"
csv_file = open(filename, 'w', encoding='UTF8', newline='') csv_file = open(filename, "w", encoding="UTF8", newline="")
csv_writer = csv.writer(csv_file) csv_writer = csv.writer(csv_file)
csv_writer.writerow(header) csv_writer.writerow(header)
for transaction in _transaction_list: for transaction in _transaction_list:
csv_data = [str(transaction.timestamp), str(transaction.episode_number), str(transaction.step_number), str(transaction.reward)] csv_data = [
csv_data = csv_data + turn_action_space_to_array(transaction.action_space) + \ str(transaction.timestamp),
turn_obs_space_to_array(transaction.obs_space_pre, obs_assets, obs_features) + \ str(transaction.episode_number),
turn_obs_space_to_array(transaction.obs_space_post, obs_assets, obs_features) str(transaction.step_number),
str(transaction.reward),
]
csv_data = (
csv_data
+ turn_action_space_to_array(transaction.action_space)
+ turn_obs_space_to_array(
transaction.obs_space_pre, obs_assets, obs_features
)
+ turn_obs_space_to_array(
transaction.obs_space_post, obs_assets, obs_features
)
)
csv_writer.writerow(csv_data) csv_writer.writerow(csv_data)
csv_file.close() csv_file.close()
except Exception as e: except Exception:
logging.error("Could not save the transaction file") logging.error("Could not save the transaction file")
logging.error("Exception occured", exc_info=True) logging.error("Exception occured", exc_info=True)

View File

@@ -1,61 +1,48 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
""" """Used to tes the ACL functions."""
Used to tes the ACL functions
"""
from primaite.acl.acl_rule import ACLRule
from primaite.acl.access_control_list import AccessControlList from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
def test_acl_address_match_1(): def test_acl_address_match_1():
""" """Test that matching IP addresses produce True."""
Test that matching IP addresses produce True
"""
acl = AccessControlList() acl = AccessControlList()
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80") rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_2():
"""
Test that mismatching IP addresses produce False
"""
def test_acl_address_match_2():
"""Test that mismatching IP addresses produce False."""
acl = AccessControlList() acl = AccessControlList()
rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80") rule = ACLRule("ALLOW", "192.168.1.1", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.3") == False
def test_acl_address_match_3():
"""
Test the ANY condition for source IP addresses produce True
"""
def test_acl_address_match_3():
"""Test the ANY condition for source IP addresses produce True."""
acl = AccessControlList() acl = AccessControlList()
rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80") rule = ACLRule("ALLOW", "ANY", "192.168.1.2", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_acl_address_match_4():
"""
Test the ANY condition for dest IP addresses produce True
"""
def test_acl_address_match_4():
"""Test the ANY condition for dest IP addresses produce True."""
acl = AccessControlList() acl = AccessControlList()
rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80") rule = ACLRule("ALLOW", "192.168.1.1", "ANY", "TCP", "80")
assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True assert acl.check_address_match(rule, "192.168.1.1", "192.168.1.2") == True
def test_check_acl_block_affirmative():
"""
Test the block function (affirmative)
"""
def test_check_acl_block_affirmative():
"""Test the block function (affirmative)."""
# Create the Access Control List # Create the Access Control List
acl = AccessControlList() acl = AccessControlList()
@@ -66,15 +53,19 @@ def test_check_acl_block_affirmative():
acl_rule_protocol = "TCP" acl_rule_protocol = "TCP"
acl_rule_port = "80" acl_rule_port = "80"
acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port) acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
)
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == False
def test_check_acl_block_negative():
"""
Test the block function (negative)
"""
def test_check_acl_block_negative():
"""Test the block function (negative)."""
# Create the Access Control List # Create the Access Control List
acl = AccessControlList() acl = AccessControlList()
@@ -85,21 +76,27 @@ def test_check_acl_block_negative():
acl_rule_protocol = "TCP" acl_rule_protocol = "TCP"
acl_rule_port = "80" acl_rule_port = "80"
acl.add_rule(acl_rule_permission, acl_rule_source, acl_rule_destination, acl_rule_protocol, acl_rule_port) acl.add_rule(
acl_rule_permission,
acl_rule_source,
acl_rule_destination,
acl_rule_protocol,
acl_rule_port,
)
assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True assert acl.is_blocked("192.168.1.1", "192.168.1.2", "TCP", "80") == True
def test_rule_hash():
"""
Test the rule hash
"""
def test_rule_hash():
"""Test the rule hash."""
# Create the Access Control List # Create the Access Control List
acl = AccessControlList() acl = AccessControlList()
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule) hash_value_local = hash(rule)
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") hash_value_remote = acl.get_dictionary_hash(
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
)
assert hash_value_local == hash_value_remote assert hash_value_local == hash_value_remote