#1962: merge dev into branch + fix minor diffs + ensure that imports pull from src

This commit is contained in:
Czar.Echavez
2023-11-08 10:36:47 +00:00
201 changed files with 4799 additions and 16458 deletions

View File

@@ -81,6 +81,17 @@ stages:
displayName: 'Install PrimAITE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
python -m pip install $GATE_WHEEL[dev]
displayName: 'Install GATE'
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
- script: |
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
displayName: 'Install GATE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
primaite setup
displayName: 'Perform PrimAITE Setup'

View File

@@ -36,6 +36,12 @@ SessionManager.
- FTP Services: `FTPClient` and `FTPServer`
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
### Removed
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`
- Removed legacy training modules, they are replaced by the new ARCD GATE dependency
- Removed tests for legacy code
## [2.0.0] - 2023-07-26
### Added

View File

@@ -18,17 +18,13 @@ PrimAITE presents the following features:
- Provision of logging to support AI evaluation and metrics gathering;
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life and adversarial behaviour;
- Realistic network traffic simulation, including address and sending packets via internet protocols like TCP, UDP, ICMP, and others
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
- Routers with traffic routing and firewall capabilities
- Application of IERs to the platform / system laydown adheres to the ACL ruleset;
- Integration with ARCD GATE for agent training
- Presents an OpenAI gym or RLLib interface to the environment, allowing integration with any compliant defensive agents;
- Full capture of discrete logs relating to agent training (full system state, agent actions taken, instantaneous and average reward for every step of every episode);
- NetworkX provides laydown visualisation capability.
- Support for multiple agents, each having their own customisable observation space, action space, and reward function definition, and either deterministic or RL-directed behaviour
## Getting Started with PrimAITE
@@ -50,6 +46,7 @@ python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
.\.venv\Scripts\activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install GATE/arcd_gate-0.1.0-py3-none-any.whl
primaite setup
```
@@ -78,6 +75,7 @@ cd ~/primaite
python3 -m venv .venv
source .venv/bin/activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install arcd_gate-0.1.0-py3-none-any.whl
primaite setup
```
@@ -122,6 +120,7 @@ source venv/bin/activate
```bash
python3 -m pip install -e .[dev]
pip install arcd_gate-0.1.0-py3-none-any.whl
```
#### 6. Perform the PrimAITE setup:

View File

@@ -92,24 +92,34 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
.. toctree::
:maxdepth: 8
:caption: Contents:
:caption: About PrimAITE:
:hidden:
source/about
source/dependencies
source/glossary
.. toctree::
:caption: Usage:
:hidden:
source/getting_started
source/about
source/config
source/simulation
source/primaite_session
source/simulation
source/game_layer
source/custom_agent
source/config
.. toctree::
:caption: Developer information:
:hidden:
source/state_system
source/request_system
PrimAITE API <source/_autosummary/primaite>
PrimAITE Tests <source/_autosummary/tests>
source/dependencies
source/glossary
source/migration_1.2_-_2.0
.. TODO: Add project links once public repo has been created
.. toctree::
:caption: Project Links:
:hidden:

View File

@@ -7,408 +7,312 @@
About PrimAITE
==============
PrimAITE is a simulation environment for training agents to protect a computer network from cyber attacks.
Features
********
PrimAITE provides the following features:
* A flexible network / system laydown based on the Python networkx framework
* Nodes and links (edges) host Python classes in order to present attributes and methods (and hence, a more representative model of a platform / system)
* A 'green agent' Information Exchange Requirement (IER) function allows the representation of traffic (protocols and loading) on any / all links. Application of IERs is based on the status of node operating systems and services
* A 'green agent' node Pattern-of-Life (PoL) function allows the representation of core behaviours on nodes (e.g. changing the Hardware state, Software State, Service state, or File System state)
* An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP, destination IP, protocol and port). Application of IERs adheres to any ACL restrictions
* Presents an OpenAI Gym interface to the environment, allowing integration with any OpenAI Gym compliant defensive agents
* Red agent activity based on 'red' IERs and 'red' PoL
* Defined reward function for use with RL agents (based on nodes status, and green / red IER success)
* Fully configurable (network / system laydown, IERs, node PoL, ACL, episode step period, episode max steps) and repeatable to suit the training requirements of agents. Therefore, not bound to a representation of any particular platform, system or technology
* Full capture of discrete metrics relating to agent training (full system state, agent actions taken, average reward)
* Networkx provides laydown visualisation capability
Architecture - Nodes and Links
******************************
**Nodes**
An inheritance model has been adopted in order to model nodes. All nodes have the following base attributes (Class: Node):
* ID
* Name
* Type (e.g. computer, switch, RTU - enumeration)
* Priority (P1, P2, P3, P4 or P5 - enumeration)
* Hardware State (ON, OFF, RESETTING, SHUTTING_DOWN, BOOTING - enumeration)
Active Nodes also have the following attributes (Class: Active Node):
* IP Address
* Software State (GOOD, PATCHING, COMPROMISED - enumeration)
* File System State (GOOD, CORRUPT, DESTROYED, REPAIRING, RESTORING - enumeration)
Service Nodes also have the following attributes (Class: Service Node):
* List of Services (where service is composed of service name and port). There is no theoretical limit on the number of services that can be modelled. Services and protocols are currently intrinsically linked (i.e. a service is an application on a node transmitting traffic of this protocol type)
* Service state (GOOD, PATCHING, COMPROMISED, OVERWHELMED - enumeration)
Passive Nodes are currently not used (but may be employed for non IP-based components such as machinery actuators in future releases).
**Links**
Links are modelled both as network edges (networkx) and as Python classes, in order to extend their functionality. Links include the following attributes:
* ID
* Name
* Bandwidth (bits/s)
* Source node ID
* Destination node ID
* Protocol list (containing the loading of protocols currently running on the link)
When the simulation runs, IERs are applied to the links in order to model traffic loading, individually assigned to each protocol. This allows green (background) and red agent behaviour to be modelled, and defensive agents to identify suspicious traffic patterns at a protocol / traffic loading level of fidelity.
Information Exchange Requirements (IERs)
****************************************
PrimAITE adopts the concept of Information Exchange Requirements (IERs) to model both green agent (background) and red agent (adversary) behaviour. IERs are used to initiate modelling of traffic loading on the network, and have the following attributes:
* ID
* Start step (i.e. which step in the training episode should the IER start)
* End step (i.e. which step in the training episode should the IER end)
* Source node ID
* Destination node ID
* Load (bits/s)
* Protocol
* Port
* Running status (i.e. on / off)
The application of green agent IERs between a source and destination follows a number of rules. Specifically:
1. Does the current simulation time step fall between IER start and end step
2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
4. Are there any Access Control List rules in place that prevent the application of this IER
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
For red agent IERs, the application of IERs between a source and destination follows a number of subtly different rules. Specifically:
1. Does the current simulation time step fall between IER start and end step
2. Is the source node operational, and is the service (protocol / port) associated with the IER (a) present on that node and (b) already in a compromised state
3. Is the destination node operational, and is the service (protocol / port) associated with the IER present on that node
4. Are there any Access Control List rules in place that prevent the application of this IER
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
Assuming the rules pass, the IER is applied to all relevant links (based on use of OSPF) between source and destination.
Node Pattern-of-Life
********************
Every node can be impacted (i.e. have a status change applied to it) by either green agent pattern-of-life or red agent pattern-of-life. This is distinct from IERs, and allows for attacks (and defence) to be modelled purely within the confines of a node.
The status changes that can be made to a node are as follows:
* All Nodes:
* Hardware State:
* ON
* OFF
* RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state
* BOOTING
* SHUTTING_DOWN
* Active Nodes and Service Nodes:
* Software State:
* GOOD
* PATCHING - when a status of patching is entered, the node will automatically exit this state after a number of steps (as defined by the osPatchingDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* File System State:
* GOOD
* CORRUPT (can be resolved by repair or restore)
* DESTROYED (can be resolved by restore only)
* REPAIRING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRepairingLimit configuration item) after which it returns to a GOOD state
* RESTORING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRestoringLimit configuration item) after which it returns to a GOOD state
* Service Nodes only:
* Service State (for any associated service):
* GOOD
* PATCHING - when a status of patching is entered, the service will automatically exit this state after a number of steps (as defined by the servicePatchingDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* OVERWHELMED
Red agent pattern-of-life has an additional feature not found in the green pattern-of-life. This is the ability to influence the state of the attributes of a node via a number of different conditions:
* DIRECT:
The pattern-of-life described by the configuration file item will be applied regardless of any other conditions in the network. This is particularly useful for direct red agent entry into the network.
* IER:
The pattern-of-life described by the configuration file item will be applied to the service on the node, only if there is an IER of the same protocol / service type incoming at the specified timestep.
* SERVICE:
The pattern-of-life described by the configuration file item will be applied to the node based on the state of a service. The service can either be on the same node, or a different node within the network.
Access Control List modelling
*****************************
An Access Control List (ACL) is modelled to provide the means to manage traffic flows in the system. This will allow defensive agents the means to turn on / off rules, or potentially create new rules, to counter an attack.
The ACL follows a standard network firewall format. For example:
.. list-table:: ACL example
:widths: 25 25 25 25 25
:header-rows: 1
* - Permission
- Source IP
- Dest IP
- Protocol
- Port
* - DENY
- 192.168.1.2
- 192.168.1.3
- HTTPS
- 443
* - ALLOW
- 192.168.1.4
- ANY
- SMTP
- 25
* - DENY
- ANY
- 192.168.1.5
- ANY
- ANY
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or ALLOW for the same parameters will override an earlier entry.
Observation Spaces
******************
The observation space provides the blue agent with information about the current status of nodes and links.
PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationsHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
NodeLinkTable component
-----------------------
For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below:
An example observation space is provided below:
.. list-table:: Observation Space example
:widths: 25 25 25 25 25 25 25
:header-rows: 1
* -
- ID
- Hardware State
- Software State
- File System State
- Service / Protocol A
- Service / Protocol B
* - Node A
- 1
- 1
- 1
- 1
- 1
- 1
* - Node B
- 2
- 1
- 3
- 1
- 1
- 1
* - Node C
- 3
- 2
- 1
- 1
- 3
- 2
* - Link 1
- 5
- 0
- 0
- 0
- 0
- 10000
* - Link 2
- 6
- 0
- 0
- 0
- 0
- 10000
* - Link 3
- 7
- 0
- 0
- 0
- 5000
- 0
For the nodes, the following values are represented:
.. code-block::
[
ID
Hardware State (1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
Operating System State (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
File System State (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
Service1/Protocol1 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
Service2/Protocol2 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
]
(Note that each service available in the network is provided as a column, although not all nodes may utilise all services)
For the links, the following statuses are represented:
.. code-block::
[
ID
Hardware State (0=not applicable)
Operating System State (0=not applicable)
File System State (0=not applicable)
Service1/Protocol1 state (Traffic load from this protocol on this link)
Service2/Protocol2 state (Traffic load from this protocol on this link)
]
NodeStatus component
----------------------
This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states.
The example above would have the following structure:
.. code-block::
[
node1_info
node2_info
node3_info
]
Each ``node_info`` contains the following:
.. code-block::
[
hardware_state (0=none, 1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
software_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
file_system_state (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
service1_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
]
In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
.. code-block::
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4])
.. note::
NodeStatus observation component provides information only about nodes. Links are not considered.
LinkTrafficLevels
-----------------
This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value.
There are two configurable parameters:
* ``quantisation_levels`` determines how many discrete bins to use for converting the continuous traffic value to discrete (default is 5).
* ``combine_service_traffic`` determines whether to separately output traffic use for each network protocol or whether to combine them into an overall value for the link. (default is ``True``)
For example, with default parameters and a network with three links, the structure of this component would be:
.. code-block::
[
link1_status
link2_status
link3_status
]
Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth.
.. code-block::
0 = No traffic (0%)
1 = low traffic (1%-33%)
2 = medium traffic (33%-66%)
3 = high traffic (66%-99%)
4 = max traffic/ overwhelmed (100%)
Using ``gym`` notation, the shape of the obs space is: ``gym.spaces.MultiDiscrete([5,5,5])``.
Action Spaces
**************
The action space available to the blue agent comes in two types:
1. Node-based
2. Access Control List
3. Any (Agent can take both node-based and ACL-based actions)
The choice of action space used during a training session is determined in the config_[name].yaml file.
**Node-Based**
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3,x4] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, num nodes] - Node ID (0 = nothing, node ID)
* [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state)
* [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore)
* [0, num services] - Resolves to service ID (0 = nothing, resolves to service)
**Access Control List**
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
* [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
**ANY**
The agent is able to carry out both **Node-Based** and **Access Control List** operations.
This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above.
Rewards
*******
A reward value is presented back to the blue agent on the conclusion of every step. The reward value is calculated via two methods which combine to give the total value:
1. Node and service status
2. IER status
**Node and service status**
On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment)
and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a
difference with respect to the reference environment, the before and after states are compared, and a score determined. See :ref:`config` for details of reward values.
**IER status**
On every step, the full IER set is examined to determine whether green and red agent IERs are being permitted to run. Any red agent IERs running incur a penalty; any green agent
IERs not permitted to run also incur a penalty. See :ref:`config` for details of reward values.
Future Enhancements
*******************
The PrimAITE project has an ambition to include the following enhancements in future releases:
* Integration with a suitable standardised framework to allow multi-agent integration
* Integration with external threat emulation tools, either using off-line data, or integrating at runtime
* A flexible system for defining network layouts and host configurations
* Highly configurable network hosts, including definition of software, file system, and network interfaces,
* Realistic network traffic simulation, including address and sending packets via internet protocols like TCP, UDP, ICMP, etc.
* Routers with traffic routing and firewall capabilities
* Interfaces with ARCD GATE to allow training of agents
* Simulation of customisable deterministic agents
* Support for multiple agents, each having their own customisable observation space, action space, and reward function definition.
Structure
*********
PrimAITE consists of a simulator and a 'game' layer that allows agents to interact with the simulator. The simulator is built in a modular way where each component such as network hosts, links, networking devices, softwares, etc. are implemented as instances of a base class, meaning they all support the same interface. This allows for standardised configuration using either the Python API or YAML files.
The game layer is built on top of the simulator and it consumes the simulation action/state interface to allow agents to interact with the simulator. The game layer is also responsible for defining the reward function and observation space for the agents.
..
Architecture - Nodes and Links
******************************
**Nodes**
An inheritance model has been adopted in order to model nodes. All nodes have the following base attributes (Class: Node):
* ID
* Name
* Type (e.g. computer, switch, RTU - enumeration)
* Priority (P1, P2, P3, P4 or P5 - enumeration)
* Hardware State (ON, OFF, RESETTING, SHUTTING_DOWN, BOOTING - enumeration)
Active Nodes also have the following attributes (Class: Active Node):
* IP Address
* Software State (GOOD, PATCHING, COMPROMISED - enumeration)
* File System State (GOOD, CORRUPT, DESTROYED, REPAIRING, RESTORING - enumeration)
Service Nodes also have the following attributes (Class: Service Node):
* List of Services (where service is composed of service name and port). There is no theoretical limit on the number of services that can be modelled. Services and protocols are currently intrinsically linked (i.e. a service is an application on a node transmitting traffic of this protocol type)
* Service state (GOOD, PATCHING, COMPROMISED, OVERWHELMED - enumeration)
Passive Nodes are currently not used (but may be employed for non IP-based components such as machinery actuators in future releases).
**Links**
Links are modelled both as network edges (networkx) and as Python classes, in order to extend their functionality. Links include the following attributes:
* ID
* Name
* Bandwidth (bits/s)
* Source node ID
* Destination node ID
* Protocol list (containing the loading of protocols currently running on the link)
When the simulation runs, IERs are applied to the links in order to model traffic loading, individually assigned to each protocol. This allows green (background) and red agent behaviour to be modelled, and defensive agents to identify suspicious traffic patterns at a protocol / traffic loading level of fidelity.
Information Exchange Requirements (IERs)
****************************************
PrimAITE adopts the concept of Information Exchange Requirements (IERs) to model both green agent (background) and red agent (adversary) behaviour. IERs are used to initiate modelling of traffic loading on the network, and have the following attributes:
* ID
* Start step (i.e. which step in the training episode should the IER start)
* End step (i.e. which step in the training episode should the IER end)
* Source node ID
* Destination node ID
* Load (bits/s)
* Protocol
* Port
* Running status (i.e. on / off)
The application of green agent IERs between a source and destination follows a number of rules. Specifically:
1. Does the current simulation time step fall between IER start and end step
2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
4. Are there any Access Control List rules in place that prevent the application of this IER
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
For red agent IERs, the application of IERs between a source and destination follows a number of subtly different rules. Specifically:
1. Does the current simulation time step fall between IER start and end step
2. Is the source node operational, and is the service (protocol / port) associated with the IER (a) present on that node and (b) already in a compromised state
3. Is the destination node operational, and is the service (protocol / port) associated with the IER present on that node
4. Are there any Access Control List rules in place that prevent the application of this IER
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
Assuming the rules pass, the IER is applied to all relevant links (based on use of OSPF) between source and destination.
Node Pattern-of-Life
********************
Every node can be impacted (i.e. have a status change applied to it) by either green agent pattern-of-life or red agent pattern-of-life. This is distinct from IERs, and allows for attacks (and defence) to be modelled purely within the confines of a node.
The status changes that can be made to a node are as follows:
* All Nodes:
* Hardware State:
* ON
* OFF
* RESETTING - when a status of resetting is entered, the node will automatically exit this state after a number of steps (as defined by the nodeResetDuration configuration item) after which it returns to an ON state
* BOOTING
* SHUTTING_DOWN
* Active Nodes and Service Nodes:
* Software State:
* GOOD
* PATCHING - when a status of patching is entered, the node will automatically exit this state after a number of steps (as defined by the osPatchingDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* File System State:
* GOOD
* CORRUPT (can be resolved by repair or restore)
* DESTROYED (can be resolved by restore only)
* REPAIRING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRepairingLimit configuration item) after which it returns to a GOOD state
* RESTORING - when a status of repairing is entered, the node will automatically exit this state after a number of steps (as defined by the fileSystemRestoringLimit configuration item) after which it returns to a GOOD state
* Service Nodes only:
* Service State (for any associated service):
* GOOD
* PATCHING - when a status of patching is entered, the service will automatically exit this state after a number of steps (as defined by the servicePatchingDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* OVERWHELMED
Red agent pattern-of-life has an additional feature not found in the green pattern-of-life. This is the ability to influence the state of the attributes of a node via a number of different conditions:
* DIRECT:
The pattern-of-life described by the configuration file item will be applied regardless of any other conditions in the network. This is particularly useful for direct red agent entry into the network.
* IER:
The pattern-of-life described by the configuration file item will be applied to the service on the node, only if there is an IER of the same protocol / service type incoming at the specified timestep.
* SERVICE:
The pattern-of-life described by the configuration file item will be applied to the node based on the state of a service. The service can either be on the same node, or a different node within the network.
Access Control List modelling
*****************************
An Access Control List (ACL) is modelled to provide the means to manage traffic flows in the system. This will allow defensive agents the means to turn on / off rules, or potentially create new rules, to counter an attack.
The ACL follows a standard network firewall format. For example:
.. list-table:: ACL example
:widths: 25 25 25 25 25
:header-rows: 1
* - Permission
- Source IP
- Dest IP
- Protocol
- Port
* - DENY
- 192.168.1.2
- 192.168.1.3
- HTTPS
- 443
* - ALLOW
- 192.168.1.4
- ANY
- SMTP
- 25
* - DENY
- ANY
- 192.168.1.5
- ANY
- ANY
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or ALLOW for the same parameters will override an earlier entry.
Observation Spaces
******************
The observation space provides the blue agent with information about the current status of nodes and links.
PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationsHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
NodeLinkTable component
-----------------------
For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below:
An example observation space is provided below:
.. list-table:: Observation Space example
:widths: 25 25 25 25 25 25 25
:header-rows: 1
* -
- ID
- Hardware State
- Software State
- File System State
- Service / Protocol A
- Service / Protocol B
* - Node A
- 1
- 1
- 1
- 1
- 1
- 1
* - Node B
- 2
- 1
- 3
- 1
- 1
- 1
* - Node C
- 3
- 2
- 1
- 1
- 3
- 2
* - Link 1
- 5
- 0
- 0
- 0
- 0
- 10000
* - Link 2
- 6
- 0
- 0
- 0
- 0
- 10000
* - Link 3
- 7
- 0
- 0
- 0
- 5000
- 0
For the nodes, the following values are represented:
.. code-block::
[
ID
Hardware State (1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
Operating System State (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
File System State (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
Service1/Protocol1 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
Service2/Protocol2 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
]
(Note that each service available in the network is provided as a column, although not all nodes may utilise all services)
For the links, the following statuses are represented:
.. code-block::
[
ID
Hardware State (0=not applicable)
Operating System State (0=not applicable)
File System State (0=not applicable)
Service1/Protocol1 state (Traffic load from this protocol on this link)
Service2/Protocol2 state (Traffic load from this protocol on this link)
]
NodeStatus component
----------------------
This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states.
The example above would have the following structure:
.. code-block::
[
node1_info
node2_info
node3_info
]
Each ``node_info`` contains the following:
.. code-block::
[
hardware_state (0=none, 1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
software_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
file_system_state (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
service1_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
]
In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
.. code-block::
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4])
.. note::
NodeStatus observation component provides information only about nodes. Links are not considered.
LinkTrafficLevels
-----------------
This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value.
There are two configurable parameters:
* ``quantisation_levels`` determines how many discrete bins to use for converting the continuous traffic value to discrete (default is 5).
* ``combine_service_traffic`` determines whether to separately output traffic use for each network protocol or whether to combine them into an overall value for the link. (default is ``True``)
For example, with default parameters and a network with three links, the structure of this component would be:
.. code-block::
[
link1_status
link2_status
link3_status
]
Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth.
.. code-block::
0 = No traffic (0%)
1 = low traffic (1%-33%)
2 = medium traffic (33%-66%)
3 = high traffic (66%-99%)
4 = max traffic/ overwhelmed (100%)
Using ``gym`` notation, the shape of the obs space is: ``gym.spaces.MultiDiscrete([5,5,5])``.
Action Spaces
**************
The action space available to the blue agent comes in two types:
1. Node-based
2. Access Control List
3. Any (Agent can take both node-based and ACL-based actions)
The choice of action space used during a training session is determined in the config_[name].yaml file.
**Node-Based**
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3,x4] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, num nodes] - Node ID (0 = nothing, node ID)
* [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state)
* [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore)
* [0, num services] - Resolves to service ID (0 = nothing, resolves to service)
**Access Control List**
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
* [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
**ANY**
The agent is able to carry out both **Node-Based** and **Access Control List** operations.
This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above.
Rewards
*******
A reward value is presented back to the blue agent on the conclusion of every step. The reward value is calculated via two methods which combine to give the total value:
1. Node and service status
2. IER status
**Node and service status**
On every step, the status of each node is compared against both a reference environment (simulating the situation if the red and blue agents had not impacted the environment)
and the before and after state of the environment. If the comparison against the reference environment shows no difference, then the score provided is "AllOK". If there is a
difference with respect to the reference environment, the before and after states are compared, and a score determined. See :ref:`config` for details of reward values.
**IER status**
On every step, the full IER set is examined to determine whether green and red agent IERs are being permitted to run. Any red agent IERs running incur a penalty; any green agent
IERs not permitted to run also incur a penalty. See :ref:`config` for details of reward values.
Future Enhancements
*******************
The PrimAITE project has an ambition to include the following enhancements in future releases:
* Integration with a suitable standardised framework to allow multi-agent integration
* Integration with external threat emulation tools, either using off-line data, or integrating at runtime

View File

@@ -1,88 +0,0 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Actions System
==============
``SimComponent``s in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``Action``.
Just like other aspects of SimComponent, the actions are not managed centrally for the whole simulation, but instead they are dynamically created and updated based on the nodes, links, and other components that currently exist. This was achieved with the following design decisions:
- API
An 'action' contains two elements:
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as `['network', 'node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
2. ``context`` - optional extra information that can be used to decide how to process the action. This is formatted as a dictionary. For example, if the action requires authentication, the context can include information about the user that initiated the request to decide if their permissions are sufficient.
- request
The request is a list of strings which help specify who should handle the request. The strings in the request list help RequestManagers traverse the 'ownership tree' of SimComponent. The example given above would be handled in the following way:
1. ``Simulation`` receives `['network', 'node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
The first element of the action is ``network``, therefore it passes the action down to its network.
2. ``Network`` receives `['node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
The first element of the action is ``node``, therefore the network looks at the node uuid and passes the action down to the node with that uuid.
3. ``Node`` receives `['service', '<service-uuid>', 'restart']`.
The first element of the action is ``service``, therefore the node looks at the service uuid and passes the rest of the action to the service with that uuid.
4. ``Service`` receives ``['restart']``.
Since ``restart`` is a defined action in the service's own RequestManager, the service performs a restart.
Technical Detail
================
This system was achieved by implementing two classes, :py:class:`primaite.simulator.core.Action`, and :py:class:`primaite.simulator.core.RequestManager`.
Action
------
The ``Action`` object stores a reference to a method that performs the action, for example a node could have an action that stores a reference to ``self.turn_on()``. Technically, this can be any callable that accepts `request, context` as it's parameters. In practice, this is often defined using ``lambda`` functions within a component's ``self._init_request_manager()`` method. Optionally, the ``Action`` object can also hold a validator that will permit/deny the action depending on context.
RequestManager
-------------
The ``RequestManager`` object stores a mapping between strings and actions. It is responsible for processing the ``request`` and passing it down the ownership tree. Technically, the ``RequestManager`` is itself a callable that accepts `request, context` tuple, and so it can be chained with other action managers.
A simple example without chaining can be seen in the :py:class:`primaite.simulator.file_system.file_system.File` class.
.. code-block:: python
class File(FileSystemItemABC):
...
def _init_request_manager(self):
...
request_manager.add_request("scan", Action(func=lambda request, context: self.scan()))
request_manager.add_request("repair", Action(func=lambda request, context: self.repair()))
request_manager.add_request("restore", Action(func=lambda request, context: self.restore()))
*ellipses (``...``) used to omit code impertinent to this explanation*
Chaining RequestManagers
-----------------------
Since the method for performing an action needs to accept `request, context` as parameters, and RequestManager itself is a callable that accepts `request, context` as parameters, it possible to use RequestManager as an action. In fact, that is how PrimAITE deals with traversing the ownership tree. Each time an RequestManager accepts a request, it pops the first elements and uses it to decide to which Action it should send the remaining request. However, the Action could have another RequestManager as it's function, therefore the request will be routed again. Each time the request is passed to a new action manager, the first element is popped.
An example of how this works is in the :py:class:`primaite.simulator.network.hardware.base.Node` class.
.. code-block:: python
class Node(SimComponent):
...
def _init_request_manager(self):
...
# a regular action which is processed by the Node itself
request_manager.add_request("turn_on", Action(func=lambda request, context: self.turn_on()))
# if the Node receives a request where the first word is 'service', it will use a dummy manager
# called self._service_request_manager to pass on the reqeust to the relevant service. This dummy
# manager is simply here to map the service UUID that that service's own action manager. This is
# done because the next string after "service" is always the uuid of that service, so we need an
# RequestManager to pop that string before sending it onto the relevant service's RequestManager.
self._service_request_manager = RequestManager()
request_manager.add_request("service", Action(func=self._service_request_manager))
...
def install_service(self, service):
self.services[service.uuid] = service
...
# Here, the service UUID is registered to allow passing actions between the node and the service.
self._service_request_manager.add_request(service.uuid, Action(func=service._request_manager))

View File

@@ -1,489 +1,13 @@
Primaite v3 config
******************
PrimAITE uses a single configuration file to define a cybersecurity scenario. This includes the computer network and multiple agents. There are three main sections: training_config, game, and simulation.
The simulation section describes the simulated network environment with which the agetns interact.
The game section describes the agents and their capabilities. Each agent has a unique type and is associated with a team (GREEN, RED, or BLUE). Each agent has a configurable observation space, action space, and reward function.
The training_config section describes the training parameters for the learning agents. This includes the number of episodes, the number of steps per episode, and the number of steps before the agents start learning. The training_config section also describes the learning algorithm used by the agents. The learning algorithm is specified by the name of the algorithm and the hyperparameters for the algorithm. The hyperparameters are specific to each algorithm and are described in the documentation for each algorithm.
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _config:
The Config Files Explained
==========================
PrimAITE uses two configuration files for its operation:
* **The Training Config**
Used to define the top-level settings of the PrimAITE environment, the reward values, and the session that is to be run.
* **The Lay Down Config**
Used to define the low-level settings of a session, including the network laydown, green / red agent information exchange requirements (IERSs) and Access Control Rules.
Training Config:
*******************
The Training Config file consists of the following attributes:
**Generic Config Values**
* **agent_framework** [enum]
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
* NONE - Where a user developed agent is to be used
* SB3 - Stable Baselines3
* RLLIB - Ray RLlib.
* **agent_identifier**
This identifies the agent to use for the session. Select from one of the following:
* A2C - Advantage Actor Critic
* PPO - Proximal Policy Optimization
* HARDCODED - A custom built deterministic agent
* RANDOM - A Stochastic random agent
* **random_red_agent** [bool]
Determines if the session should be run with a random red agent
* **action_type** [enum]
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
* **OBSERVATION_SPACE** [dict]
Allows for user to configure observation space by combining one or more observation components. List of available
components is in :py:mod:`primaite.environment.observations`.
The observation space config item should have a ``components`` key which is a list of components. Each component
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
component while it is being initialised.
This example illustrates the correct format for the observation space config item
.. code-block:: yaml
observation_space:
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
options:
combine_service_traffic : False
quantisation_levels: 99
Currently available components are:
* :py:mod:`NODE_LINK_TABLE<primaite.environment.observations.NodeLinkTable>` this does not accept any additional options
* :py:mod:`NODE_STATUSES<primaite.environment.observations.NodeStatuses>`, this does not accept any additional options
* :py:mod:`ACCESS_CONTROL_LIST<primaite.environment.observations.AccessControlList>`, this does not accept additional options
* :py:mod:`LINK_TRAFFIC_LEVELS<primaite.environment.observations.LinkTrafficLevels>`, this accepts the following options:
* ``combine_service_traffic`` - whether to consider bandwidth use separately for each network protocol or combine them into a single bandwidth reading (boolean)
* ``quantisation_levels`` - how many discrete bandwidth usage levels to use for encoding. This can be an integer equal to or greater than 3.
The other configurable item is ``flatten`` which is false by default. When set to true, the observation space is flattened (turned into a 1-D vector). You should use this if your RL agent does not natively support observation space types like ``gym.Spaces.Tuple``.
* **num_train_episodes** [int]
This defines the number of episodes that the agent will train for.
* **num_train_steps** [int]
Determines the number of steps to run in each episode of the training session.
* **num_eval_episodes** [int]
This defines the number of episodes that the agent will be evaluated over.
* **num_eval_steps** [int]
Determines the number of steps to run in each episode of the evaluation session.
* **time_delay** [int]
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
* **session_type** [text]
Type of session to be run (TRAINING, EVALUATION, or BOTH)
* **load_agent** [bool]
Determine whether to load an agent from file
* **agent_load_file** [text]
File path and file name of agent if you're loading one in
* **observation_space_high_value** [int]
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
* **implicit_acl_rule** [str]
Determines which Explicit rule the ACL list has - two options are: DENY or ALLOW.
* **max_number_acl_rules** [int]
Sets a limit on how many ACL rules there can be in the ACL list throughout the training session.
**Reward-Based Config Values**
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
* **Generic [all_ok]** [float]
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
* **Node Hardware State [off_should_be_on]** [float]
The score to give when the node should be on, but is off
* **Node Hardware State [off_should_be_resetting]** [float]
The score to give when the node should be resetting, but is off
* **Node Hardware State [on_should_be_off]** [float]
The score to give when the node should be off, but is on
* **Node Hardware State [on_should_be_resetting]** [float]
The score to give when the node should be resetting, but is on
* **Node Hardware State [resetting_should_be_on]** [float]
The score to give when the node should be on, but is resetting
* **Node Hardware State [resetting_should_be_off]** [float]
The score to give when the node should be off, but is resetting
* **Node Hardware State [resetting]** [float]
The score to give when the node is resetting
* **Node Operating System or Service State [good_should_be_patching]** [float]
The score to give when the state should be patching, but is good
* **Node Operating System or Service State [good_should_be_compromised]** [float]
The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patching_should_be_good]** [float]
The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [float]
The score to give when the state is patching
* **Node Operating System or Service State [compromised_should_be_good]** [float]
The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [float]
The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [float]
The score to give when the state is overwhelmed
* **Node File System State [good_should_be_repairing]** [float]
The score to give when the state should be repairing, but is good
* **Node File System State [good_should_be_restoring]** [float]
The score to give when the state should be restoring, but is good
* **Node File System State [good_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is good
* **Node File System State [good_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is good
* **Node File System State [repairing_should_be_good]** [float]
The score to give when the state should be good, but is repairing
* **Node File System State [repairing_should_be_restoring]** [float]
The score to give when the state should be restoring, but is repairing
* **Node File System State [repairing_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is repairing
* **Node File System State [repairing_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is repairing
* **Node File System State [repairing]** [float]
The score to give when the state is repairing
* **Node File System State [restoring_should_be_good]** [float]
The score to give when the state should be good, but is restoring
* **Node File System State [restoring_should_be_repairing]** [float]
The score to give when the state should be repairing, but is restoring
* **Node File System State [restoring_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is restoring
* **Node File System State [restoring_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is restoring
* **Node File System State [restoring]** [float]
The score to give when the state is restoring
* **Node File System State [corrupt_should_be_good]** [float]
The score to give when the state should be good, but is corrupt
* **Node File System State [corrupt_should_be_repairing]** [float]
The score to give when the state should be repairing, but is corrupt
* **Node File System State [corrupt_should_be_restoring]** [float]
The score to give when the state should be restoring, but is corrupt
* **Node File System State [corrupt_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is corrupt
* **Node File System State [corrupt]** [float]
The score to give when the state is corrupt
* **Node File System State [destroyed_should_be_good]** [float]
The score to give when the state should be good, but is destroyed
* **Node File System State [destroyed_should_be_repairing]** [float]
The score to give when the state should be repairing, but is destroyed
* **Node File System State [destroyed_should_be_restoring]** [float]
The score to give when the state should be restoring, but is destroyed
* **Node File System State [destroyed_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is destroyed
* **Node File System State [destroyed]** [float]
The score to give when the state is destroyed
* **Node File System State [scanning]** [float]
The score to give when the state is scanning
* **IER Status [red_ier_running]** [float]
The score to give when a red agent IER is permitted to run
* **IER Status [green_ier_blocked]** [float]
The score to give when a green agent IER is prevented from running
**Patching / Reset Durations**
* **os_patching_duration** [int]
The number of steps to take when patching an Operating System
* **node_reset_duration** [int]
The number of steps to take when resetting a node's hardware state
* **service_patching_duration** [int]
The number of steps to take when patching a service
* **file_system_repairing_limit** [int]:
The number of steps to take when repairing the file system
* **file_system_restoring_limit** [int]
The number of steps to take when restoring the file system
* **file_system_scanning_limit** [int]
The number of steps to take when scanning the file system
* **deterministic** [bool]
Set to true if the agent evaluation should be deterministic. Default is ``False``
* **seed** [int]
Seed used in the randomisation in agent training. Default is ``None``
The Lay Down Config
*******************
The lay down config file consists of the following attributes:
* **itemType: STEPS** [int]
* **item_type: PORTS** [int]
Provides a list of ports modelled in this session
* **item_type: SERVICES** [freetext]
Provides a list of services modelled in this session
* **item_type: NODE**
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
* **node_class** [enum]: Relates to the base type of the node. Can be SERVICE, ACTIVE or PASSIVE. PASSIVE nodes do not have an operating system or services. ACTIVE nodes have an operating system, but no services. SERVICE nodes have both an operating system and one or more services
* **node_type** [enum]: Relates to the component type. Can be one of CCTV, SWITCH, COMPUTER, LINK, MONITOR, PRINTER, LOP, RTU, ACTUATOR or SERVER
* **priority** [enum]: Provides a priority for each node. Can be one of P1, P2, P3, P4 or P5 (which P1 being the highest)
* **hardware_state** [enum]: The initial hardware state of the node. Can be one of ON, OFF or RESETTING
* **ip_address** [IP address]: The IP address of the component in format xxx.xxx.xxx.xxx
* **software_state** [enum]: The intial state of the node operating system. Can be GOOD, PATCHING or COMPROMISED
* **file_system_state** [enum]: The initial state of the node file system. Can be GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING
* **services**: For each service associated with the node:
* **name** [freetext]: Free-text name of the service, but must match one of the services defined for the system in the services list
* **port** [int]: Integer value of the port related to this service, but must match one of the ports defined for the system in the ports list
* **state** [enum]: The initial state of the service. Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
* **item_type: LINK**
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
* **bandwidth** [int]: The bandwidth (in bits/s) of the link
* **source** [int]: The ID of the source node
* **destination** [int]: The ID of the destination node
* **item_type: GREEN_IER**
Defines a green agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin
* **end_step** [int]: The end step (in the episode) for this IER to finish
* **load** [int]: The load (in bits/s) for this IER to apply to links
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
* **source** [int]: The ID of the source node
* **destination** [int]: The ID of the destination node
* **mission_criticality** [enum]: The mission criticality of this IER (with 5 being highest, 1 lowest)
* **item_type: RED_IER**
Defines a red agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin
* **end_step** [int]: The end step (in the episode) for this IER to finish
* **load** [int]: The load (in bits/s) for this IER to apply to links
* **protocol** [freetext]: The protocol to apply to the links. This must match a value in the services list
* **port** [int]: The port that the protocol is running on. This must match a value in the ports list
* **source** [int]: The ID of the source node
* **destination** [int]: The ID of the destination node
* **mission_criticality** [enum]: Not currently used. Default to 0
* **item_type: GREEN_POL**
Defines a green agent pattern-of-life instruction. It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this PoL to begin
* **end_step** [int]: Not currently used. Default to same as start step
* **nodeId** [int]: The ID of the node to apply the PoL to
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
* **state** [enuum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state)
* **item_type: RED_POL**
Defines a red agent pattern-of-life instruction. It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this PoL to begin
* **end_step** [int]: Not currently used. Default to same as start step
* **targetNodeId** [int]: The ID of the node to apply the PoL to
* **initiator** [enum]: What initiates the PoL. Can be DIRECT, IER or SERVICE
* **type** [enum]: The type of PoL to apply. Can be one of OPERATING, OS or SERVICE
* **protocol** [freetext]: The protocol to be affected if SERVICE type is chosen. Must match a value in the services list
* **state** [enum]: The state to apply to the node (which represents the PoL change). Can be one of ON, OFF or RESETTING (for node state) or GOOD, PATCHING or COMPROMISED (for Software State) or GOOD, PATCHING, COMPROMISED or OVERWHELMED (for service state) or GOOD, CORRUPT, DESTROYED, REPAIRING or RESTORING (for file system state)
* **sourceNodeId** [int] The ID of the source node containing the service to check (used for SERVICE initiator)
* **sourceNodeService** [freetext]: The service on the source node to check (used for SERVICE initiator). Must match a value in the services list for this node
* **sourceNodeServiceState** [enum]: The state of the source node service to check (used for SERVICE initiator). Can be one of GOOD, PATCHING, COMPROMISED or OVERWHELMED
* **item_type: ACL_RULE**
Defines an initial Access Control List (ACL) rule. It should consist of:
* **id** [int]: Unique ID for this YAML item
* **permission** [enum]: Defines either an allow or deny rule. Value must be either DENY or ALLOW
* **source** [IP address]: Defines the source IP address for the rule in xxx.xxx.xxx.xxx format
* **destination** [IP address]: Defines the destination IP address for the rule in xxx.xxx.xxx.xxx format
* **protocol** [freetext]: Defines the protocol for the rule. Must match a value in the services list
* **port** [int]: Defines the port for the rule. Must match a value in the ports list
* **position** [int]: Defines where to place the ACL rule in the list. Lower index or (higher up in the list) means they are checked first. Index starts at 0 (Python indexes).
This needs a bit of refactoring so I haven't written extensive documentation about the config yet.

View File

@@ -11,132 +11,4 @@ Integrating a user defined blue agent
.. note::
If you are planning to implement custom RL agents into PrimAITE, you must use the project as a repository. If you install PrimAITE as a python package from wheel, custom agents are not supported.
PrimAITE has integration with Ray RLLib and StableBaselines3 agents. All agents interface with PrimAITE through an :py:class:`primaite.agents.agent_abc.AgentSessionABC<Agent Session>` which provides Input/Output of agent savefiles, as well as capturing and plotting performance metrics during training and evaluation. If you wish to integrate a custom blue agent, it is recommended to create a subclass of the :py:class:`primaite.agents.agent_abc.AgentSessionABC` and implement the ``__init__()``, ``_setup()``, ``_save_checkpoint()``, ``learn()``, ``evaluate()``, ``_get_latest_checkpoint``, ``load()``, and ``save()`` methods.
Below is a barebones example of a custom agent implementation:
.. code:: python
# src/primaite/agents/my_custom_agent.py
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
class CustomAgent(AgentSessionABC):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
assert self._training_config.agent_framework == AgentFramework.CUSTOM
assert self._training_config.agent_identifier == AgentIdentifier.MY_AGENT
self._setup()
def _setup(self):
super()._setup()
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._agent = ... # your code to setup agent
def _save_checkpoint(self):
checkpoint_num = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
save_checkpoint = False
if checkpoint_num:
save_checkpoint = episode_count % checkpoint_num == 0
# saves checkpoint if the episode count is not 0 and save_checkpoint flag was set to true
if episode_count and save_checkpoint:
...
# your code to save checkpoint goes here.
# The path should start with self.checkpoints_path and include the episode number.
def learn(self):
...
# call your agent's learning function here.
super().learn() # this will finalise learning and output session metadata
self.save()
def evaluate(self):
...
# call your agent's evaluation function here.
self._env.close()
super().evaluate()
def _get_latest_checkpoint(self):
...
# Load an agent from file.
@classmethod
def load(cls, path):
...
# Create a CustomAgent object which loads model weights from file.
def save(self):
...
# Call your agent's function that saves it to a file
You will also need to modify :py:class:`primaite.primaite_session.PrimaiteSession<PrimaiteSession>` and :py:mod:`primaite.common.enums` to capture your new agent identifiers.
.. code-block:: python
:emphasize-lines: 17, 18
# src/primaite/common/enums.py
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"The Hardcoded agents"
DO_NOTHING = 4
"The DoNothing agents"
RANDOM = 5
"The RandomAgent"
DUMMY = 6
"The DummyAgent"
CUSTOM_AGENT = 7
"Your custom agent"
.. code-block:: python
:emphasize-lines: 3, 11, 12
# src/primaite_session.py
from primaite.agents.my_custom_agent import CustomAgent
# ...
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
if self._training_config.agent_identifier == AgentIdentifier.CUSTOM_AGENT:
self._agent_session = CustomAgent(self._training_config_path, self._lay_down_config_path)
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
Finally, specify your agent in your training config.
.. code-block:: yaml
# ~/primaite/2.0.0/config/path/to/your/config_main.yaml
# Training Config File
agent_framework: CUSTOM
agent_identifier: CUSTOM_AGENT
random_red_agent: False
# ...
Now you can :ref:`run a primaite session<run a primaite session>` with your custom agent by passing in the custom ``config_main``.
PrimAITE uses ARCD GATE for agent integration. In order to use a custom agent with PrimAITE, you must integrate it with ARCD GATE. Please look at the ARCD GATE documentation for more information.

View File

@@ -0,0 +1,48 @@
PrimAITE Game layer
*******************
The Primaite codebase consists of two main modules:
* ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes.
* ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules, including ARCD GATE.
These two components have been decoupled to allow the agent training code in ARCD GATE to be reused with other simulators. The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API. The game layer communicates with ARCD gate using the `Farama Gymnasium Spaces API <https://gymnasium.farama.org/api/spaces/>`_.
..
TODO: write up these APIs and link them here.
Game layer
----------
The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components:
PrimAITE Session
^^^^^^^^^^^^^^^
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. It also sends messages to ARCD GATE to perform reinforcement learning. ``PrimaiteSession`` keeps track of multiple agents of different types.
Agents
^^^^^^
All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types:
* RL agents action during each step is decided by an RL algorithm which lives inside of ARCD GATE. The agent within PrimAITE just acts to format and forward actions decided by an RL policy.
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable.
..
TODO: add seed to stochastic scripted agents
Observations
^^^^^^^^^^^^^^^^^^
An agent's observations are managed by the ``ObservationManager`` class. It generates observations based on the current simulation state dictionary. It also provides the observation space during initial setup. The data is formatted so it's compatible with ``Gymnasium.spaces``. Observation spaces are composed of one or more components which are defined by the ``AbstractObservation`` base class.
Actions
^^^^^^^
An agent's actions are managed by the ``ActionManager``. It converts actions selected by agents (which are typically integers chosen from a ``gymnasium.spaces.Discrete`` space) into simulation-friendly requests. It also provides the action space during initial setup. Action spaces are composed of one or more components which are defined by the ``AbstractAction`` base class.
Rewards
^^^^^^^
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. The reward components are defined by the AbstractReward base class.

View File

@@ -11,7 +11,7 @@ Getting Started
Pre-Requisites
In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.10 installed. If you don't already have it, this is how to install it:
In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.11 installed. If you don't already have it, this is how to install it:
.. code-block:: bash
@@ -33,39 +33,36 @@ In order to get **PrimAITE** installed, you will need to have a python version b
Install PrimAITE
****************
1. Create a primaite directory in your home directory:
1. Create a directory for your PrimAITE project:
.. code-block:: bash
:caption: Unix
mkdir ~/primaite/2.0.0
mkdir ~/primaite/3.0.0
.. code-block:: powershell
:caption: Windows (Powershell)
mkdir ~\primaite\2.0.0
mkdir ~\primaite\3.0.0
2. Navigate to the primaite directory and create a new python virtual environment (venv)
.. code-block:: bash
:caption: Unix
cd ~/primaite/2.0.0
cd ~/primaite/3.0.0
python3 -m venv .venv
.. code-block:: powershell
:caption: Windows (Powershell)
cd ~\primaite\2.0.0
cd ~\primaite\3.0.0
python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
3. Activate the venv
3. Activate the venv
.. code-block:: bash
:caption: Unix
@@ -78,21 +75,34 @@ Install PrimAITE
.\.venv\Scripts\activate
4. Install PrimAITE using pip from PyPi
4. Install PrimAITE from your saved wheel file
.. code-block:: bash
:caption: Unix
pip install path/to/your/primaite.whl
.. code-block:: powershell
:caption: Windows (Powershell)
pip install path\to\your\primaite.whl
5. Install ARCD GATE from wheel file
.. code-block:: bash
:caption: Unix
pip install primaite
pip install path/to/your/arcd_gate-0.1.0-py3-none-any.whl
.. code-block:: powershell
:caption: Windows (Powershell)
pip install primaite
pip install path\to\your\arcd_gate-0.1.0-py3-none-any.whl
5. Perform the PrimAITE setup
6. Perform the PrimAITE setup
.. code-block:: bash
:caption: Unix
@@ -110,13 +120,14 @@ Clone & Install PrimAITE for Development
To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location
of your choice:
1. Clone the repository
.. code-block:: bash
git clone https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE
cd primaite
Create and activate your Python virtual environment (venv)
2. Create and activate your Python virtual environment (venv)
.. code-block:: bash
:caption: Unix
@@ -130,8 +141,7 @@ Create and activate your Python virtual environment (venv)
python3 -m venv venv
.\venv\Scripts\activate
Install PrimAITE with the dev extra
3. Install PrimAITE with the dev extra
.. code-block:: bash
:caption: Unix
@@ -144,4 +154,16 @@ Install PrimAITE with the dev extra
pip install -e .[dev]
4. Install ARCD GATE from wheel file
.. code-block:: bash
:caption: Unix
pip install GATE/arcd_gate-0.1.0-py3-none-any.whl
.. code-block:: powershell
:caption: Windows (Powershell)
pip install GATE\arcd_gate-0.1.0-py3-none-any.whl
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).

View File

@@ -1,57 +0,0 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
v1.2 to v2.0 Migration guide
============================
**1. Installing PrimAITE**
Like before, you can install primaite from the repository by running ``pip install -e .``. But, there is now an additional setup step which does several things, like setting up user directories, copy default configs and notebooks, etc. Once you have installed PrimAITE to your virtual environment, run this command to finalise setup.
.. code-block:: bash
primaite setup
**2. Running a training session**
In version 1.2 of PrimAITE, the main entry point for training or evaluating agents was the ``src/primaite/main.py`` file. v2.0.0 introduced managed 'sessions' which are responsible for reading configuration files, performing training, and writing outputs.
``main.py`` file still runs a training session but it now uses the new `PrimaiteSession`, and it now requires you to provide the path to your config files.
.. code-block:: bash
python src/primaite/main.py --tc path/to/training-config.yaml --ldc path/to/laydown-config.yaml
Alternatively, the session can be invoked via the commandline by running:
.. code-block:: bash
primaite session --tc path/to/training-config.yaml --ldc path/to/laydown-config.yaml
**3. Location of configs**
In version 1.2, training configs and laydown configs were all stored in the project repository under ``src/primaite/config``. Version 2.0.0 introduced user data directories, and now when you install and setup PrimAITE, config files are stored in your user data location. On Linux/OSX, this is stored in ``~/primaite/2.0.0/config``. On Windows, this is stored in ``C:\Users\<your username>\primaite\configs``. Upon first setup, the configs folder is populated with some default yaml files. It is recommended that you store all your custom configuration files here.
**4. Contents of configs**
Some things that were previously part of the laydown config are now part of the traning config.
* Actions
If you have custom configs which use these, you will need to adapt them by moving the configuration from the laydown config to the training config.
Also, there are new configurable items in the training config:
* Observations
* Agent framework
* Agent
* Deep learning framework
* random red agents
* seed
* deterministic
* hard coded agent view
Each of these items have default values which are designed so that PrimAITE has the same behaviour as it did in 1.2.0, so you do not have to specify them.
ACL Rules in laydown configs have a new required parameter: ``position``. The lower the position, the higher up in the ACL table the rule will placed. If you have custom laydowns, you will need to go through them and add a position to each ACL_RULE.

View File

@@ -14,199 +14,200 @@ A PrimAITE session can be ran either with the ``primaite session`` command from
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
Both the ``primaite session`` and :func:`primaite.main.run` take a training config and a lay down config as parameters.
.. note::
🚧 *UNDER CONSTRUCTION* 🚧
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to training config yaml file>
lay_down_config = <path to lay down config yaml file>
run(training_config, lay_down_config)
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``).
The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``.
``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``.
To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options.
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to legacy training config yaml file>
lay_down_config = <path to legacy lay down config yaml file>
run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True)
Outputs
-------
PrimAITE produces four types of outputs:
* Session Metadata
* Results
* Diagrams
* Saved agents (training checkpoints and a final trained agent)
**Session Metadata**
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
* **uuid** - The UUID assigned to the session upon instantiation.
* **start_datetime** - The date & time the session started in iso format.
* **end_datetime** - The date & time the session ended in iso format.
* **learning**
* **total_episodes** - The total number of training episodes completed.
* **total_time_steps** - The total number of training time steps completed.
* **evaluation**
* **total_episodes** - The total number of evaluation episodes completed.
* **total_time_steps** - The total number of evaluation time steps completed.
* **env**
* **training_config**
* **All training config items**
* **lay_down_config**
* **All lay down config items**
**Results**
PrimAITE automatically creates two sets of results from each learning and evaluation session:
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
* All transactions - a csv file listing the following values for every step of every episode:
* Timestamp
* Episode number
* Step number
* Reward value
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Initial observation space (what the blue agent observed when it decided its action)
**Diagrams**
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
**Saved agents**
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
**Example Session Directory Structure**
.. code-block:: text
~/
└── primaite/
└── 2.0.0/
└── sessions/
└── 2023-07-18/
└── 2023-07-18_11-06-04/
├── evaluation/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
│ ├── checkpoints/
│ │ └── sb3ppo_10.zip
│ ├── SB3_PPO.zip
│ └── tensorboard_logs/
│ ├── PPO_1/
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
│ ├── PPO_2/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
│ ├── PPO_3/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
│ ├── PPO_4/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
│ ├── PPO_5/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
│ ├── PPO_6/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
│ ├── PPO_7/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
│ ├── PPO_8/
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
│ ├── PPO_9/
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
│ └── PPO_10/
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
Loading a session
-----------------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
..
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --load "path/to/session"
primaite session --tc ./config/my_training_config.yaml --ldc ./config/my_lay_down_config.yaml
.. code-tab:: bash
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --load "path\to\session"
primaite session --tc .\config\my_training_config.yaml --ldc .\config\my_lay_down_config.yaml
.. code-tab:: python
.. code-block:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
training_config = <path to training config yaml file>
lay_down_config = <path to lay down config yaml file>
run(training_config, lay_down_config)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
When a session is ran, a session output sub-directory is created in the users app sessions directory (``~/primaite/2.0.0/sessions``).
The sub-directory is formatted as such: ``~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-dd>/``
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/2.0.0/sessions/2023-01-31/2023-01-31_17-30-00/``.
``primaite session`` can be ran in the terminal/command prompt without arguments. It will use the default configs in the directory ``primaite/config/example_config``.
To run a PrimAITE session using legacy training or laydown config files, add the ``--legacy-tc`` and/or ``legacy-ldc`` options.
.. code-block:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --tc ./config/my_legacy_training_config.yaml --legacy-tc --ldc ./config/my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: powershell
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --tc .\config\my_legacy_training_config.yaml --legacy-tc --ldc .\config\my_legacy_lay_down_config.yaml --legacy-ldc
.. code-block:: python
:caption: Python
from primaite.main import run
training_config = <path to legacy training config yaml file>
lay_down_config = <path to legacy lay down config yaml file>
run(training_config, lay_down_config, legacy_training_config=True, legacy_lay_down_config=True)
Outputs
-------
PrimAITE produces four types of outputs:
* Session Metadata
* Results
* Diagrams
* Saved agents (training checkpoints and a final trained agent)
**Session Metadata**
PrimAITE creates a ``session_metadata.json`` file that contains the following metadata:
* **uuid** - The UUID assigned to the session upon instantiation.
* **start_datetime** - The date & time the session started in iso format.
* **end_datetime** - The date & time the session ended in iso format.
* **learning**
* **total_episodes** - The total number of training episodes completed.
* **total_time_steps** - The total number of training time steps completed.
* **evaluation**
* **total_episodes** - The total number of evaluation episodes completed.
* **total_time_steps** - The total number of evaluation time steps completed.
* **env**
* **training_config**
* **All training config items**
* **lay_down_config**
* **All lay down config items**
**Results**
PrimAITE automatically creates two sets of results from each learning and evaluation session:
* Average reward per episode - a csv file listing the average reward for each episode of the session. This provides, for example, an indication of the change over a training session of the reward value
* All transactions - a csv file listing the following values for every step of every episode:
* Timestamp
* Episode number
* Step number
* Reward value
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
* Initial observation space (what the blue agent observed when it decided its action)
**Diagrams**
* For each session, PrimAITE automatically creates a visualisation of the system / network lay down configuration.
* For each learning and evaluation task within the session, PrimAITE automatically plots the average reward per episode using PlotLY and saves it to the learning or evaluation subdirectory in the session directory.
**Saved agents**
For each training session, assuming the agent being trained implements the *save()* function and this function is called by the code, PrimAITE automatically saves the agent state.
**Example Session Directory Structure**
.. code-block:: text
~/
└── primaite/
└── 2.0.0/
└── sessions/
└── 2023-07-18/
└── 2023-07-18_11-06-04/
├── evaluation/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ └── average_reward_per_episode_2023-07-18_11-06-04.png
├── learning/
│ ├── all_transactions_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.csv
│ ├── average_reward_per_episode_2023-07-18_11-06-04.png
│ ├── checkpoints/
│ │ └── sb3ppo_10.zip
│ ├── SB3_PPO.zip
│ └── tensorboard_logs/
│ ├── PPO_1/
│ │ └── events.out.tfevents.1689674765.METD-9PMRFB3.42960.0
│ ├── PPO_2/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.1
│ ├── PPO_3/
│ │ └── events.out.tfevents.1689674766.METD-9PMRFB3.42960.2
│ ├── PPO_4/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.3
│ ├── PPO_5/
│ │ └── events.out.tfevents.1689674767.METD-9PMRFB3.42960.4
│ ├── PPO_6/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.5
│ ├── PPO_7/
│ │ └── events.out.tfevents.1689674768.METD-9PMRFB3.42960.6
│ ├── PPO_8/
│ │ └── events.out.tfevents.1689674769.METD-9PMRFB3.42960.7
│ ├── PPO_9/
│ │ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.8
│ └── PPO_10/
│ └── events.out.tfevents.1689674770.METD-9PMRFB3.42960.9
├── network_2023-07-18_11-06-04.png
└── session_metadata.json
Loading a session
-----------------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite/2.0.0
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite\2.0.0
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory

View File

@@ -0,0 +1,90 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Request System
==============
``SimComponent`` in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``RequestType``.
Just like other aspects of SimComponent, the request typess are not managed centrally for the whole simulation, but instead they are dynamically created and updated based on the nodes, links, and other components that currently exist. This was achieved in the following way:
- API
An ``RequestType`` contains two elements:
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as `['network', 'node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
2. ``context`` - optional extra information that can be used to decide how to process the request. This is formatted as a dictionary. For example, if the request requires authentication, the context can include information about the user that initiated the request to decide if their permissions are sufficient.
- request
The request is a list of strings which help specify who should handle the request. The strings in the request list help RequestManagers traverse the 'ownership tree' of SimComponent. The example given above would be handled in the following way:
1. ``Simulation`` receives `['network', 'node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
The first element of the request is ``network``, therefore it passes the request down to its network.
2. ``Network`` receives `['node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
The first element of the request is ``node``, therefore the network looks at the node uuid and passes the request down to the node with that uuid.
3. ``Node`` receives `['service', '<service-uuid>', 'restart']`.
The first element of the request is ``service``, therefore the node looks at the service uuid and passes the rest of the request to the service with that uuid.
4. ``Service`` receives ``['restart']``.
Since ``restart`` is a defined request type in the service's own RequestManager, the service performs a restart.
Technical Detail
----------------
This system was achieved by implementing two classes, :py:class:`primaite.simulator.core.RequestType`, and :py:class:`primaite.simulator.core.RequestManager`.
``RequestType``
------
The ``RequestType`` object stores a reference to a method that executes the request, for example a node could have a request type that stores a reference to ``self.turn_on()``. Technically, this can be any callable that accepts `request, context` as it's parameters. In practice, this is often defined using ``lambda`` functions within a component's ``self._init_request_manager()`` method. Optionally, the ``RequestType`` object can also hold a validator that will permit/deny the request depending on context.
``RequestManager``
------------------
The ``RequestManager`` object stores a mapping between strings and request types. It is responsible for processing the request and passing it down the ownership tree. Technically, the ``RequestManager`` is itself a callable that accepts `request, context` tuple, and so it can be chained with other request managers.
A simple example without chaining can be seen in the :py:class:`primaite.simulator.file_system.file_system.File` class.
.. code-block:: python
class File(FileSystemItemABC):
...
def _init_request_manager(self):
...
request_manager.add_request("scan", RequestType(func=lambda request, context: self.scan()))
request_manager.add_request("repair", RequestType(func=lambda request, context: self.repair()))
request_manager.add_request("restore", RequestType(func=lambda request, context: self.restore()))
*ellipses (``...``) used to omit code impertinent to this explanation*
Chaining RequestManagers
-----------------------
A request function needs to be a callable that accepts ``request, context`` as parameters. Since the request manager resolves requests by invoking it with ``request, context`` as parameter, it is possible to use a ``RequestManager`` as a ``RequestType``.
When a RequestManager accepts a request, it pops the first element and uses it to decide where it should send the remaining request. This is how PrimAITE traverses the ownership tree. If the ``RequestType`` has another ``RequestManager`` as its function, the request will be routed again. Each time the request is passed to a new request manager, the first element is popped.
An example of how this works is in the :py:class:`primaite.simulator.network.hardware.base.Node` class.
.. code-block:: python
class Node(SimComponent):
...
def _init_request_manager(self):
...
# a regular action which is processed by the Node itself
request_manager.add_request("turn_on", RequestType(func=lambda request, context: self.turn_on()))
# if the Node receives a request where the first word is 'service', it will use a dummy manager
# called self._service_request_manager to pass on the reqeust to the relevant service. This dummy
# manager is simply here to map the service UUID that that service's own action manager. This is
# done because the next string after "service" is always the uuid of that service, so we need an
# RequestManager to pop that string before sending it onto the relevant service's RequestManager.
self._service_request_manager = RequestManager()
request_manager.add_request("service", RequestType(func=self._service_request_manager))
...
def install_service(self, service):
self.services[service.uuid] = service
...
# Here, the service UUID is registered to allow passing actions between the node and the service.
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))

View File

@@ -23,4 +23,8 @@ Contents
simulation_components/network/network
simulation_components/system/internal_frame_processing
simulation_components/system/software
action_system
simulation_components/system/data_manipulation_bot
simulation_components/system/database_client_server
simulation_components/system/dns_client_server
simulation_components/system/ftp_client_server
simulation_components/system/web_browser_and_web_server_service

View File

@@ -14,7 +14,7 @@ The ``DatabaseService`` provides a SQL database server simulation by extending t
Key capabilities
^^^^^^^^^^^^^^^^
- Initialises a SQLite database file in the ``Node``'s ``FileSystem`` upon creation.
- Initialises a SQLite database file in the ``Node`` 's ``FileSystem`` upon creation.
- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs.
- Authenticates connections using a configurable password.
- Executes SQL queries against the SQLite database.

View File

@@ -12,7 +12,7 @@ and a domain controller for managing software and users.
Each node of the simulation 'tree' has responsibility for creating, deleting, and updating its direct descendants. Also,
when a component's ``describe_state()`` method is called, it will include the state of its descendants. The
``apply_action()`` method can be used to act on a component or one of its descendatnts. The diagram below shows the
``apply_request()`` method can be used to act on a component or one of its descendatnts. The diagram below shows the
relationship between components.
.. image:: _static/component_relationship.png
@@ -25,9 +25,9 @@ relationship between components.
Actions
=======
Agents can interact with the simulation by using actions. Actions are standardised with the
:py:class:`primaite.simulation.core.Action` class, which just holds a reference to two special functions.
:py:class:`primaite.simulation.core.RequestType` class, which just holds a reference to two special functions.
1. The action function itself, it must accept a `request` parameters which is a list of strings that describe what the
1. The request function itself, it must accept a `request` parameters which is a list of strings that describe what the
action should do. It must also accept a `context` dict which can house additional information surrounding the action.
For example, the context will typically include information about which entity intiated the action.
2. A validator function. This function should return a boolean value that decides if the request is permitted or not.

View File

@@ -0,0 +1,31 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Simulation State
==============
``SimComponent`` in the simulation have a method called ``describe_state`` which returns a dictionary of the state of the component. This is used to report pertinent data that could impact agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` reports not only it's own attributes in the state but also that of its child components. I.e. a computer node will report the state of its ``FileSystem``, and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling childrens' own ``describe_state`` methods.
The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then pass the state to the agents once per simulation step. For this reason, all ``SimComponent`` must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``.
This code snippet demonstrates how the state information is defined within the ``SimComponent`` class:
.. code-block:: python
class Node(SimComponent):
operating_state: NodeOperatingState = NodeOperatingState.OFF
services: Dict[str, Service] = {}
def describe_state(self) -> Dict:
state = super().describe_state()
state["operating_state"] = self.operating_state.value
state["services"] = {uuid: svc.describe_state() for uuid, svc in self.services.items()}
return state
class Service(SimComponent):
health_state: ServiceHealthState = ServiceHealthState.GOOD
def describe_state(self) -> Dict:
state = super().describe_state()
state["health_state"] = self.health_state.value
return state

View File

@@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"gym==0.21.0",
"gymnasium==0.28.1",
"jupyterlab==3.6.1",
"kaleido==0.2.1",
"matplotlib==3.7.1",
@@ -35,7 +35,7 @@ dependencies = [
"polars==0.18.4",
"prettytable==3.8.0",
"PyYAML==6.0",
"stable-baselines3==1.6.2",
"stable-baselines3[extra]==2.1.0",
"tensorflow==2.12.0",
"typer[all]==0.9.0",
"pydantic==2.1.1"

View File

@@ -1 +1 @@
2.0.0
3.0.0a1

View File

@@ -1,2 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Access Control List. Models firewall functionality."""

View File

@@ -1,198 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A class that implements the access control list implementation for the network."""
import logging
from typing import Dict, Final, List, Union
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import RulePermissionType
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class AccessControlList:
"""Access Control List class."""
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
"""Init."""
# Implicit ALLOW or DENY firewall spec
self.acl_implicit_permission = implicit_permission
# Implicit rule in ACL list
if self.acl_implicit_permission == RulePermissionType.DENY:
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
else:
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
# Maximum number of ACL Rules in ACL
self.max_acl_rules: int = max_acl_rules
# A list of ACL Rules
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
@property
def acl(self) -> List[Union[ACLRule, None]]:
"""Public access method for private _acl."""
return self._acl + [self.acl_implicit_rule]
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches.
:param _rule: The rule object to check
:type _rule: ACLRule
:param _source_ip_address: Source IP address to compare
:type _source_ip_address: str
:param _dest_ip_address: Destination IP address to compare
:type _dest_ip_address: str
:return: True if there is a match, otherwise False.
:rtype: bool
"""
if (
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
):
return True
else:
return False
def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool:
"""
Checks for rules that block a protocol / port.
Args:
_source_ip_address: the source IP address to check
_dest_ip_address: the destination IP address to check
_protocol: the protocol to check
_port: the port to check
Returns:
Indicates block if all conditions are satisfied.
"""
for rule in self.acl:
if isinstance(rule, ACLRule):
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule.get_permission() == RulePermissionType.DENY:
return True
elif rule.get_permission() == RulePermissionType.ALLOW:
return False
# If there has been no rule to allow the IER through, it will return a blocked signal by default
return True
def add_rule(
self,
_permission: RulePermissionType,
_source_ip: str,
_dest_ip: str,
_protocol: str,
_port: str,
_position: str,
) -> None:
"""
Adds a new rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
_source_ip: the source IP address
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
"""
try:
position_index = int(_position)
except TypeError:
_LOGGER.info(f"Position {_position} could not be converted to integer.")
return
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
# Checks position is in correct range
if self.max_acl_rules - 1 > position_index > -1:
try:
_LOGGER.info(f"Position {position_index} is valid.")
# Check to see Agent will not overwrite current ACL in ACL list
if self._acl[position_index] is None:
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
# Adds rule
self._acl[position_index] = new_rule
else:
# Cannot overwrite it
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
return
except Exception:
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
else:
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
def remove_rule(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Removes a rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
_source_ip: the source IP address
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
"""
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
delete_rule_hash = hash(rule_to_delete)
for index in range(0, len(self._acl)):
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
self._acl[index] = None
def remove_all_rules(self) -> None:
"""Removes all rules."""
for i in range(len(self._acl)):
self._acl[i] = None
def get_dictionary_hash(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> int:
"""
Produces a hash value for a rule.
Args:
_permission: the permission value (e.g. "ALLOW" or "DENY")
_source_ip: the source IP address
_dest_ip: the destination IP address
_protocol: the protocol
_port: the port
Returns:
Hash value based on rule parameters.
"""
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
return hash_value
def get_relevant_rules(
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
) -> Dict[int, ACLRule]:
"""Get all ACL rules that relate to the given arguments.
:param _source_ip_address: the source IP address to check
:param _dest_ip_address: the destination IP address to check
:param _protocol: the protocol to check
:param _port: the port to check
:return: Dictionary of all ACL rules that relate to the given arguments
:rtype: Dict[int, ACLRule]
"""
relevant_rules = {}
for rule in self.acl:
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
):
# There's a matching rule.
relevant_rules[self._acl.index(rule)] = rule
return relevant_rules

View File

@@ -1,87 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A class that implements an access control list rule."""
from primaite.common.enums import RulePermissionType
class ACLRule:
"""Access Control List Rule class."""
def __init__(
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
) -> None:
"""
Initialise an ACL Rule.
:param _permission: The permission (ALLOW or DENY)
:param _source_ip: The source IP address
:param _dest_ip: The destination IP address
:param _protocol: The rule protocol
:param _port: The rule port
"""
self.permission: RulePermissionType = _permission
self.source_ip: str = _source_ip
self.dest_ip: str = _dest_ip
self.protocol: str = _protocol
self.port: str = _port
def __hash__(self) -> int:
"""
Override the hash function.
Returns:
Returns hash of core parameters.
"""
return hash(
(
self.permission,
self.source_ip,
self.dest_ip,
self.protocol,
self.port,
)
)
def get_permission(self) -> str:
"""
Gets the permission attribute.
Returns:
Returns permission attribute
"""
return self.permission
def get_source_ip(self) -> str:
"""
Gets the source IP address attribute.
Returns:
Returns source IP address attribute
"""
return self.source_ip
def get_dest_ip(self) -> str:
"""
Gets the desintation IP address attribute.
Returns:
Returns destination IP address attribute
"""
return self.dest_ip
def get_protocol(self) -> str:
"""
Gets the protocol attribute.
Returns:
Returns protocol attribute
"""
return self.protocol
def get_port(self) -> str:
"""
Gets the port attribute.
Returns:
Returns port attribute
"""
return self.port

View File

@@ -1,2 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Common interface between RL agents from different libraries and PrimAITE."""

View File

@@ -1,319 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Optional, Union
from uuid import uuid4
import primaite
from primaite import getLogger, PRIMAITE_PATHS
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
_LOGGER: Logger = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
class AgentSessionABC(ABC):
"""
An ABC that manages training and/or evaluation of agents in PrimAITE.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
@abstractmethod
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
) -> None:
"""
Initialise an agent session from config files, or load a previous session.
If training configuration and laydown configuration are provided with a session path,
the session path will be used.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param session_path: directory path of the session to load
"""
# initialise variables
self._env: Primaite
self._agent = None
self._can_learn: bool = False
self._can_evaluate: bool = False
self.is_eval = False
self.legacy_training_config = legacy_training_config
self.legacy_lay_down_config = legacy_lay_down_config
self.session_timestamp: datetime = datetime.now()
# convert session to path
if session_path is not None:
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
# load session
self.load(session_path)
else:
# set training config path
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(
self._training_config_path, legacy_file=legacy_training_config
)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = str(uuid4())
"The session timestamp"
self.session_path = get_session_path(self.session_timestamp)
"The Session path"
@property
def timestamp_str(self) -> str:
"""The session timestamp as a string."""
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
@property
def learning_path(self) -> Path:
"""The learning outputs path."""
path = self.session_path / "learning"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def evaluation_path(self) -> Path:
"""The evaluation outputs path."""
path = self.session_path / "evaluation"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def checkpoints_path(self) -> Path:
"""The Session checkpoints path."""
path = self.learning_path / "checkpoints"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def uuid(self) -> str:
"""The Agent Session UUID."""
return self._uuid
def _write_session_metadata_file(self) -> None:
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_path`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self.uuid,
"start_datetime": self.session_timestamp.isoformat(),
"end_datetime": None,
"learning": {"total_episodes": None, "total_time_steps": None},
"evaluation": {"total_episodes": None, "total_time_steps": None},
"env": {
"training_config": self._training_config.to_dict(json_serializable=True),
"lay_down_config": self._lay_down_config,
},
}
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished writing session metadata file")
def _update_session_metadata_file(self) -> None:
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._env.actual_episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._env.actual_episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
@abstractmethod
def _setup(self) -> None:
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
self._write_session_metadata_file()
self._can_learn = True
self._can_evaluate = False
@abstractmethod
def _save_checkpoint(self) -> None:
pass
@abstractmethod
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_learn:
_LOGGER.info("Finished learning")
_LOGGER.debug("Writing transactions")
self._update_session_metadata_file()
self._can_evaluate = True
self.is_eval = False
@abstractmethod
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_evaluate:
self._update_session_metadata_file()
self.is_eval = True
self._plot_av_reward_per_episode(learning_session=False)
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self) -> None:
pass
def load(self, path: Union[str, Path]) -> None:
"""Load an agent from file."""
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
# set training config path
self._training_config_path: Union[Path, str] = training_config_path
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
self._lay_down_config_path: Union[Path, str] = laydown_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
# set random UUID for session
self._uuid = md_dict["uuid"]
# set the session path
self.session_path = path
"The Session path"
@property
def _saved_agent_path(self) -> Path:
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
return self.learning_path / file_name
@abstractmethod
def save(self) -> None:
"""Save the agent."""
pass
@abstractmethod
def export(self) -> None:
"""Export the agent to transportable file format."""
pass
def close(self) -> None:
"""Closes the agent."""
self._env.episode_av_reward_writer.close() # noqa
self._env.transaction_writer.close() # noqa
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
# self.close()
title = f"PrimAITE Session {self.timestamp_str} "
subtitle = str(self._training_config)
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
if learning_session:
title += "(Learning)"
path = self.learning_path / csv_file
image_path = self.learning_path / image_file
else:
title += "(Evaluation)"
path = self.evaluation_path / csv_file
image_path = self.evaluation_path / image_file
fig = plot_av_reward_per_episode(path, title, subtitle)
fig.write_image(image_path)
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")

View File

@@ -1,118 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import time
from abc import abstractmethod
from pathlib import Path
from typing import Any, Optional, Union
import numpy as np
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
def load(cls, path: Union[str, Path] = None) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -1,515 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from typing import Dict, List, Union
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
get_node_of_ip,
transform_action_acl_enum,
transform_change_obs_readable,
)
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardCodedAgentView
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs: np.ndarray) -> int:
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
# full view action using observation space, action
# history and reward feedback
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, IER]:
"""Get blocked green IERs.
:param green_iers: Green IERs to check for being
:type green_iers: Dict[str, IER]
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
:rtype: Dict[str, IER]
"""
blocked_green_iers = {}
for green_ier_id, green_ier in green_iers.items():
source_node_id = green_ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = green_ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = green_ier.get_protocol() # e.g. 'TCP'
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
def get_matching_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""Get list of ACL rules which are relevant to an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
dest_node_id = ier.get_dest_node_id()
dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""
Get blocking ACL rules for an IER.
.. warning::
Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked).
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
blocked_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
blocked_rules[rule_key] = rule_value
return blocked_rules
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[int, ACLRule]:
"""Get all allowing ACL rules for an IER.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_matching_acl_rules(
self,
source_node_id: str,
dest_node_id: str,
protocol: str,
port: str,
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
) -> Dict[int, ACLRule]:
"""Filter ACL rules to only those which are relevant to the specified nodes.
:param source_node_id: Source node
:type source_node_id: str
:param dest_node_id: Destination nodes
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: str
:param port: Network port
:type port: str
:param acl: Access Control list which will be filtered
:type acl: AccessControlList
:param nodes: The environment's node directory.
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
:param services_list: List of services registered for the environment.
:type services_list: List[str]
:return: Filtered version of 'acl'
:rtype: Dict[str, ACLRule]
"""
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
else:
source_node_address = source_node_id
if dest_node_id != "ANY":
dest_node_address = nodes[str(dest_node_id)].ip_address
else:
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
# TODO: This should throw an error because protocol is a string
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_allow_acl_rules(
self,
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[int, ACLRule]:
"""List ALLOW rules relating to specified nodes.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "ALLOW":
allowed_rules[rule_key] = rule_value
return allowed_rules
def get_deny_acl_rules(
self,
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[int, ACLRule]:
"""List DENY rules relating to specified nodes.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
)
allowed_rules = {}
for rule_key, rule_value in matching_rules.items():
if rule_value.get_permission() == "DENY":
allowed_rules[rule_key] = rule_value
return allowed_rules
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
- Which ACL rules already exist, - otherwise:
- The agent would perminently get stuck in a loop of performing the same action over and over.
(best action is to block something, but its already blocked but doesn't know this)
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
if it doesnt know what rules exist)
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
based on the observation space. Additionally, potentially in the future, once a node state
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
If a service node becomes compromised there's a decision to make - do we block that service?
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
Cons: Will block a green IER, decreasing the reward
We decide to block the service.
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
# obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
_, _, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# 1. Check if node is compromised. If so we want to block its outwards services
# a. If it is comprimised check if there's an allow rule we should delete.
# cons: might delete a multi-rule from any source node (ANY -> x)
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
found_action = False
for service_num, service_states in enumerate(s):
for x, service_state in enumerate(service_states):
if service_state == "COMPROMISED":
action_source_id = x + 1 # +1 as 0 is any
action_destination_id = "ANY"
action_protocol = service_num + 1 # +1 as 0 is any
action_port = "ANY"
allow_rules = self.get_allow_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
deny_rules = self.get_deny_acl_rules(
action_source_id,
action_destination_id,
action_protocol,
action_port,
self._env.acl,
self._env.nodes,
self._env.services_list,
)
if len(allow_rules) > 0:
# Check if there's an allow rule we should delete
rule = list(allow_rules.values())[0]
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
elif len(deny_rules) > 0:
# TODO OPTIONAL
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
# to create another
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
continue
else:
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
action_decision = "CREATE"
action_permission = "DENY"
break
if found_action:
break
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
# add an Allow rule if the green IER is being blocked.
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
# If there's a DENY rule delete it if:
# - There isn't already a deny rule
# - It doesnt allows a comprimised node to become operational.
# b. Add an ALLOW rule if:
# - There isn't already an allow rule
# - It doesnt allows a comprimised node to become operational
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
action_decision = "CREATE"
action_permission = "ALLOW"
action_source_id = int(ier.get_source_node_id())
action_destination_id = int(ier.get_dest_node_id())
action_protocol_name = ier.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = ier.get_port()
action_port = (
self._env.ports_list.index(action_port_name) + 1
) # convert port name e.g. '80' to index
found_action = True
break
if found_action:
action = [
action_decision,
action_permission,
action_source_id,
action_destination_id,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
else:
# If no good/useful action has been found, just perform a nothing action
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
action = transform_action_acl_enum(action)
action = get_new_action(action, self._env.action_dict)
return action
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge
of previous actions taken and NO reward feedback.
We rely on randomness to select the precise action, as we want to
block all traffic originating from a compromised node, without being
able to tell:
1. Which ACL rules already exist
2. Which actions the agent has already tried.
There is a high probability that the correct rule will not be deleted
before the state becomes overwhelmed.
Currently, a deny rule does not overwrite an allow rule. The allow
rules must be deleted.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, _, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = (
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block
# Bad assumption that number of protocols equals number of ports
# AND no rules exist with an ANY port
action_port = np.random.choice(list(range(1, len(s) + 1)))
action = [
action_decision,
action_permission,
action_source_ip,
action_destination_ip,
action_protocol,
action_port,
]
action = transform_action_acl_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good/useful action has been found, just perform a nothing action
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, action_dict)
return nothing_action

View File

@@ -1,125 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
class HardCodedNodeAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic Node agent."""
def _calculate_action(self, obs: np.ndarray) -> int:
"""
Calculate a good node-based action for the blue agent to take.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)
_, o, os, *s = r_obs
if len(r_obs) == 4: # only 1 service
s = [*s]
# Check in order of most important states (order doesn't currently
# matter, but it probably should)
# First see if any OS states are compromised
for x, os_state in enumerate(os):
if os_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "OS"
property_action = "PATCHING"
action_service_index = 0 # does nothing isn't relevant for os
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, see if any Services are compromised
# We fix the compromised state before overwhelemd state,
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "COMPROMISED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Next, See if any services are overwhelmed
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
for service_num, service in enumerate(s):
for x, service_state in enumerate(service):
if service_state == "OVERWHELMED":
action_node_id = x + 1
action_node_property = "SERVICE"
property_action = "PATCHING"
action_service_index = service_num
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# Finally, turn on any off nodes
for x, operating_state in enumerate(o):
if os_state == "OFF":
action_node_id = x + 1
action_node_property = "OPERATING"
property_action = "ON" # Why reset it when we can just turn it on
action_service_index = 0 # does nothing isn't relevant for operating state
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step
return action
# If no good actions, just go with an action that wont do any harm
action_node_id = 1
action_node_property = "NONE"
property_action = "ON"
action_service_index = 0
action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
action = transform_action_node_enum(action)
action = get_new_action(action, action_dict)
return action

View File

@@ -1,287 +0,0 @@
# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
# from __future__ import annotations
# import json
# import shutil
# import zipfile
# from datetime import datetime
# from logging import Logger
# from pathlib import Path
# from typing import Any, Callable, Dict, Optional, Union
# from uuid import uuid4
# from primaite import getLogger
# from primaite.agents.agent_abc import AgentSessionABC
# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
# from primaite.environment.primaite_env import Primaite
# # from ray.rllib.algorithms import Algorithm
# # from ray.rllib.algorithms.a2c import A2CConfig
# # from ray.rllib.algorithms.ppo import PPOConfig
# # from ray.tune.logger import UnifiedLogger
# # from ray.tune.registry import register_env
# # from primaite.exceptions import RLlibAgentError
# _LOGGER: Logger = getLogger(__name__)
# # TODO: verify type of env_config
# def _env_creator(env_config: Dict[str, Any]) -> Primaite:
# return Primaite(
# training_config_path=env_config["training_config_path"],
# lay_down_config_path=env_config["lay_down_config_path"],
# session_path=env_config["session_path"],
# timestamp_str=env_config["timestamp_str"],
# )
# # # TODO: verify type hint return type
# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
# # logdir = session_path / "ray_results"
# # logdir.mkdir(parents=True, exist_ok=True)
# # def logger_creator(config: Dict) -> UnifiedLogger:
# # return UnifiedLogger(config, logdir, loggers=None)
# return logger_creator
# # class RLlibAgent(AgentSessionABC):
# # """An AgentSession class that implements a Ray RLlib agent."""
# # def __init__(
# # self,
# # training_config_path: Optional[Union[str, Path]] = "",
# # lay_down_config_path: Optional[Union[str, Path]] = "",
# # session_path: Optional[Union[str, Path]] = None,
# # ) -> None:
# # """
# # Initialise the RLLib Agent training session.
# # :param training_config_path: YAML file containing configurable items defined in
# # `primaite.config.training_config.TrainingConfig`
# # :type training_config_path: Union[path, str]
# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
# # :type lay_down_config_path: Union[path, str]
# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB")
# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO`
# # or `A2C`)
# # """
# # # TODO: implement RLlib agent loading
# # if session_path is not None:
# # msg = "RLlib agent loading has not been implemented yet"
# # _LOGGER.critical(msg)
# # raise NotImplementedError(msg)
# # super().__init__(training_config_path, lay_down_config_path)
# # if self._training_config.session_type == SessionType.EVAL:
# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
# # _LOGGER.critical(msg)
# # raise RLlibAgentError(msg)
# # if not self._training_config.agent_framework == AgentFramework.RLLIB:
# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
# # _LOGGER.error(msg)
# # raise ValueError(msg)
# # self._agent_config_class: Union[PPOConfig, A2CConfig]
# # if self._training_config.agent_identifier == AgentIdentifier.PPO:
# # self._agent_config_class = PPOConfig
# # elif self._training_config.agent_identifier == AgentIdentifier.A2C:
# # self._agent_config_class = A2CConfig
# # else:
# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
# # _LOGGER.error(msg)
# # raise ValueError(msg)
# # self._agent_config: Union[PPOConfig, A2CConfig]
# # self._current_result: dict
# # self._setup()
# # _LOGGER.debug(
# # f"Created {self.__class__.__name__} using: "
# # f"agent_framework={self._training_config.agent_framework}, "
# # f"agent_identifier="
# # f"{self._training_config.agent_identifier}, "
# # f"deep_learning_framework="
# # f"{self._training_config.deep_learning_framework}"
# # )
# # self._train_agent = None # Required to capture the learning agent to close after eval
# # def _update_session_metadata_file(self) -> None:
# # """
# # Update the ``session_metadata.json`` file.
# # Updates the `session_metadata.json`` in the ``session_path`` directory
# # with the following key/value pairs:
# # - end_datetime: The date & time the session ended in iso format.
# # - total_episodes: The total number of training episodes completed.
# # - total_time_steps: The total number of training time steps completed.
# # """
# # with open(self.session_path / "session_metadata.json", "r") as file:
# # metadata_dict = json.load(file)
# # metadata_dict["end_datetime"] = datetime.now().isoformat()
# # if not self.is_eval:
# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
# # else:
# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
# # filepath = self.session_path / "session_metadata.json"
# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}")
# # with open(filepath, "w") as file:
# # json.dump(metadata_dict, file)
# # _LOGGER.debug("Finished updating session metadata file")
# # def _setup(self) -> None:
# # super()._setup()
# # register_env("primaite", _env_creator)
# # self._agent_config = self._agent_config_class()
# # self._agent_config.environment(
# # env="primaite",
# # env_config=dict(
# # training_config_path=self._training_config_path,
# # lay_down_config_path=self._lay_down_config_path,
# # session_path=self.session_path,
# # timestamp_str=self.timestamp_str,
# # ),
# # )
# # self._agent_config.seed = self._training_config.seed
# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
# # self._agent_config.framework(framework="tf")
# # self._agent_config.rollouts(
# # num_rollout_workers=1,
# # num_envs_per_worker=1,
# # horizon=self._training_config.num_train_steps,
# # )
# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
# # def _save_checkpoint(self) -> None:
# # checkpoint_n = self._training_config.checkpoint_every_n_episodes
# # episode_count = self._current_result["episodes_total"]
# # save_checkpoint = False
# # if checkpoint_n:
# # save_checkpoint = episode_count % checkpoint_n == 0
# # if episode_count and save_checkpoint:
# # self._agent.save(str(self.checkpoints_path))
# # def learn(
# # self,
# # **kwargs: Any,
# # ) -> None:
# # """
# # Evaluate the agent.
# # :param kwargs: Any agent-specific key-word args to be passed.
# # """
# # time_steps = self._training_config.num_train_steps
# # episodes = self._training_config.num_train_episodes
# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
# # for i in range(episodes):
# # self._current_result = self._agent.train()
# # self._save_checkpoint()
# # self.save()
# # super().learn()
# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped
# # if self._training_config.session_type is not SessionType.TRAIN:
# # self._train_agent = self._agent
# # else:
# # self._agent.stop()
# # self._plot_av_reward_per_episode(learning_session=True)
# # def _unpack_saved_agent_into_eval(self) -> Path:
# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
# # agent_restore_path = self.evaluation_path / "agent_restore"
# # if agent_restore_path.exists():
# # shutil.rmtree(agent_restore_path)
# # agent_restore_path.mkdir()
# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
# # zip_file.extractall(agent_restore_path)
# # return agent_restore_path
# # def _setup_eval(self):
# # self._can_learn = False
# # self._can_evaluate = True
# # self._agent.restore(str(self._unpack_saved_agent_into_eval()))
# # def evaluate(
# # self,
# # **kwargs,
# # ):
# # """
# # Evaluate the agent.
# # :param kwargs: Any agent-specific key-word args to be passed.
# # """
# # time_steps = self._training_config.num_eval_steps
# # episodes = self._training_config.num_eval_episodes
# # self._setup_eval()
# # self._env: Primaite = Primaite(
# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
# # )
# # self._env.set_as_eval()
# # self.is_eval = True
# # if self._training_config.deterministic:
# # deterministic_str = "deterministic"
# # else:
# # deterministic_str = "non-deterministic"
# # _LOGGER.info(
# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
# # )
# # for episode in range(episodes):
# # obs = self._env.reset()
# # for step in range(time_steps):
# # action = self._agent.compute_single_action(observation=obs, explore=False)
# # obs, rewards, done, info = self._env.step(action)
# # self._env.reset()
# # self._env.close()
# # super().evaluate()
# # # Now we're safe to close the learning agent and write the mean rewards per episode for it
# # if self._training_config.session_type is not SessionType.TRAIN:
# # self._train_agent.stop()
# # self._plot_av_reward_per_episode(learning_session=True)
# # # Perform a clean-up of the unpacked agent
# # if (self.evaluation_path / "agent_restore").exists():
# # shutil.rmtree((self.evaluation_path / "agent_restore"))
# # def _get_latest_checkpoint(self) -> None:
# # raise NotImplementedError
# # @classmethod
# # def load(cls, path: Union[str, Path]) -> RLlibAgent:
# # """Load an agent from file."""
# # raise NotImplementedError
# # def save(self, overwrite_existing: bool = True) -> None:
# # """Save the agent."""
# # # Make temp dir to save in isolation
# # temp_dir = self.learning_path / str(uuid4())
# # temp_dir.mkdir()
# # # Save the agent to the temp dir
# # self._agent.save(str(temp_dir))
# # # Capture the saved Rllib checkpoint inside the temp directory
# # for file in temp_dir.iterdir():
# # checkpoint_dir = file
# # break
# # # Zip the folder
# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# # # Drop the temp directory
# # shutil.rmtree(temp_dir)
# # def export(self) -> None:
# # """Export the agent to transportable file format."""
# # raise NotImplementedError

View File

@@ -1,206 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
import json
from logging import Logger
from pathlib import Path
from typing import Any, Optional, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = getLogger(__name__)
class SB3Agent(AgentSessionABC):
"""An AgentSession class that implements a Stable Baselines3 agent."""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = None,
lay_down_config_path: Optional[Union[str, Path]] = None,
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
) -> None:
"""
Initialise the SB3 Agent training session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
or `A2C`)
"""
super().__init__(
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_class: Union[PPO, A2C]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
_LOGGER.error(msg)
raise ValueError(msg)
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"agent_identifier="
f"{self._training_config.agent_identifier}"
)
self.is_eval = False
self._setup()
def _setup(self) -> None:
"""Set up the SB3 Agent."""
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
legacy_training_config=self.legacy_training_config,
legacy_lay_down_config=self.legacy_lay_down_config,
)
# check if there is a zip file that needs to be loaded
load_file = next(self.session_path.rglob("*.zip"), None)
if not load_file:
# create a new env and agent
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
else:
# set env values from session metadata
with open(self.session_path / "session_metadata.json", "r") as file:
md_dict = json.load(file)
# load environment values
if self.is_eval:
# evaluation always starts at 0
self._env.episode_count = 0
self._env.total_step_count = 0
else:
# carry on from previous learning sessions
self._env.episode_count = md_dict["learning"]["total_episodes"]
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
# load the file
self._agent = self._agent_class.load(load_file, env=self._env)
# set agent values
self._agent.verbose = self.sb3_output_verbose_level
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
super()._setup()
def _save_checkpoint(self) -> None:
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
self.is_eval = False
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env._write_av_reward_per_episode() # noqa
self.save()
self._env.close()
super().learn()
# save agent
self.save()
self._plot_av_reward_per_episode(learning_session=True)
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
self._env._write_av_reward_per_episode() # noqa
self._env.close()
super().evaluate()
def save(self) -> None:
"""Save the agent."""
self._agent.save(self._saved_agent_path)
def export(self) -> None:
"""Export the agent to transportable file format."""
raise NotImplementedError

View File

@@ -1,59 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
class RandomAgent(HardCodedAgentSessionABC):
"""
A Random Agent.
Get a completely random action from the action space.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
return self._env.action_space.sample()
class DummyAgent(HardCodedAgentSessionABC):
"""
A Dummy Agent.
All action spaces setup so dummy action is always 0 regardless of action type used.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
return 0
class DoNothingACLAgent(HardCodedAgentSessionABC):
"""
A do nothing ACL agent.
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
return nothing_action
class DoNothingNodeAgent(HardCodedAgentSessionABC):
"""
A do nothing Node agent.
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
# nothing_action should currently always be 0
return nothing_action

View File

@@ -1,450 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from typing import Dict, List, Union
import numpy as np
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
LinkStatus,
NodeHardwareAction,
NodePOLType,
NodeSoftwareAction,
SoftwareState,
)
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
"""Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
return new_action
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == 0:
new_action[n + 2] = "ANY"
return new_action
def is_valid_node_action(action: List[int]) -> bool:
"""
Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
# print("node property", node_property, "\nnode action", node_action)
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in [
"NONE",
"PATCHING",
]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action: List[int]) -> bool:
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect.
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action: List[int]) -> bool:
"""
Harsher version of valid acl actions, does not allow action.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
os_states = [SoftwareState(i).name for i in obs[:, 2]]
new_obs = [ids, operating_states, os_states]
for service in range(4, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
new_obs.append(service_states)
return new_obs
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform observation to readable format.
example
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
# Convert list of tuples to list of lists
new_obs = [list(i) for i in new_obs]
return new_obs
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
"""Convert original gym Box observation space to new multiDiscrete observation space.
:param obs: observation in the 'old' (NodeLinkTable) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:return: reformatted observation
:rtype: np.ndarray
"""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
"""Convert to old observation.
Links filled with 0's as no information is included in new observation space.
example:
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
new_obs = array([[ 1, 1, 1, 1],
[ 2, 1, 1, 1],
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
:param obs: observation in the 'new' (MultiDiscrete) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:param num_links: number of links in the network, defaults to 10
:type num_links: int, optional
:param num_services: number of services on the network, defaults to 1
:type num_services: int, optional
:return: 2-d BOX observation space, in the same format as NodeLinkTable
:rtype: np.ndarray
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
# Add empty links back and add node ID back
s = np.zeros(
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
dtype=np.int64,
)
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
s[:num_nodes, 1:] = reshaped_nodes # put values back in
new_obs = s
# Add links back in
links = obs[-num_links:]
# Links will be added to the last protocol/service slot but they are not specific to that service
new_obs[num_nodes:, -1] = links
return new_obs
def describe_obs_change(
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
) -> str:
"""Build a string describing the difference between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
:param obs1: First observation
:type obs1: np.ndarray
:param obs2: Second observation
:type obs2: np.ndarray
:param num_nodes: How many nodes are in the network laydown, defaults to 10
:type num_nodes: int, optional
:param num_links: How many links are in the network laydown, defaults to 10
:type num_links: int, optional
:param num_services: How many services are configured for this scenario, defaults to 1
:type num_services: int, optional
:return: A multi-line string with a human-readable description of the difference.
:rtype: str
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
list_of_changes = []
for n, row in enumerate(obs1 - obs2):
if row.any() != 0:
relevant_changes = np.where(row != 0, obs2[n], -1)
relevant_changes[0] = obs2[n, 0] # ID is always relevant
is_link = relevant_changes[0] > num_nodes
desc = _describe_obs_change_helper(relevant_changes, is_link)
list_of_changes.append(desc)
change_string = "\n ".join(list_of_changes)
if len(list_of_changes) > 0:
change_string = "\n " + change_string
return change_string
def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str:
"""
Helper funcion to describe what has changed.
example:
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
:param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one
row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new
status where it has changed.
:type obs_change: List[int]
:param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node.
:type is_link: bool
:return: A human-readable description of the difference between the two observation rows.
:rtype: str
"""
# Indexes where a change has occured, not including 0th index
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
# Node pol types, Indexes >= 3 are service nodes
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
# Account for hardware states, software sattes and links
states = [
LinkStatus(obs_change[i]).name
if is_link
else HardwareState(obs_change[i]).name
if i == 1
else SoftwareState(obs_change[i]).name
for i in index_changed
]
if not is_link:
desc = f"ID {obs_change[0]}:"
for node_pol_type, state in list(zip(NodePOLTypes, states)):
desc = desc + " " + node_pol_type + " changed to " + state + "."
else:
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
return desc
def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]:
"""Convert a node action from readable string format, to enumerated format.
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
:param action: Action in 'readable' format
:type action: List[Union[str,int]]
:return: Action with verbs encoded as ints
:rtype: List[int]
"""
action_node_id = action[0]
action_node_property = NodePOLType[action[1]].value
if action[1] == "OPERATING":
property_action = NodeHardwareAction[action[2]].value
elif action[1] == "OS" or action[1] == "SERVICE":
property_action = NodeSoftwareAction[action[2]].value
else:
property_action = 0
action_service_index = action[3]
new_action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
return new_action
def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray:
"""
Convert acl action from readable str format, to enumerated format.
:param action: ACL-based action expressed as a list of human-readable ints and strings
:type action: List[Union[int,str]]
:return: The same action but encoded to contain only integers.
:rtype: np.ndarray
"""
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
action_permissions = {"DENY": 0, "ALLOW": 1}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == "ANY":
new_action[n + 2] = 0
new_action = np.array(new_action)
return new_action
def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str:
"""Get the node ID of an IP address.
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
:param ip: The IP address of the node whose ID is required
:type ip: str
:param node_dict: The environment's node registry dictionary
:type node_dict: Dict[str,NodeUnion]
:return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip`
:rtype: str
"""
for node_key, node_value in node_dict.items():
node_ip = node_value.ip_address
if node_ip == ip:
return node_key
def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int:
"""
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
Old_action can be either node or acl action type
:param old_action: Action expressed as a list of choices, eg. [1,1,1,0]
:type old_action: np.ndarray
:param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions.
:type action_dict: Dict[int,List]
:return: Action key correspoinding to the input `old_action`
:rtype: int
"""
for key, val in action_dict.items():
if list(val) == list(old_action):
return key
# Not all possible actions are included in dict, only valid action are
# if action is not in the dict, its an invalid action so return 0
return 0

View File

@@ -10,7 +10,6 @@ import yaml
from typing_extensions import Annotated
from primaite import PRIMAITE_PATHS
from primaite.data_viz import PlotlyTemplate
app = typer.Typer()
@@ -30,7 +29,7 @@ def reset_notebooks(overwrite: bool = True) -> None:
:param overwrite: If True, will overwrite existing demo notebooks.
"""
from primaite.setup import reset_demo_notebooks
from src.primaite.setup import reset_demo_notebooks
reset_demo_notebooks.run(overwrite)
@@ -81,14 +80,6 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) ->
print(f"PrimAITE Log Level: {level}")
@app.command()
def notebooks() -> None:
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
from primaite.notebooks import start_jupyter_session
start_jupyter_session()
@app.command()
def version() -> None:
"""Get the installed PrimAITE version number."""
@@ -97,14 +88,6 @@ def version() -> None:
print(primaite.__version__)
@app.command()
def clean_up() -> None:
"""Cleans up left over files from previous version installations."""
from primaite.setup import old_installation_clean_up
old_installation_clean_up.run()
@app.command()
def setup(overwrite_existing: bool = True) -> None:
"""
@@ -112,8 +95,10 @@ def setup(overwrite_existing: bool = True) -> None:
WARNING: All user-data will be lost.
"""
from arcd_gate.cli import setup as gate_setup
from primaite import getLogger
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
from src.primaite.setup import reset_demo_notebooks, reset_example_configs
_LOGGER = getLogger(__name__)
@@ -130,84 +115,32 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
old_installation_clean_up.run()
_LOGGER.info("Setting up ARCD GATE...")
gate_setup()
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
tc: Optional[str] = None,
ldc: Optional[str] = None,
load: Optional[str] = None,
legacy_tc: bool = False,
legacy_ldc: bool = False,
config: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
load: The directory of a previous session. Optional. If no value is passed, then the session
will use the default training config and laydown config. Inversely, if a training config and laydown config
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
and laydown config.
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
from primaite.main import run
from threading import Thread
if load is not None:
# run a loaded session
run(session_path=load)
from src.primaite.config.load import example_config_path
from src.primaite.main import run
from src.primaite.utils.start_gate_server import start_gate_server
else:
# start a new session using tc and ldc
if not tc:
tc = main_training_config_path()
server_thread = Thread(target=start_gate_server)
server_thread.start()
if not ldc:
ldc = dos_very_basic_config_path()
run(
training_config_path=tc,
lay_down_config_path=ldc,
legacy_training_config=legacy_tc,
legacy_lay_down_config=legacy_ldc,
)
@app.command()
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
"""
View or set the plotly template for Session plots.
To View, simply call: primaite plotly-template
To set, call: primaite plotly-template <desired template>
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
"""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if template:
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE plotly template: {template.value}")
else:
template = primaite_config["session"]["outputs"]["plots"]["template"]
print(f"PrimAITE plotly template: {template}")
if not config:
config = example_config_path()
print(config)
run(config_path=config)

View File

@@ -1,2 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Objects which are shared between many PrimAITE modules."""

View File

@@ -1,8 +0,0 @@
from typing import Union
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""

View File

@@ -1,208 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Enumerations for APE."""
from enum import Enum, IntEnum
class NodeType(Enum):
"""Node type enumeration."""
CCTV = 1
SWITCH = 2
COMPUTER = 3
LINK = 4
MONITOR = 5
PRINTER = 6
LOP = 7
RTU = 8
ACTUATOR = 9
SERVER = 10
class Priority(Enum):
"""Node priority enumeration."""
P1 = 1
P2 = 2
P3 = 3
P4 = 4
P5 = 5
class HardwareState(Enum):
"""Node hardware state enumeration."""
NONE = 0
ON = 1
OFF = 2
RESETTING = 3
SHUTTING_DOWN = 4
BOOTING = 5
class SoftwareState(Enum):
"""Software or Service state enumeration."""
NONE = 0
GOOD = 1
PATCHING = 2
COMPROMISED = 3
OVERWHELMED = 4
class NodePOLType(Enum):
"""Node Pattern of Life type enumeration."""
NONE = 0
OPERATING = 1
OS = 2
SERVICE = 3
FILE = 4
class NodePOLInitiator(Enum):
"""Node Pattern of Life initiator enumeration."""
DIRECT = 1
IER = 2
SERVICE = 3
class Protocol(Enum):
"""Service protocol enumeration."""
LDAP = 0
FTP = 1
HTTPS = 2
SMTP = 3
RTP = 4
IPP = 5
TCP = 6
NONE = 7
class SessionType(Enum):
"""The type of PrimAITE Session to be run."""
TRAIN = 1
"Train an agent"
EVAL = 2
"Evaluate an agent"
TRAIN_EVAL = 3
"Train then evaluate an agent"
class AgentFramework(Enum):
"""The agent algorithm framework/package."""
CUSTOM = 0
"Custom Agent"
SB3 = 1
"Stable Baselines3"
# RLLIB = 2
# "Ray RLlib"
class DeepLearningFramework(Enum):
"""The deep learning framework."""
TF = "tf"
"Tensorflow"
TF2 = "tf2"
"Tensorflow 2.x"
TORCH = "torch"
"PyTorch"
class AgentIdentifier(Enum):
"""The Red Agent algo/class."""
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"The Hardcoded agents"
DO_NOTHING = 4
"The DoNothing agents"
RANDOM = 5
"The RandomAgent"
DUMMY = 6
"The DummyAgent"
class HardCodedAgentView(Enum):
"""The view the deterministic hard-coded agent has of the environment."""
BASIC = 1
"The current observation space only"
FULL = 2
"Full environment view with actions taken and reward feedback"
class ActionType(Enum):
"""Action type enumeration."""
NODE = 0
ACL = 1
ANY = 2
# TODO: this is not used anymore, write a ticket to delete it.
class ObservationType(Enum):
"""Observation type enumeration."""
BOX = 0
MULTIDISCRETE = 1
class FileSystemState(Enum):
"""File System State."""
GOOD = 1
CORRUPT = 2
DESTROYED = 3
REPAIRING = 4
RESTORING = 5
class NodeHardwareAction(Enum):
"""Node hardware action."""
NONE = 0
ON = 1
OFF = 2
RESET = 3
class NodeSoftwareAction(Enum):
"""Node software action."""
NONE = 0
PATCHING = 1
class LinkStatus(Enum):
"""Link traffic status."""
NONE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
OVERLOAD = 4
class SB3OutputVerboseLevel(IntEnum):
"""The Stable Baselines3 learn/eval output verbosity level."""
NONE = 0
INFO = 1
DEBUG = 2
class RulePermissionType(Enum):
"""Any firewall rule type."""
NONE = 0
DENY = 1
ALLOW = 2

View File

@@ -1,47 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The protocol class."""
class Protocol(object):
"""Protocol class."""
def __init__(self, _name: str) -> None:
"""
Initialise a protocol.
:param _name: The name of the protocol
:type _name: str
"""
self.name: str = _name
self.load: int = 0 # bps
def get_name(self) -> str:
"""
Gets the protocol name.
Returns:
The protocol name
"""
return self.name
def get_load(self) -> int:
"""
Gets the protocol load.
Returns:
The protocol load (bps)
"""
return self.load
def add_load(self, _load: int) -> None:
"""
Adds load to the protocol.
Args:
_load: The load to add
"""
self.load += _load
def clear_load(self) -> None:
"""Clears the load on this protocol."""
self.load = 0

View File

@@ -1,28 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Service class."""
from primaite.common.enums import SoftwareState
class Service(object):
"""Service class."""
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
"""
Initialise a service.
:param name: The service name.
:param port: The service port.
:param software_state: The service SoftwareState.
"""
self.name: str = name
self.port: str = port
self.software_state: SoftwareState = software_state
self.patching_count: int = 0
def reduce_patching_count(self) -> None:
"""Reduces the patching count for the service."""
self.patching_count -= 1
if self.patching_count <= 0:
self.patching_count = 0
self.software_state = SoftwareState.GOOD

View File

@@ -0,0 +1,726 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 20
n_learn_steps: 128
n_eval_episodes: 20
n_eval_steps: 128
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: GATERLAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -1,166 +0,0 @@
- item_type: PORTS
ports_list:
- port: '80'
- item_type: SERVICES
service_list:
- name: TCP
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: PC2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.4
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '4'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.5
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '5'
name: SWITCH2
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.6
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '6'
name: SWITCH3
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.7
software_state: GOOD
file_system_state: GOOD
- item_type: LINK
id: '7'
name: link1
bandwidth: 1000000000
source: '1'
destination: '4'
- item_type: LINK
id: '8'
name: link2
bandwidth: 1000000000
source: '4'
destination: '2'
- item_type: LINK
id: '9'
name: link3
bandwidth: 1000000000
source: '2'
destination: '5'
- item_type: LINK
id: '10'
name: link4
bandwidth: 1000000000
source: '2'
destination: '6'
- item_type: LINK
id: '11'
name: link5
bandwidth: 1000000000
source: '5'
destination: '3'
- item_type: LINK
id: '12'
name: link6
bandwidth: 1000000000
source: '6'
destination: '3'
- item_type: GREEN_IER
id: '13'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '3'
destination: '2'
mission_criticality: 5
- item_type: RED_POL
id: '14'
start_step: 50
end_step: 50
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '15'
start_step: 60
end_step: 100
load: 1000000
protocol: TCP
port: '80'
source: '1'
destination: '2'
mission_criticality: 0
- item_type: RED_POL
id: '16'
start_step: 80
end_step: 80
targetNodeId: '2'
initiator: IER
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: ACL_RULE
id: '17'
permission: ALLOW
source: ANY
destination: ANY
protocol: ANY
port: ANY
position: 0

View File

@@ -1,366 +0,0 @@
- item_type: PORTS
ports_list:
- port: '80'
- item_type: SERVICES
service_list:
- name: TCP
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.11
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '2'
name: PC2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: PC3
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.13
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '4'
name: PC4
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.20.14
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '5'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '6'
name: IDS
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.4
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '7'
name: SWITCH2
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '8'
name: LOP1
node_class: SERVICE
node_type: LOP
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '9'
name: SERVER1
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.14
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '10'
name: SERVER2
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.20.15
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: LINK
id: '11'
name: link1
bandwidth: 1000000000
source: '1'
destination: '5'
- item_type: LINK
id: '12'
name: link2
bandwidth: 1000000000
source: '2'
destination: '5'
- item_type: LINK
id: '13'
name: link3
bandwidth: 1000000000
source: '3'
destination: '5'
- item_type: LINK
id: '14'
name: link4
bandwidth: 1000000000
source: '4'
destination: '5'
- item_type: LINK
id: '15'
name: link5
bandwidth: 1000000000
source: '5'
destination: '6'
- item_type: LINK
id: '16'
name: link6
bandwidth: 1000000000
source: '5'
destination: '8'
- item_type: LINK
id: '17'
name: link7
bandwidth: 1000000000
source: '6'
destination: '7'
- item_type: LINK
id: '18'
name: link8
bandwidth: 1000000000
source: '8'
destination: '7'
- item_type: LINK
id: '19'
name: link9
bandwidth: 1000000000
source: '7'
destination: '9'
- item_type: LINK
id: '20'
name: link10
bandwidth: 1000000000
source: '7'
destination: '10'
- item_type: GREEN_IER
id: '21'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '1'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '22'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '2'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '23'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '9'
destination: '3'
mission_criticality: 5
- item_type: GREEN_IER
id: '24'
start_step: 1
end_step: 128
load: 100000
protocol: TCP
port: '80'
source: '4'
destination: '10'
mission_criticality: 2
- item_type: ACL_RULE
id: '25'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.10.14
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '26'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.10.14
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '27'
permission: ALLOW
source: 192.168.10.13
destination: 192.168.10.14
protocol: TCP
port: 80
position: 2
- item_type: ACL_RULE
id: '28'
permission: ALLOW
source: 192.168.20.14
destination: 192.168.20.15
protocol: TCP
port: 80
position: 3
- item_type: ACL_RULE
id: '29'
permission: ALLOW
source: 192.168.10.14
destination: 192.168.10.13
protocol: TCP
port: 80
position: 4
- item_type: ACL_RULE
id: '30'
permission: DENY
source: 192.168.10.11
destination: 192.168.20.15
protocol: TCP
port: 80
position: 5
- item_type: ACL_RULE
id: '31'
permission: DENY
source: 192.168.10.12
destination: 192.168.20.15
protocol: TCP
port: 80
position: 6
- item_type: ACL_RULE
id: '32'
permission: DENY
source: 192.168.10.13
destination: 192.168.20.15
protocol: TCP
port: 80
position: 7
- item_type: ACL_RULE
id: '33'
permission: DENY
source: 192.168.20.14
destination: 192.168.10.14
protocol: TCP
port: 80
position: 8
- item_type: RED_POL
id: '34'
start_step: 20
end_step: 20
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_POL
id: '35'
start_step: 20
end_step: 20
targetNodeId: '2'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '36'
start_step: 30
end_step: 128
load: 440000000
protocol: TCP
port: '80'
source: '1'
destination: '9'
mission_criticality: 0
- item_type: RED_IER
id: '37'
start_step: 30
end_step: 128
load: 440000000
protocol: TCP
port: '80'
source: '2'
destination: '9'
mission_criticality: 0
- item_type: RED_POL
id: '38'
start_step: 30
end_step: 30
targetNodeId: '9'
initiator: IER
type: SERVICE
protocol: TCP
state: OVERWHELMED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA

View File

@@ -1,164 +0,0 @@
- item_type: PORTS
ports_list:
- port: '80'
- item_type: SERVICES
service_list:
- name: TCP
- item_type: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '2'
name: PC2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '4'
name: SERVER1
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.4
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: LINK
id: '5'
name: link1
bandwidth: 1000000000
source: '1'
destination: '3'
- item_type: LINK
id: '6'
name: link2
bandwidth: 1000000000
source: '2'
destination: '3'
- item_type: LINK
id: '7'
name: link3
bandwidth: 1000000000
source: '3'
destination: '4'
- item_type: GREEN_IER
id: '8'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '1'
destination: '4'
mission_criticality: 1
- item_type: GREEN_IER
id: '9'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '2'
destination: '4'
mission_criticality: 1
- item_type: GREEN_IER
id: '10'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '4'
destination: '2'
mission_criticality: 5
- item_type: ACL_RULE
id: '11'
permission: ALLOW
source: 192.168.1.2
destination: 192.168.1.4
protocol: TCP
port: 80
position: 0
- item_type: ACL_RULE
id: '12'
permission: ALLOW
source: 192.168.1.3
destination: 192.168.1.4
protocol: TCP
port: 80
position: 1
- item_type: ACL_RULE
id: '13'
permission: ALLOW
source: 192.168.1.4
destination: 192.168.1.3
protocol: TCP
port: 80
position: 2
- item_type: RED_POL
id: '14'
start_step: 20
end_step: 20
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: TCP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '15'
start_step: 30
end_step: 256
load: 10000000
protocol: TCP
port: '80'
source: '1'
destination: '4'
mission_criticality: 0
- item_type: RED_POL
id: '16'
start_step: 40
end_step: 40
targetNodeId: '4'
initiator: IER
type: SERVICE
protocol: TCP
state: OVERWHELMED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA

View File

@@ -1,546 +0,0 @@
- item_type: PORTS
ports_list:
- port: '80'
- port: '1433'
- port: '53'
- item_type: SERVICES
service_list:
- name: TCP
- name: TCP_SQL
- name: UDP
- item_type: NODE
node_id: '1'
name: CLIENT_1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.11
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '2'
name: CLIENT_2
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.10.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: NODE
node_id: '3'
name: SWITCH_1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.10.1
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '4'
name: SECURITY_SUITE
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.10
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '5'
name: MANAGEMENT_CONSOLE
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.12
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '6'
name: SWITCH_2
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.2.1
software_state: GOOD
file_system_state: GOOD
- item_type: NODE
node_id: '7'
name: WEB_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.10
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: TCP_SQL
port: '1433'
state: GOOD
- item_type: NODE
node_id: '8'
name: DATABASE_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.14
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: TCP_SQL
port: '1433'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- item_type: NODE
node_id: '9'
name: BACKUP_SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.2.16
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- item_type: LINK
id: '10'
name: LINK_1
bandwidth: 1000000000
source: '1'
destination: '3'
- item_type: LINK
id: '11'
name: LINK_2
bandwidth: 1000000000
source: '2'
destination: '3'
- item_type: LINK
id: '12'
name: LINK_3
bandwidth: 1000000000
source: '3'
destination: '4'
- item_type: LINK
id: '13'
name: LINK_4
bandwidth: 1000000000
source: '3'
destination: '5'
- item_type: LINK
id: '14'
name: LINK_5
bandwidth: 1000000000
source: '4'
destination: '6'
- item_type: LINK
id: '15'
name: LINK_6
bandwidth: 1000000000
source: '5'
destination: '6'
- item_type: LINK
id: '16'
name: LINK_7
bandwidth: 1000000000
source: '6'
destination: '7'
- item_type: LINK
id: '17'
name: LINK_8
bandwidth: 1000000000
source: '6'
destination: '8'
- item_type: LINK
id: '18'
name: LINK_9
bandwidth: 1000000000
source: '6'
destination: '9'
- item_type: GREEN_IER
id: '19'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '1'
destination: '7'
mission_criticality: 5
- item_type: GREEN_IER
id: '20'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '7'
destination: '1'
mission_criticality: 5
- item_type: GREEN_IER
id: '21'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '2'
destination: '7'
mission_criticality: 5
- item_type: GREEN_IER
id: '22'
start_step: 1
end_step: 256
load: 10000
protocol: TCP
port: '80'
source: '7'
destination: '2'
mission_criticality: 5
- item_type: GREEN_IER
id: '23'
start_step: 1
end_step: 256
load: 5000
protocol: TCP_SQL
port: '1433'
source: '7'
destination: '8'
mission_criticality: 5
- item_type: GREEN_IER
id: '24'
start_step: 1
end_step: 256
load: 100000
protocol: TCP_SQL
port: '1433'
source: '8'
destination: '7'
mission_criticality: 5
- item_type: GREEN_IER
id: '25'
start_step: 1
end_step: 256
load: 50000
protocol: TCP
port: '80'
source: '1'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '26'
start_step: 1
end_step: 256
load: 50000
protocol: TCP
port: '80'
source: '2'
destination: '9'
mission_criticality: 2
- item_type: GREEN_IER
id: '27'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '7'
mission_criticality: 1
- item_type: GREEN_IER
id: '28'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '7'
destination: '5'
mission_criticality: 1
- item_type: GREEN_IER
id: '29'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '8'
mission_criticality: 1
- item_type: GREEN_IER
id: '30'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '8'
destination: '5'
mission_criticality: 1
- item_type: GREEN_IER
id: '31'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '5'
destination: '9'
mission_criticality: 1
- item_type: GREEN_IER
id: '32'
start_step: 1
end_step: 256
load: 5000
protocol: TCP
port: '80'
source: '9'
destination: '5'
mission_criticality: 1
- item_type: ACL_RULE
id: '33'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 0
- item_type: ACL_RULE
id: '34'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 1
- item_type: ACL_RULE
id: '35'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 2
- item_type: ACL_RULE
id: '36'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 3
- item_type: ACL_RULE
id: '37'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.10.11
protocol: ANY
port: ANY
position: 4
- item_type: ACL_RULE
id: '38'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.10.12
protocol: ANY
port: ANY
position: 5
- item_type: ACL_RULE
id: '39'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 6
- item_type: ACL_RULE
id: '40'
permission: ALLOW
source: 192.168.2.14
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 7
- item_type: ACL_RULE
id: '41'
permission: ALLOW
source: 192.168.10.11
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 8
- item_type: ACL_RULE
id: '42'
permission: ALLOW
source: 192.168.10.12
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 9
- item_type: ACL_RULE
id: '43'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.10
protocol: ANY
port: ANY
position: 10
- item_type: ACL_RULE
id: '44'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.14
protocol: ANY
port: ANY
position: 11
- item_type: ACL_RULE
id: '45'
permission: ALLOW
source: 192.168.1.12
destination: 192.168.2.16
protocol: ANY
port: ANY
position: 12
- item_type: ACL_RULE
id: '46'
permission: ALLOW
source: 192.168.2.10
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 13
- item_type: ACL_RULE
id: '47'
permission: ALLOW
source: 192.168.2.14
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 14
- item_type: ACL_RULE
id: '48'
permission: ALLOW
source: 192.168.2.16
destination: 192.168.1.12
protocol: ANY
port: ANY
position: 15
- item_type: ACL_RULE
id: '49'
permission: DENY
source: ANY
destination: ANY
protocol: ANY
port: ANY
position: 16
- item_type: RED_POL
id: '50'
start_step: 50
end_step: 50
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: UDP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_IER
id: '51'
start_step: 75
end_step: 105
load: 10000
protocol: UDP
port: '53'
source: '1'
destination: '8'
mission_criticality: 0
- item_type: RED_POL
id: '52'
start_step: 100
end_step: 100
targetNodeId: '8'
initiator: IER
type: SERVICE
protocol: UDP
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- item_type: RED_POL
id: '53'
start_step: 105
end_step: 105
targetNodeId: '8'
initiator: SERVICE
type: FILE
protocol: NA
state: CORRUPT
sourceNodeId: '8'
sourceNodeService: UDP
sourceNodeServiceState: COMPROMISED
- item_type: RED_POL
id: '54'
start_step: 105
end_step: 105
targetNodeId: '8'
initiator: SERVICE
type: SERVICE
protocol: TCP_SQL
state: COMPROMISED
sourceNodeId: '8'
sourceNodeService: UDP
sourceNodeServiceState: COMPROMISED
- item_type: RED_POL
id: '55'
start_step: 125
end_step: 125
targetNodeId: '7'
initiator: SERVICE
type: SERVICE
protocol: TCP
state: OVERWHELMED
sourceNodeId: '8'
sourceNodeService: TCP_SQL
sourceNodeServiceState: COMPROMISED

View File

@@ -1,168 +0,0 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent evaluation will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# observation space
observation_space:
flatten: true
components:
- name: NODE_LINK_TABLE
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
- name: ACCESS_CONTROL_LIST
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 10
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
implicit_acl_rule: DENY
# Total number of ACL rules allowed in the environment
max_number_acl_rules: 30
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,141 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, List, Union
import yaml
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down"
def convert_legacy_lay_down_config(legacy_config: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Convert a legacy lay down config to the new format.
:param legacy_config: A legacy lay down config.
"""
field_conversion_map = {
"itemType": "item_type",
"portsList": "ports_list",
"serviceList": "service_list",
"baseType": "node_class",
"nodeType": "node_type",
"hardwareState": "hardware_state",
"softwareState": "software_state",
"startStep": "start_step",
"endStep": "end_step",
"fileSystemState": "file_system_state",
"ipAddress": "ip_address",
"missionCriticality": "mission_criticality",
}
new_config = []
for item in legacy_config:
if "itemType" in item:
if item["itemType"] in ["ACTIONS", "STEPS"]:
continue
new_dict = {}
for key in item.keys():
conversion_key = field_conversion_map.get(key)
if key == "id" and "itemType" in item:
if item["itemType"] == "NODE":
conversion_key = "node_id"
if conversion_key:
new_dict[conversion_key] = item[key]
else:
new_dict[key] = item[key]
new_config.append(new_dict)
return new_config
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
"""
Read in a lay down config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise False.
:return: The lay down config as a dict.
:raises ValueError: If the file_path does not exist.
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loading lay down config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_lay_down_config(config)
except KeyError:
msg = (
f"Failed to convert lay down config file {file_path} "
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
return config
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
_LOGGER.error(msg)
raise ValueError(msg)
def ddos_basic_one_config_path() -> Path:
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def ddos_basic_two_config_path() -> Path:
"""
The path to the example lay_down_config_2_DDOS_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def dos_very_basic_config_path() -> Path:
"""
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
def data_manipulation_config_path() -> Path:
"""
The path to the example lay_down_config_5_data_manipulation.yaml file.
:return: The file path.
"""
path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -0,0 +1,45 @@
from pathlib import Path
from typing import Dict, Final, Union
import yaml
from primaite import getLogger, PRIMAITE_PATHS
_LOGGER = getLogger(__name__)
_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config"
def load(file_path: Union[str, Path]) -> Dict:
"""
Read a YAML file and return the contents as a dictionary.
:param file_path: Path to the YAML file.
:type file_path: Union[str, Path]
:return: Config dictionary
:rtype: Dict
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if not file_path.exists():
_LOGGER.error(f"File does not exist: {file_path}")
raise FileNotFoundError(f"File does not exist: {file_path}")
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loaded config from {file_path}")
return config
def example_config_path() -> Path:
"""
Get the path to the example config.
:return: Path to the example config.
:rtype: Path
"""
path = _EXAMPLE_CFG / "example_config.yaml"
if not path.exists():
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -1,438 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from __future__ import annotations
from dataclasses import dataclass, field
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Optional, Union
import yaml
from primaite import getLogger, PRIMAITE_PATHS
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
DeepLearningFramework,
HardCodedAgentView,
RulePermissionType,
SB3OutputVerboseLevel,
SessionType,
)
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training"
def main_training_config_path() -> Path:
"""
The path to the example training_config_main.yaml file.
:return: The file path.
"""
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
if not path.exists():
msg = "Example config not found. Please run 'primaite setup'"
_LOGGER.critical(msg)
raise FileNotFoundError(msg)
return path
@dataclass()
class TrainingConfig:
"""The Training Config class."""
agent_framework: AgentFramework = AgentFramework.SB3
"The AgentFramework"
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
"The DeepLearningFramework"
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
"The AgentIdentifier"
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
"The view the deterministic hard-coded agent has of the environment"
random_red_agent: bool = False
"Creates Random Red Agent Attacks"
action_type: ActionType = ActionType.ANY
"The ActionType to use"
num_train_episodes: int = 10
"The number of episodes to train over during an training session"
num_train_steps: int = 256
"The number of steps in an episode during an training session"
num_eval_episodes: int = 1
"The number of episodes to train over during an evaluation session"
num_eval_steps: int = 256
"The number of steps in an episode during an evaluation session"
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes"
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
"The observation space config dict"
time_delay: int = 10
"The delay between steps (ms). Applies to generic agents only"
# file
session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run"
load_agent: bool = False
"Determine whether to load an agent from file"
agent_load_file: Optional[str] = None
"File path and file name of agent if you're loading one in"
# Environment
observation_space_high_value: int = 1000000000
"The high value for the observation space"
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
"Stable Baselines3 learn/eval output verbosity level"
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
max_number_acl_rules: int = 30
"Sets a limit for number of acl rules allowed in the list and environment."
# Reward values
# Generic
all_ok: float = 0
# Node Hardware State
off_should_be_on: float = -0.001
off_should_be_resetting: float = -0.0005
on_should_be_off: float = -0.0002
on_should_be_resetting: float = -0.0005
resetting_should_be_on: float = -0.0005
resetting_should_be_off: float = -0.0002
resetting: float = -0.0003
# Node Software or Service State
good_should_be_patching: float = 0.0002
good_should_be_compromised: float = 0.0005
good_should_be_overwhelmed: float = 0.0005
patching_should_be_good: float = -0.0005
patching_should_be_compromised: float = 0.0002
patching_should_be_overwhelmed: float = 0.0002
patching: float = -0.0003
compromised_should_be_good: float = -0.002
compromised_should_be_patching: float = -0.002
compromised_should_be_overwhelmed: float = -0.002
compromised: float = -0.002
overwhelmed_should_be_good: float = -0.002
overwhelmed_should_be_patching: float = -0.002
overwhelmed_should_be_compromised: float = -0.002
overwhelmed: float = -0.002
# Node File System State
good_should_be_repairing: float = 0.0002
good_should_be_restoring: float = 0.0002
good_should_be_corrupt: float = 0.0005
good_should_be_destroyed: float = 0.001
repairing_should_be_good: float = -0.0005
repairing_should_be_restoring: float = 0.0002
repairing_should_be_corrupt: float = 0.0002
repairing_should_be_destroyed: float = 0.0000
repairing: float = -0.0003
restoring_should_be_good: float = -0.001
restoring_should_be_repairing: float = -0.0002
restoring_should_be_corrupt: float = 0.0001
restoring_should_be_destroyed: float = 0.0002
restoring: float = -0.0006
corrupt_should_be_good: float = -0.001
corrupt_should_be_repairing: float = -0.001
corrupt_should_be_restoring: float = -0.001
corrupt_should_be_destroyed: float = 0.0002
corrupt: float = -0.001
destroyed_should_be_good: float = -0.002
destroyed_should_be_repairing: float = -0.002
destroyed_should_be_restoring: float = -0.002
destroyed_should_be_corrupt: float = -0.002
destroyed: float = -0.002
scanning: float = -0.0002
# IER status
red_ier_running: float = -0.0005
green_ier_blocked: float = -0.001
# Patching / Reset durations
os_patching_duration: int = 5
"The time taken to patch the OS"
node_reset_duration: int = 5
"The time taken to reset a node (hardware)"
node_booting_duration: int = 3
"The Time taken to turn on the node"
node_shutdown_duration: int = 2
"The time taken to turn off the node"
service_patching_duration: int = 5
"The time taken to patch a service"
file_system_repairing_limit: int = 5
"The time take to repair the file system"
file_system_restoring_limit: int = 5
"The time take to restore the file system"
file_system_scanning_limit: int = 5
"The time taken to scan the file system"
deterministic: bool = False
"If true, the training will be deterministic"
seed: Optional[int] = None
"The random number generator seed to be used while training the agent"
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
:param config_dict: The training config dict.
:return: The instance of TrainingConfig.
"""
field_enum_map = {
"agent_framework": AgentFramework,
"deep_learning_framework": DeepLearningFramework,
"agent_identifier": AgentIdentifier,
"action_type": ActionType,
"session_type": SessionType,
"sb3_output_verbose_level": SB3OutputVerboseLevel,
"hard_coded_agent_view": HardCodedAgentView,
"implicit_acl_rule": RulePermissionType,
}
# convert the string representation of enums into the actual enum values themselves?
for key, value in field_enum_map.items():
if key in config_dict:
config_dict[key] = value[config_dict[key]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True) -> Dict:
"""
Serialise the ``TrainingConfig`` as dict.
:param json_serializable: If True, Enums are converted to their
string name.
:return: The ``TrainingConfig`` as a dict.
"""
data = self.__dict__
if json_serializable:
data["agent_framework"] = self.agent_framework.name
data["deep_learning_framework"] = self.deep_learning_framework.name
data["agent_identifier"] = self.agent_identifier.name
data["action_type"] = self.action_type.name
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
data["session_type"] = self.session_type.name
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
data["implicit_acl_rule"] = self.implicit_acl_rule.name
return data
def __str__(self) -> str:
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
tc = f"{self.agent_framework}, "
# if self.agent_framework is AgentFramework.RLLIB:
# tc += f"{self.deep_learning_framework}, "
tc += f"{self.agent_identifier}, "
if self.agent_identifier is AgentIdentifier.HARDCODED:
tc += f"{self.hard_coded_agent_view}, "
tc += f"{self.action_type}, "
tc += f"observation_space={obs_str}, "
if self.session_type is SessionType.TRAIN:
tc += f"{self.num_train_episodes} episodes @ "
tc += f"{self.num_train_steps} steps"
elif self.session_type is SessionType.EVAL:
tc += f"{self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
else:
tc += f"Training: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
return tc
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
"""
Read in a training config yaml file.
:param file_path: The config file path.
:param legacy_file: True if the config file is legacy format, otherwise
False.
:return: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:raises ValueError: If the file_path does not exist.
:raises TypeError: When the TrainingConfig object cannot be created
using the values from the config file read from ``file_path``.
"""
if not isinstance(file_path, Path):
file_path = Path(file_path)
if file_path.exists():
with open(file_path, "r") as file:
config = yaml.safe_load(file)
_LOGGER.debug(f"Loading training config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_training_config_dict(config)
except KeyError as e:
msg = (
f"Failed to convert training config file {file_path} "
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
raise e
try:
return TrainingConfig.from_dict(config)
except TypeError as e:
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
_LOGGER.critical(msg, exc_info=True)
raise e
msg = f"Cannot load the training config as it does not exist: {file_path}"
_LOGGER.error(msg)
raise ValueError(msg)
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
agent_framework: AgentFramework = AgentFramework.SB3,
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_train_steps: int = 256,
num_eval_steps: int = 256,
num_train_episodes: int = 10,
num_eval_episodes: int = 1,
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param agent_identifier: The red agent identifier to use as legacy
training configs don't have agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_train_steps: The number of train steps to set as legacy training configs
don't have num_train_steps values.
:param num_eval_steps: The number of eval steps to set as legacy training configs
don't have num_eval_steps values.
:param num_train_episodes: The number of train episodes to set as legacy training configs
don't have num_train_episodes values.
:param num_eval_episodes: The number of eval episodes to set as legacy training configs
don't have num_eval_episodes values.
:return: The converted training config dict.
"""
config_dict = {
"agent_framework": agent_framework.name,
"agent_identifier": agent_identifier.name,
"action_type": action_type.name,
"num_train_steps": num_train_steps,
"num_eval_steps": num_eval_steps,
"num_train_episodes": num_train_episodes,
"num_eval_episodes": num_eval_episodes,
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
}
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:
config_dict[new_key] = value
return config_dict
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
"""
Maps legacy training config keys to the new format keys.
:param legacy_key: A legacy training config key.
:return: The mapped key.
"""
key_mapping = {
"agentIdentifier": None,
"numEpisodes": "num_train_episodes",
"numSteps": "num_train_steps",
"timeDelay": "time_delay",
"configFilename": None,
"sessionType": "session_type",
"loadAgent": "load_agent",
"agentLoadFile": "agent_load_file",
"observationSpaceHighValue": "observation_space_high_value",
"allOk": "all_ok",
"offShouldBeOn": "off_should_be_on",
"offShouldBeResetting": "off_should_be_resetting",
"onShouldBeOff": "on_should_be_off",
"onShouldBeResetting": "on_should_be_resetting",
"resettingShouldBeOn": "resetting_should_be_on",
"resettingShouldBeOff": "resetting_should_be_off",
"resetting": "resetting",
"goodShouldBePatching": "good_should_be_patching",
"goodShouldBeCompromised": "good_should_be_compromised",
"goodShouldBeOverwhelmed": "good_should_be_overwhelmed",
"patchingShouldBeGood": "patching_should_be_good",
"patchingShouldBeCompromised": "patching_should_be_compromised",
"patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed",
"patching": "patching",
"compromisedShouldBeGood": "compromised_should_be_good",
"compromisedShouldBePatching": "compromised_should_be_patching",
"compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed",
"compromised": "compromised",
"overwhelmedShouldBeGood": "overwhelmed_should_be_good",
"overwhelmedShouldBePatching": "overwhelmed_should_be_patching",
"overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised",
"overwhelmed": "overwhelmed",
"goodShouldBeRepairing": "good_should_be_repairing",
"goodShouldBeRestoring": "good_should_be_restoring",
"goodShouldBeCorrupt": "good_should_be_corrupt",
"goodShouldBeDestroyed": "good_should_be_destroyed",
"repairingShouldBeGood": "repairing_should_be_good",
"repairingShouldBeRestoring": "repairing_should_be_restoring",
"repairingShouldBeCorrupt": "repairing_should_be_corrupt",
"repairingShouldBeDestroyed": "repairing_should_be_destroyed",
"repairing": "repairing",
"restoringShouldBeGood": "restoring_should_be_good",
"restoringShouldBeRepairing": "restoring_should_be_repairing",
"restoringShouldBeCorrupt": "restoring_should_be_corrupt",
"restoringShouldBeDestroyed": "restoring_should_be_destroyed",
"restoring": "restoring",
"corruptShouldBeGood": "corrupt_should_be_good",
"corruptShouldBeRepairing": "corrupt_should_be_repairing",
"corruptShouldBeRestoring": "corrupt_should_be_restoring",
"corruptShouldBeDestroyed": "corrupt_should_be_destroyed",
"corrupt": "corrupt",
"destroyedShouldBeGood": "destroyed_should_be_good",
"destroyedShouldBeRepairing": "destroyed_should_be_repairing",
"destroyedShouldBeRestoring": "destroyed_should_be_restoring",
"destroyedShouldBeCorrupt": "destroyed_should_be_corrupt",
"destroyed": "destroyed",
"scanning": "scanning",
"redIerRunning": "red_ier_running",
"greenIerBlocked": "green_ier_blocked",
"osPatchingDuration": "os_patching_duration",
"nodeResetDuration": "node_reset_duration",
"nodeBootingDuration": "node_booting_duration",
"nodeShutdownDuration": "node_shutdown_duration",
"servicePatchingDuration": "service_patching_duration",
"fileSystemRepairingLimit": "file_system_repairing_limit",
"fileSystemRestoringLimit": "file_system_restoring_limit",
"fileSystemScanningLimit": "file_system_scanning_limit",
}
return key_mapping[legacy_key]

View File

@@ -1,15 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Utility to generate plots of sessions metrics after PrimAITE."""
from enum import Enum
class PlotlyTemplate(Enum):
"""The built-in plotly templates."""
PLOTLY = "plotly"
PLOTLY_WHITE = "plotly_white"
PLOTLY_DARK = "plotly_dark"
GGPLOT2 = "ggplot2"
SEABORN = "seaborn"
SIMPLE_WHITE = "simple_white"
NONE = "none"

View File

@@ -1,73 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Dict, Optional, Union
import plotly.graph_objects as go
import polars as pl
import yaml
from plotly.graph_objs import Figure
from primaite import PRIMAITE_PATHS
def get_plotly_config() -> Dict:
"""Get the plotly config from primaite_config.yaml."""
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
return primaite_config["session"]["outputs"]["plots"]
def plot_av_reward_per_episode(
av_reward_per_episode_csv: Union[str, Path],
title: Optional[str] = None,
subtitle: Optional[str] = None,
) -> Figure:
"""
Plot the average reward per episode from a csv session output.
:param av_reward_per_episode_csv: The average reward per episode csv
file path.
:param title: The plot title. This is optional.
:param subtitle: The plot subtitle. This is optional.
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
"""
df = pl.read_csv(av_reward_per_episode_csv)
if title:
if subtitle:
title = f"{title} <br>{subtitle}</sup>"
else:
if subtitle:
title = subtitle
config = get_plotly_config()
layout = go.Layout(
autosize=config["size"]["auto_size"],
width=config["size"]["width"],
height=config["size"]["height"],
)
# Create the line graph with a colored line
fig = go.Figure(layout=layout)
fig.update_layout(template=config["template"])
fig.add_trace(
go.Scatter(
x=df["Episode"],
y=df["Average Reward"],
mode="lines",
name="Mean Reward per Episode",
)
)
# Set the layout of the graph
fig.update_layout(
xaxis={
"title": "Episode",
"type": "linear",
"rangeslider": {"visible": config["range_slider"]},
},
yaxis={"title": "Average Reward"},
title=title,
showlegend=False,
)
return fig

View File

@@ -1,2 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""

View File

@@ -1,735 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from logging import Logger
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
from gym import spaces
from primaite.acl.acl_rule import ACLRule
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
# This dependency is only needed for type hints,
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER: Logger = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite") -> None:
"""
Initialise observation component.
:param env: Primaite training environment.
:type env: Primaite
"""
_LOGGER.info(f"Initialising {self} observation component")
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
self.structure: List[str]
return NotImplemented
@abstractmethod
def update(self) -> None:
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
@abstractmethod
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
return NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""
Table with nodes and links as rows and hardware/software status as cols.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
The number of columns is 4 plus one per service. They are:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
"""
_FIXED_PARAMETERS: int = 4
_MAX_VAL: int = 1_000_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeLinkTable observation space component.
:param env: Training environment.
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
num_items = self.env.num_links + self.env.num_nodes
num_columns = self.env.num_services + self._FIXED_PARAMETERS
observation_shape = (num_items, num_columns)
# 2. Create Observation space
self.space = spaces.Box(
low=0,
high=self._MAX_VAL,
shape=observation_shape,
dtype=self._DATA_TYPE,
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeLinkTable`
"""
item_index = 0
nodes = self.env.nodes
links = self.env.links
# Do nodes first
for _, node in nodes.items():
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][3] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][service_index] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.env.services_list:
self.current_observation[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for _, link in links.items():
self.current_observation[item_index][0] = int(link.get_id())
self.current_observation[item_index][1] = 0
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
structure = []
for i, node in enumerate(nodes):
node_id = node.node_id
node_labels = [
f"node_{node_id}_id",
f"node_{node_id}_hardware_status",
f"node_{node_id}_os_status",
f"node_{node_id}_fs_status",
]
for j, serv in enumerate(self.env.services_list):
node_labels.append(f"node_{node_id}_service_{serv}_status")
structure.extend(node_labels)
for i, link in enumerate(links):
link_id = link.id
link_labels = [
f"link_{link_id}_id",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
]
for j, serv in enumerate(self.env.services_list):
link_labels.append(f"link_{link_id}_service_{serv}_load")
structure.extend(link_labels)
return structure
class NodeStatuses(AbstractObservationComponent):
"""
Flat list of nodes' hardware, OS, file system, and service states.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each node has 3 elements plus 1 per service. It will have the following structure:
.. code-block::
[
node1 hardware state,
node1 OS state,
node1 file system state,
node1 service1 state,
node1 service2 state,
node1 serviceN state (one for each service),
node2 hardware state,
node2 OS state,
node2 file system state,
node2 service1 state,
node2 service2 state,
node2 serviceN state (one for each service),
...
]
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite") -> None:
"""
Initialise a NodeStatuses observation component.
:param env: Training environment.
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
node_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
services_shape = [len(SoftwareState) + 1] * self.env.num_services
node_shape = node_shape + services_shape
shape = node_shape * self.env.num_nodes
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeStatuses`
"""
obs = []
for _, node in self.env.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
service_states = [0] * self.env.num_services
if isinstance(node, ActiveNode):
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*service_states,
]
)
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
structure = []
for _, node in self.env.nodes.items():
node_id = node.node_id
structure.append(f"node_{node_id}_hardware_state_NONE")
for state in HardwareState:
structure.append(f"node_{node_id}_hardware_state_{state.name}")
structure.append(f"node_{node_id}_software_state_NONE")
for state in SoftwareState:
structure.append(f"node_{node_id}_software_state_{state.name}")
structure.append(f"node_{node_id}_file_system_state_NONE")
for state in FileSystemState:
structure.append(f"node_{node_id}_file_system_state_{state.name}")
for service in services:
structure.append(f"node_{node_id}_service_{service}_state_NONE")
for state in SoftwareState:
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
return structure
class LinkTrafficLevels(AbstractObservationComponent):
"""
Flat list of traffic levels encoded into banded categories.
For each link, total traffic or traffic per service is encoded into a categorical value.
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
* 0 = No traffic (0% of bandwidth)
* 1 = No traffic (0%-33% of bandwidth)
* 2 = No traffic (33%-66% of bandwidth)
* 3 = No traffic (66%-100% of bandwidth)
* 4 = No traffic (100% of bandwidth)
.. note::
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
"""
_DATA_TYPE: type = np.int64
def __init__(
self,
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
) -> None:
"""
Initialise a LinkTrafficLevels observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical
value, defaults to 5
:type quantisation_levels: int, optional
"""
if quantisation_levels < 3:
_msg = (
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
f"Resetting to default value (5)"
)
_LOGGER.warning(_msg)
quantisation_levels = 5
super().__init__(env)
self._combine_service_traffic: bool = combine_service_traffic
self._quantisation_levels: int = quantisation_levels
self._entries_per_link: int = 1
if not self._combine_service_traffic:
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""
Update the observation based on current environment state.
The structure of the observation space is described in :class:`.LinkTrafficLevels`
"""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [protocol.get_load() for protocol in link.protocol_list]
for load in loads:
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
obs.append(int(traffic_level))
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
link_id = link.id
if self._combine_service_traffic:
protocols = ["overall"]
else:
protocols = [protocol.name for protocol in link.protocol_list]
for p in protocols:
for i in range(self._quantisation_levels):
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
return structure
class AccessControlList(AbstractObservationComponent):
"""Flat list of all the Access Control Rules in the Access Control List.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each ACL Rule has 6 elements. It will have the following structure:
.. code-block::
[
acl_rule1 permission,
acl_rule1 source_ip,
acl_rule1 dest_ip,
acl_rule1 protocol,
acl_rule1 port,
acl_rule1 position,
acl_rule2 permission,
acl_rule2 source_ip,
acl_rule2 dest_ip,
acl_rule2 protocol,
acl_rule2 port,
acl_rule2 position,
...
]
Terms (for ACL Observation Space):
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite") -> None:
"""
Initialise an AccessControlList observation component.
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
super().__init__(env)
# 1. Define the shape of your observation space component
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
# Number of ACL rules incremented by 1 for positions starting at index 0.
acl_shape = [
len(RulePermissionType),
len(env.nodes) + 2,
len(env.nodes) + 2,
len(env.services_list) + 2,
len(env.ports_list) + 2,
env.max_number_acl_rules,
]
shape = acl_shape * self.env.max_number_acl_rules
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self) -> None:
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.AccessControlList`
"""
obs = []
for index in range(0, len(self.env.acl.acl)):
acl_rule = self.env.acl.acl[index]
if isinstance(acl_rule, ACLRule):
permission = acl_rule.permission
source_ip = acl_rule.source_ip
dest_ip = acl_rule.dest_ip
protocol = acl_rule.protocol
port = acl_rule.port
position = index
# Map each ACL attribute from what it was to an integer to fit the observation space
source_ip_int = None
dest_ip_int = None
if permission == RulePermissionType.DENY:
permission_int = 1
else:
permission_int = 2
if source_ip == "ANY":
source_ip_int = 1
else:
# Map Node ID (+ 1) to source IP address
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == source_ip:
source_ip_int = int(node.node_id) + 1
break
if dest_ip == "ANY":
dest_ip_int = 1
else:
# Map Node ID (+ 1) to dest IP address
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
nodes = list(self.env.nodes.values())
for node in nodes:
if (
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
) and node.ip_address == dest_ip:
dest_ip_int = int(node.node_id) + 1
if protocol == "ANY":
protocol_int = 1
else:
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
try:
protocol_int = self.env.services_list.index(protocol) + 2
except AttributeError:
_LOGGER.info(f"Service {protocol} could not be found")
protocol_int = None
if port == "ANY":
port_int = 1
else:
if port in self.env.ports_list:
port_int = self.env.ports_list.index(port) + 2
else:
_LOGGER.info(f"Port {port} could not be found.")
port_int = None
# Add to current obs
obs.extend(
[
permission_int,
source_ip_int,
dest_ip_int,
protocol_int,
port_int,
position,
]
)
else:
# The Nothing or NA representation of 'NONE' ACL rules
obs.extend([0, 0, 0, 0, 0, 0])
self.current_observation[:] = obs
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for acl_rule in self.env.acl.acl:
acl_rule_id = self.env.acl.acl.index(acl_rule)
for permission in RulePermissionType:
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
for node in self.env.nodes.keys():
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
for service in self.env.services_list:
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
for port in self.env.ports_list:
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
return structure
class ObservationsHandler:
"""
Component-based observation space handler.
This allows users to configure observation spaces by mixing and matching components. Each component can also define
further parameters to make them more flexible.
"""
_REGISTRY: Final[Dict[str, type]] = {
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
"ACCESS_CONTROL_LIST": AccessControlList,
}
def __init__(self) -> None:
"""Initialise the observation handler."""
self.registered_obs_components: List[AbstractObservationComponent] = []
# internal the observation space (unflattened version of space if flatten=True)
self._space: spaces.Space
# flattened version of the observation space
self._flat_space: spaces.Space
self._observation: Union[Tuple[np.ndarray], np.ndarray]
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
def update_obs(self) -> None:
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
obs.update()
current_obs.append(obs.current_observation)
if len(current_obs) == 1:
self._observation = current_obs[0]
else:
self._observation = tuple(current_obs)
self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent) -> None:
"""
Add a component for this handler to track.
:param obs_component: The component to add.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent) -> None:
"""
Remove a component from this handler.
:param obs_component: Which component to remove. It must exist within this object's
``registered_obs_components`` attribute.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self) -> None:
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
# if there are multiple components, build a composite tuple space
if len(component_spaces) == 1:
self._space = component_spaces[0]
else:
self._space = spaces.Tuple(component_spaces)
if len(component_spaces) > 0:
self._flat_space = spaces.flatten_space(self._space)
else:
self._flat_space = spaces.Box(0, 1, (0,))
@property
def space(self) -> spaces.Space:
"""Observation space, return the flattened version if flatten is True."""
if len(self.registered_obs_components) > 1:
return self._flat_space
else:
return self._space
@property
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
"""Current observation, return the flattened version if flatten is True."""
if len(self.registered_obs_components) > 1:
return self._flat_observation
else:
return self._observation
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
"""
Parse a config dictinary, return a new observation handler populated with new observation component objects.
The expected format for the config dictionary is:
.. code-block:: python
config = {
components: [
{
"name": "<COMPONENT1_NAME>"
},
{
"name": "<COMPONENT2_NAME>"
"options": {"opt1": val1, "opt2": val2}
},
{
...
},
]
}
:return: Observation handler
:rtype: primaite.environment.observations.ObservationsHandler
"""
# Instantiate the handler
handler = cls()
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
comp_class = cls._REGISTRY[comp_type]
# Create the component with options from the YAML
options = component_cfg.get("options") or {}
component = comp_class(env, **options)
handler.register(component)
handler.update_obs()
return handler
def describe_structure(self) -> List[str]:
"""
Create a list of names for the features of the obs space.
The order of labels follows the flattened version of the space.
"""
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
# to fake it. each component has to just hard-code the expected label order after flattening...
labels = []
for obs_comp in self.registered_obs_components:
labels.extend(obs_comp.structure)
return labels

File diff suppressed because it is too large Load Diff

View File

@@ -1,386 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements reward function."""
from logging import Logger
from typing import Dict, TYPE_CHECKING, Union
from primaite import getLogger
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
if TYPE_CHECKING:
from primaite.config.training_config import TrainingConfig
from primaite.pol.ier import IER
_LOGGER: Logger = getLogger(__name__)
def calculate_reward_function(
initial_nodes: Dict[str, NodeUnion],
final_nodes: Dict[str, NodeUnion],
reference_nodes: Dict[str, NodeUnion],
green_iers: Dict[str, "IER"],
green_iers_reference: Dict[str, "IER"],
red_iers: Dict[str, "IER"],
step_count: int,
config_values: "TrainingConfig",
) -> float:
"""
Compares the states of the initial and final nodes/links to get a reward.
Args:
initial_nodes: The nodes before red and blue agents take effect
final_nodes: The nodes after red and blue agents take effect
reference_nodes: The nodes if there had been no red or blue effect
green_iers: The green IERs (should be running)
red_iers: Should be stopeed (ideally) by the blue agent
step_count: current step
config_values: Config values
"""
reward_value: float = 0.0
# For each node, compare hardware state, SoftwareState, service states
for node_key, final_node in final_nodes.items():
initial_node = initial_nodes[node_key]
reference_node = reference_nodes[node_key]
# Hardware State
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
# Service State
if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
# File System State
if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
# Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if ier_value.get_is_running():
reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted)
# but only if it's supposed to be running (it's running in reference)
for ier_key, ier_value in green_iers.items():
reference_ier = green_iers_reference[ier_key]
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
if live_blocked and not reference_blocked:
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference and live environments. "
f"Penalty of {ier_reward} was NOT applied."
)
)
elif not live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference env but not in the live one. "
f"Penalty of {ier_reward} was NOT applied."
)
)
return reward_value
def score_node_operating_state(
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the hardware state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score: float = 0.0
final_node_operating_state = final_node.hardware_state
reference_node_operating_state = reference_node.hardware_state
if final_node_operating_state == reference_node_operating_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_operating_state == HardwareState.ON:
if final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_on
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_on
else:
pass
elif reference_node_operating_state == HardwareState.OFF:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_off
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_off
else:
pass
elif reference_node_operating_state == HardwareState.RESETTING:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_resetting
elif final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_resetting
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting
else:
pass
else:
pass
return score
def score_node_os_state(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the Software State of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score: float = 0.0
final_node_os_state = final_node.software_state
reference_node_os_state = reference_node.software_state
if final_node_os_state == reference_node_os_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_os_state == SoftwareState.GOOD:
if final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
else:
pass
elif reference_node_os_state == SoftwareState.PATCHING:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif reference_node_os_state == SoftwareState.COMPROMISED:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised
else:
pass
else:
pass
return score
def score_node_service_state(
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
) -> float:
"""
Calculates score relating to the service state(s) of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score: float = 0.0
final_node_services: Dict[str, Service] = final_node.services
reference_node_services: Dict[str, Service] = reference_node.services
for service_key, final_service in final_node_services.items():
reference_service = reference_node_services[service_key]
final_service = final_node_services[service_key]
if final_service.software_state == reference_service.software_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final state of node (i.e. at every step)
if reference_service.software_state == SoftwareState.GOOD:
if final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_good
else:
pass
elif reference_service.software_state == SoftwareState.PATCHING:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_patching
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif reference_service.software_state == SoftwareState.COMPROMISED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_compromised
else:
pass
elif reference_service.software_state == SoftwareState.OVERWHELMED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_overwhelmed
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_overwhelmed
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_overwhelmed
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed
else:
pass
else:
pass
return score
def score_node_file_system(
final_node: Union[ActiveNode, ServiceNode],
initial_node: Union[ActiveNode, ServiceNode],
reference_node: Union[ActiveNode, ServiceNode],
config_values: "TrainingConfig",
) -> float:
"""
Calculates score relating to the file system state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
"""
score: float = 0.0
final_node_file_system_state = final_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual
final_node_scanning_state = final_node.file_system_scanning
reference_node_scanning_state = reference_node.file_system_scanning
# File System State
if final_node_file_system_state == reference_node_file_system_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final state of node (i.e. at every step)
if reference_node_file_system_state == FileSystemState.GOOD:
if final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_good
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_good
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_good
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_good
else:
pass
elif reference_node_file_system_state == FileSystemState.REPAIRING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_repairing
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_repairing
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_repairing
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_repairing
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing
else:
pass
elif reference_node_file_system_state == FileSystemState.RESTORING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_restoring
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_restoring
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_restoring
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_restoring
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring
else:
pass
elif reference_node_file_system_state == FileSystemState.CORRUPT:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_corrupt
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_corrupt
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_corrupt
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_corrupt
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt
else:
pass
elif reference_node_file_system_state == FileSystemState.DESTROYED:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_destroyed
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_destroyed
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_destroyed
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_destroyed
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed
else:
pass
else:
pass
# Scanning State
if final_node_scanning_state == reference_node_scanning_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# We're scanning the file system which incurs a penalty (as it slows down systems)
score += config_values.scanning
return score

View File

@@ -5,12 +5,6 @@ class PrimaiteError(Exception):
pass
class RLlibAgentError(PrimaiteError):
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
pass
class NetworkError(PrimaiteError):
"""Raised when an error occurs at the network level."""

View File

@@ -0,0 +1 @@
"""PrimAITE Game Layer."""

View File

@@ -0,0 +1,31 @@
# flake8: noqa
from typing import Dict, Optional, Tuple
from gymnasium.core import ActType, ObsType
from src.primaite.game.agent.actions import ActionManager
from src.primaite.game.agent.interface import AbstractGATEAgent, ObsType
from src.primaite.game.agent.observations import ObservationSpace
from src.primaite.game.agent.rewards import RewardFunction
class GATERLAgent(AbstractGATEAgent):
...
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
# because when we are supporting MARL, the actions form multiple agents will have to be batched
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
# with the actions for those agents.
def __init__(
self,
agent_name: str | None,
action_space: ActionManager | None,
observation_space: ObservationSpace | None,
reward_function: RewardFunction | None,
) -> None:
super().__init__(agent_name, action_space, observation_space, reward_function)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
return self.most_recent_action

View File

View File

@@ -0,0 +1,866 @@
"""
This module contains the ActionManager class which belongs to the Agent class.
An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of
AbstractAction. The ActionManager is responsible for:
1. Creating the action space from a list of action types.
2. Converting an integer action choice into a specific action and parameter choice.
3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This
ensures that requests conform to the simulator's request format.
"""
import itertools
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from src.primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from src.primaite.game.session import PrimaiteSession
class AbstractAction(ABC):
"""Base class for actions."""
@abstractmethod
def __init__(self, manager: "ActionManager", **kwargs) -> None:
"""
Init method for action.
All action init functions should accept **kwargs as a way of ignoring extra arguments.
Since many parameters are defined for the action space as a whole (such as max files per folder, max services
per node), we need to pass those options to every action that gets created. To pervent verbosity, these
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
"""
self.name: str = ""
"""Human-readable action identifier used for printing, logging, and reporting."""
self.shape: Dict[str, int] = {}
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
align with the keyword args of the form_request method."""
self.manager: ActionManager = manager
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
objects."""
@abstractmethod
def form_request(self) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return []
class DoNothingAction(AbstractAction):
"""Action which does nothing. This is here to allow agents to be idle if they choose to."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
self.name = "DONOTHING"
self.shape: Dict[str, int] = {
"dummy": 1,
}
# This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1),
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
# with one option. This just aids the Action Manager to enumerate all possibilities.
def form_request(self, **kwargs) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
class NodeServiceAbstractAction(AbstractAction):
"""
Base class for service actions.
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str
def form_request(self, node_id: int, service_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
if node_uuid is None or service_uuid is None:
return ["do_nothing"]
return ["network", "node", node_uuid, "services", service_uuid, self.verb]
class NodeServiceScanAction(NodeServiceAbstractAction):
"""Action which scans a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
"""Action which stops a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
"""Action which starts a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
"""Action which pauses a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
"""Action which resumes a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
"""Action which restarts a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
"""Action which disables a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
"""Action which enables a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "enable"
class NodeFolderAbstractAction(AbstractAction):
"""
Base class for folder actions.
Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from
this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str
def form_request(self, node_id: int, folder_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
if node_uuid is None or folder_uuid is None:
return ["do_nothing"]
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, self.verb]
class NodeFolderScanAction(NodeFolderAbstractAction):
"""Action which scans a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
"""Action which checks the hash of a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction):
"""Action which repairs a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction):
"""Action which restores a folder."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "restore"
class NodeFileAbstractAction(AbstractAction):
"""Abstract base class for file actions.
Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit
from this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
if node_uuid is None or folder_uuid is None or file_uuid is None:
return ["do_nothing"]
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, "files", file_uuid, self.verb]
class NodeFileScanAction(NodeFileAbstractAction):
"""Action which scans a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
"""Action which checks the hash of a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
"""Action which deletes a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
"""Action which repairs a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
"""Action which restores a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction):
"""Action which corrupts a file."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "corrupt"
class NodeAbstractAction(AbstractAction):
"""
Abstract base class for node actions.
Any action which applies to a node and uses node_id as its only parameter can inherit from this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str
def form_request(self, node_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
return ["network", "node", node_uuid, self.verb]
class NodeOSScanAction(NodeAbstractAction):
"""Action which scans a node's OS."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "scan"
class NodeShutdownAction(NodeAbstractAction):
"""Action which shuts down a node."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "shutdown"
class NodeStartupAction(NodeAbstractAction):
"""Action which starts up a node."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "startup"
class NodeResetAction(NodeAbstractAction):
"""Action which resets a node."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "reset"
class NetworkACLAddRuleAction(AbstractAction):
"""Action which adds a rule to a router's ACL."""
def __init__(
self,
manager: "ActionManager",
target_router_uuid: str,
max_acl_rules: int,
num_ips: int,
num_ports: int,
num_protocols: int,
**kwargs,
) -> None:
"""Init method for NetworkACLAddRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param target_router_uuid: UUID of the router to which the ACL rule should be added.
:type target_router_uuid: str
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
:param num_ips: Number of IP addresses in the simulation.
:type num_ips: int
:param num_ports: Number of ports in the simulation.
:type num_ports: int
:param num_protocols: Number of protocols in the simulation.
:type num_protocols: int
"""
super().__init__(manager=manager)
num_permissions = 3
self.shape: Dict[str, int] = {
"position": max_acl_rules,
"permission": num_permissions,
"source_ip_id": num_ips,
"dest_ip_id": num_ips,
"source_port_id": num_ports,
"dest_port_id": num_ports,
"protocol_id": num_protocols,
}
self.target_router_uuid: str = target_router_uuid
def form_request(
self,
position: int,
permission: int,
source_ip_id: int,
dest_ip_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif permission == 1:
permission_str = "ALLOW"
elif permission == 2:
permission_str = "DENY"
else:
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if source_ip_id in [0, 1]:
src_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_ip_id in (0, 1):
dst_ip = "ALL"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(dest_port_id)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
"network",
"node",
self.target_router_uuid,
"acl",
"add_rule",
permission_str,
protocol,
src_ip,
src_port,
dst_ip,
dst_port,
position,
]
class NetworkACLRemoveRuleAction(AbstractAction):
"""Action which removes a rule from a router's ACL."""
def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None:
"""Init method for NetworkACLRemoveRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param target_router_uuid: UUID of the router from which the ACL rule should be removed.
:type target_router_uuid: str
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
self.target_router_uuid: str = target_router_uuid
def form_request(self, position: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
class NetworkNICAbstractAction(AbstractAction):
"""
Abstract base class for NIC actions.
Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base
class.
"""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkNICAbstractAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param num_nodes: Number of nodes in the simulation.
:type num_nodes: int
:param max_nics_per_node: Maximum number of NICs per node.
:type max_nics_per_node: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str
def form_request(self, node_id: int, nic_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id)
nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id)
if node_uuid is None or nic_uuid is None:
return ["do_nothing"]
return [
"network",
"node",
node_uuid,
"nic",
nic_uuid,
self.verb,
]
class NetworkNICEnableAction(NetworkNICAbstractAction):
"""Action which enables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "enable"
class NetworkNICDisableAction(NetworkNICAbstractAction):
"""Action which disables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "disable"
class ActionManager:
"""Class which manages the action space for an agent."""
__act_class_identifiers: Dict[str, type] = {
"DONOTHING": DoNothingAction,
"NODE_SERVICE_SCAN": NodeServiceScanAction,
"NODE_SERVICE_STOP": NodeServiceStopAction,
"NODE_SERVICE_START": NodeServiceStartAction,
"NODE_SERVICE_PAUSE": NodeServicePauseAction,
"NODE_SERVICE_RESUME": NodeServiceResumeAction,
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
"NODE_FILE_SCAN": NodeFileScanAction,
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
"NODE_FILE_DELETE": NodeFileDeleteAction,
"NODE_FILE_REPAIR": NodeFileRepairAction,
"NODE_FILE_RESTORE": NodeFileRestoreAction,
"NODE_FILE_CORRUPT": NodeFileCorruptAction,
"NODE_FOLDER_SCAN": NodeFolderScanAction,
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
"NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
"NODE_OS_SCAN": NodeOSScanAction,
"NODE_SHUTDOWN": NodeShutdownAction,
"NODE_STARTUP": NodeStartupAction,
"NODE_RESET": NodeResetAction,
"NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""
def __init__(
self,
session: "PrimaiteSession", # reference to session for looking up stuff
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
max_folders_per_node: int = 2, # allows calculating shape
max_files_per_folder: int = 2, # allows calculating shape
max_services_per_node: int = 2, # allows calculating shape
max_nics_per_node: int = 8, # allows calculating shape
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
"""Init method for ActionManager.
:param session: Reference to the session to which the agent belongs.
:type session: PrimaiteSession
:param actions: List of action types which should be made available to the agent.
:type actions: List[str]
:param node_uuids: List of node UUIDs that this agent can act on.
:type node_uuids: List[str]
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
:type max_folders_per_node: int
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
:type max_files_per_folder: int
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
:type max_services_per_node: int
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
:type max_nics_per_node: int
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
:type max_acl_rules: int
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
:type protocols: List[str]
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
:type ports: List[str]
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
:type ip_address_list: Optional[List[str]]
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
self.session: "PrimaiteSession" = session
self.sim: Simulation = self.session.simulation
self.node_uuids: List[str] = node_uuids
self.protocols: List[str] = protocols
self.ports: List[str] = ports
self.ip_address_list: List[str]
if ip_address_list is not None:
self.ip_address_list = ip_address_list
else:
self.ip_address_list = []
for node_uuid in self.node_uuids:
node_obj = self.sim.network.nodes[node_uuid]
nics = node_obj.nics
for nic_uuid, nic_obj in nics.items():
self.ip_address_list.append(nic_obj.ip_address)
# action_args are settings which are applied to the action space as a whole.
global_action_args = {
"num_nodes": len(node_uuids),
"num_folders": max_folders_per_node,
"num_files": max_files_per_folder,
"num_services": max_services_per_node,
"num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules,
"num_protocols": len(self.protocols),
"num_ports": len(self.protocols),
"num_ips": len(self.ip_address_list),
"max_acl_rules": max_acl_rules,
"max_nics_per_node": max_nics_per_node,
}
self.actions: Dict[str, AbstractAction] = {}
for act_spec in actions:
# each action is provided into the action space config like this:
# - type: ACTION_TYPE
# options:
# option_1: value1
# option_2: value2
# where `type` decides which AbstractAction subclass should be used
# and `options` is an optional dict of options to pass to the init method of the action class
act_type = act_spec.get("type")
act_options = act_spec.get("options", {})
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options)
self.action_map: Dict[int, Tuple[str, Dict]] = {}
"""
Action mapping that converts an integer to a specific action and parameter choice.
For example :
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
"""
if act_map is None:
self.action_map = self._enumerate_actions()
else:
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
# make sure all numbers between 0 and N are represented as dict keys in action map
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
def _enumerate_actions(
self,
) -> Dict[int, Tuple[str, Dict]]:
"""Generate a list of all the possible actions that could be taken.
This enumerates all actions all combinations of parametes you could choose for those actions. The output
of this function is intended to populate the self.action_map parameter in the situation where the user provides
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
The enumeration relies on the Actions' `shape` attribute.
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
An example output could be:
{0: ("DONOTHING", {'dummy': 0}),
1: ("NODE_OS_SCAN", {'node_id': 0}),
2: ("NODE_OS_SCAN", {'node_id': 1}),
3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}),
... #etc...
}
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
"""
all_action_possibilities = []
for act_name, action in self.actions.items():
param_names = list(action.shape.keys())
num_possibilities = list(action.shape.values())
possibilities = [range(n) for n in num_possibilities]
param_combinations = list(itertools.product(*possibilities))
all_action_possibilities.extend(
[
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
for j in range(len(param_combinations))
]
)
return {i: p for i, p in enumerate(all_action_possibilities)}
def get_action(self, action: int) -> Tuple[str, Dict]:
"""Produce action in CAOS format."""
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
"""The CAOS format is basically a action identifier, followed by parameters stored in a dictionary"""
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_obj = self.actions[action_identifier]
return act_obj.form_request(**action_options)
@property
def space(self) -> spaces.Space:
"""Return the gymnasium action space for this agent."""
return spaces.Discrete(len(self.action_map))
def get_node_uuid_by_idx(self, node_idx: int) -> str:
"""
Get the node UUID corresponding to the given index.
:param node_idx: The index of the node to retrieve.
:type node_idx: int
:return: The node UUID.
:rtype: str
"""
return self.node_uuids[node_idx]
def get_folder_uuid_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
"""
Get the folder UUID corresponding to the given node and folder indices.
:param node_idx: The index of the node.
:type node_idx: int
:param folder_idx: The index of the folder on the node.
:type folder_idx: int
:return: The UUID of the folder. Or None if the node has fewer folders than the given index.
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
def get_file_uuid_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
"""Get the file UUID corresponding to the given node, folder, and file indices.
:param node_idx: The index of the node.
:type node_idx: int
:param folder_idx: The index of the folder on the node.
:type folder_idx: int
:param file_idx: The index of the file in the folder.
:type file_idx: int
:return: The UUID of the file. Or None if the node has fewer folders than the given index, or the folder has
fewer files than the given index.
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
if len(folder_uuids) <= folder_idx:
return None
folder = node.file_system.folders[folder_uuids[folder_idx]]
file_uuids = list(folder.files.keys())
return file_uuids[file_idx] if len(file_uuids) > file_idx else None
def get_service_uuid_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
"""Get the service UUID corresponding to the given node and service indices.
:param node_idx: The index of the node.
:type node_idx: int
:param service_idx: The index of the service on the node.
:type service_idx: int
:return: The UUID of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
service_uuids = list(node.services.keys())
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
"""Get the internet protocol corresponding to the given index.
:param protocol_idx: The index of the protocol to retrieve.
:type protocol_idx: int
:return: The protocol.
:rtype: str
"""
return self.protocols[protocol_idx]
def get_ip_address_by_idx(self, ip_idx: int) -> str:
"""
Get the IP address corresponding to the given index.
:param ip_idx: The index of the IP address to retrieve.
:type ip_idx: int
:return: The IP address.
:rtype: str
"""
return self.ip_address_list[ip_idx]
def get_port_by_idx(self, port_idx: int) -> str:
"""
Get the port corresponding to the given index.
:param port_idx: The index of the port to retrieve.
:type port_idx: int
:return: The port.
:rtype: str
"""
return self.ports[port_idx]
def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str:
"""
Get the NIC UUID corresponding to the given node and NIC indices.
:param node_idx: The index of the node.
:type node_idx: int
:param nic_idx: The index of the NIC on the node.
:type nic_idx: int
:return: The NIC UUID.
:rtype: str
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node_obj = self.sim.network.nodes[node_uuid]
nics = list(node_obj.nics.keys())
if len(nics) <= nic_idx:
return None
return nics[nic_idx]
@classmethod
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
"""
Construct an ActionManager from a config definition.
The action space config supports the following three sections:
1. ``action_list``
``action_list`` contians a list action components which need to be included in the action space.
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
which will be passed to the action class's __init__ method during initialisation.
2. ``action_map``
Since the agent uses a discrete action space which acts as a flattened version of the component-based
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be "
3. ``options``
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
These options are used to calculate the shape of the action space, and to provide additional information
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
:param session: The Primaite Session to which the agent belongs.
:type session: PrimaiteSession
:param cfg: The action space config.
:type cfg: Dict
:return: The constructed ActionManager.
:rtype: ActionManager
"""
obj = cls(
session=session,
actions=cfg["action_list"],
# node_uuids=cfg["options"]["node_uuids"],
**cfg["options"],
protocols=session.options.protocols,
ports=session.options.ports,
ip_address_list=None,
act_map=cfg.get("action_map"),
)
return obj

View File

@@ -0,0 +1,116 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
import numpy as np
from src.primaite.game.agent.actions import ActionManager
from src.primaite.game.agent.observations import ObservationSpace
from src.primaite.game.agent.rewards import RewardFunction
ObsType: TypeAlias = Union[Dict, np.ndarray]
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
"""
Initialize an agent.
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
:type agent_name: Optional[str]
:param action_space: Action space for the agent.
:type action_space: Optional[ActionManager]
:param observation_space: Observation space for the agent.
:type observation_space: Optional[ObservationSpace]
:param reward_function: Reward function for the agent.
:type reward_function: Optional[RewardFunction]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
def convert_state_to_obs(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
def calculate_reward_from_state(self, state: Dict) -> float:
"""
Use the reward function to calculate a reward from the state.
:param state: State of the environment.
:type state: Dict
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.calculate(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
Subclasses should implement agent logic here. It should use the observation as input to decide best next action.
:param obs: Observation of the environment.
:type obs: ObsType
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
:type reward: float, optional
:return: Action to be taken in the environment.
:rtype: Tuple[str, Dict]
"""
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.action_space.form_request(action_identifier=action, action_options=options)
return request
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""
...
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:rtype: Tuple[str, Dict]
"""
return self.action_space.get_action(self.action_space.space.sample())
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""
...

View File

@@ -0,0 +1,984 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from src.primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from src.primaite.game.session import PrimaiteSession
class AbstractObservation(ABC):
"""Abstract class for an observation space component."""
@abstractmethod
def observe(self, state: Dict) -> Any:
"""
Return an observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Any
"""
pass
@property
@abstractmethod
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space."""
pass
@classmethod
@abstractmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
"""
pass
class FileObservation(AbstractObservation):
"""Observation of a file on a node in the network."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
Initialise file observation.
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["health_status"]}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param session: _description_
:type session: PrimaiteSession
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
:rtype: _type_
"""
return cls(where=parent_where + ["files", config["file_name"]])
class ServiceObservation(AbstractObservation):
"""Observation of a service in the network."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise service observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": service_state["operating_state"], "health_status": service_state["health_state"]}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 10) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": utilisation_category}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
class FolderObservation(AbstractObservation):
"""Folder observation, including files inside of the folder."""
def __init__(
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
) -> None:
"""Initialise folder Observation, including files inside of the folder.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
:type max_files: int, optional
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
that even if new files are created, the existing files will always occupy the same space in the observation
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
name, it will take the position defined in this dict. Defaults to {}
:type file_positions: Dict[int, str], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.files: List[FileObservation] = files
while len(self.files) < num_files_per_folder:
self.files.append(FileObservation())
while len(self.files) > num_files_per_folder:
truncated_file = self.files.pop()
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
_LOGGER.warn(msg)
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed folder observation
:rtype: FolderObservation
"""
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
default_observation: spaces.Space = {"nic_status": 0}
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
return {"nic_status": 1 if nic_state["enabled"] else 2}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
class NodeObservation(AbstractObservation):
"""Observation of a node in the network. Includes services, folders and NICs."""
def __init__(
self,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
nics: List[NicObservation] = [],
logon_status: bool = False,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> None:
"""
Configurable observation for a node in the simulation.
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service UUID, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
:type max_services: int, optional
:param folders: Mapping between position in observation space and folder name, defaults to {}
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
:type nics: Dict[int,str], optional
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
:type max_nics: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.services: List[ServiceObservation] = services
while len(self.services) < num_services_per_node:
# add empty service observation without `where` parameter so it always returns default (blank) observation
self.services.append(ServiceObservation())
while len(self.services) > num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warn(msg)
# truncate service list
self.folders: List[FolderObservation] = folders
# add empty folder observation without `where` parameter that will always return default (blank) observations
while len(self.folders) < num_folders_per_node:
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warn(msg)
self.nics: List[NicObservation] = nics
while len(self.nics) < num_nics_per_node:
self.nics.append(NicObservation())
while len(self.nics) > num_nics_per_node:
truncated_nic = self.nics.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warn(msg)
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
"operating_status": 0,
}
if self.logon_status:
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.logon_status:
obs["logon_status"] = 0
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
}
if self.logon_status:
space_shape["logon_status"] = spaces.Discrete(3)
return spaces.Dict(space_shape)
@classmethod
def from_config(
cls,
config: Dict,
session: "PrimaiteSession",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> "NodeObservation":
"""Create node observation from a config. Also creates child service, folder and NIC observations.
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
observation size) , defaults to 2
:type num_services_per_node: int, optional
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
observation size) , defaults to 2
:type num_folders_per_node: int, optional
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = session.ref_map_nodes[config["node_ref"]]
if parent_where is None:
where = ["network", "nodes", node_uuid]
else:
where = parent_where + ["nodes", node_uuid]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
services=services,
folders=folders,
nics=nics,
logon_status=logon_status,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_uuid>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
for i, rule_state in acl_state.items():
if rule_state is None:
obs[i + 1] = {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
obs[i + 1] = {
"position": i,
"permission": rule_state["action"],
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
"source_port": self.port_to_id[rule_state["src_port"]],
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
"dest_port": self.port_to_id[rule_state["dst_port"]],
"protocol": self.protocol_to_id[rule_state["protocol"]],
}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_ref"]
nic_num = ip_map_config["nic_num"]
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
num_rules=max_acl_rules,
)
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
"""
Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
pass
class UC2BlueObservation(AbstractObservation):
"""Container for all observations used by the blue agent in UC2.
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
for the purpose of compiling several observation components.
"""
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
"""Initialise UC2 blue observation.
:param nodes: List of node observations
:type nodes: List[NodeObservation]
:param links: List of link observations
:type links: List[LinkObservation]
:param acl: The Access Control List observation
:type acl: AclObservation
:param ics: The ICS observation
:type ics: ICSObservation
:param where: Where in the simulation state dict to find information. Not used in this particular observation
because it only compiles other observations and doesn't contribute any new information, defaults to None
:type where: Optional[List[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
num_nics_per_node = config["num_nics_per_node"]
nodes = [
NodeObservation.from_config(
config=n,
session=session,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for n in node_configs
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, session=session)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, session=session)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
"""Container for all observations used by the red agent in UC2."""
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation):
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
pass
class ObservationSpace:
"""
Manage the observations of an Agent.
The observation space has the purpose of:
1. Reading the outputted state from the PrimAITE Simulation.
2. Selecting parts of the simulation state that are requested by the simulation config
3. Formatting this information so an agent can use it to make decisions.
"""
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
# refactor.
def __init__(self, observation: AbstractObservation) -> None:
"""Initialise observation space.
:param observation: Observation object
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
def observe(self, state: Dict) -> Dict:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
"""
return self.obs.observe(state)
@property
def space(self) -> None:
"""Gymnasium space object describing the observation space shape."""
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
else:
raise ValueError("Observation space type invalid")

View File

@@ -0,0 +1,284 @@
"""
Manages the reward function for the agent.
Each agent is equipped with a RewardFunction, which is made up of a list of reward components. The components are
designed to calculate a reward value based on the current state of the simulation. The overall reward function is a
weighed sum of the components.
The reward function is typically specified using a config yaml file or a config dictionary. The following example shows
the structure:
```yaml
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_database_client
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, TYPE_CHECKING
from primaite import getLogger
from src.primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from src.primaite.game.session import PrimaiteSession
class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict) -> float:
"""Calculate the reward for the current state."""
return 0.0
@classmethod
@abstractmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:return: The reward component.
:rtype: AbstractReward
"""
return cls()
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict) -> float:
"""Calculate the reward for the current state."""
return 0.0
@classmethod
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
:type config: dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
"""
return cls()
class DatabaseFileIntegrity(AbstractReward):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the database file.
:type node_uuid: str
:param folder_name: folder which contains the database file.
:type folder_name: str
:param file_name: name of the database file.
:type file_name: str
"""
self.location_in_state = [
"network",
"nodes",
node_uuid,
"file_system",
"folders",
folder_name,
"files",
file_name,
]
def calculate(self, state: Dict) -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:type state: Dict
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
health_status = database_file_state["health_status"]
if health_status == "corrupted":
return -1
elif health_status == "good":
return 1
else:
return 0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
node_ref = config.get("node_ref")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not node_ref:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not folder_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
if not file_name:
_LOGGER.error(
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
)
return DummyReward() # TODO: better error handling
node_uuid = session.ref_map_nodes[node_ref]
if not node_uuid:
_LOGGER.error(
(
f"{cls.__name__} could not be initialised from config because the referenced node could not be "
f"found in the simulation"
)
)
return DummyReward() # TODO: better error handling
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_uuid: str, service_uuid: str) -> None:
"""Initialise the reward component.
:param node_uuid: UUID of the node which contains the web server service.
:type node_uuid: str
:param service_uuid: UUID of the web server service.
:type service_uuid: str
"""
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
def calculate(self, state: Dict) -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:type state: Dict
"""
web_service_state = access_from_nested_dict(state, self.location_in_state)
if web_service_state is NOT_PRESENT_IN_STATE:
print("error getting web service state")
return 0.0
most_recent_return_code = web_service_state["last_response_status_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
if most_recent_return_code == 200:
return 1.0
elif most_recent_return_code == 404:
return -1.0
else:
return 0.0
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:return: The reward component.
:rtype: WebServer404Penalty
"""
node_ref = config.get("node_ref")
service_ref = config.get("service_ref")
if not (node_ref and service_ref):
msg = (
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
"found in reward config."
)
_LOGGER.warn(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = session.ref_map_nodes[node_ref]
service_uuid = session.ref_map_services[service_ref].uuid
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
" found in the simulator."
)
_LOGGER.warn(msg)
return DummyReward() # TODO: consider erroring here as well
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
class RewardFunction:
"""Manages the reward function for the agent."""
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
}
def __init__(self):
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.
:param component: Instance of a reward component.
:type component: AbstractReward
:param weight: Relative weight of the reward component, defaults to 1.0
:type weight: float, optional
"""
self.reward_components.append((component, weight))
def calculate(self, state: Dict) -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
:type state: Dict
"""
total = 0.0
for comp_and_weight in self.reward_components:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
return total
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
"""Create a reward function from a config dictionary.
:param config: dict of options for the reward manager's constructor
:type config: Dict
:param session: Reference to the PrimAITE Session object
:type session: PrimaiteSession
:return: The reward manager.
:rtype: RewardFunction
"""
new = cls()
for rew_component_cfg in config["reward_components"]:
rew_type = rew_component_cfg["type"]
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.__rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
new.regsiter_component(component=rew_instance, weight=weight)
return new

View File

@@ -0,0 +1,14 @@
"""Agents with predefined behaviours."""
from src.primaite.game.agent.interface import AbstractScriptedAgent
class GreenWebBrowsingAgent(AbstractScriptedAgent):
"""Scripted agent which attempts to send web requests to a target node."""
raise NotImplementedError
class RedDatabaseCorruptingAgent(AbstractScriptedAgent):
"""Scripted agent which attempts to corrupt the database of the target node."""
raise NotImplementedError

View File

@@ -0,0 +1,30 @@
from typing import Any, Dict, Hashable, Sequence
NOT_PRESENT_IN_STATE = object()
"""
Need an object to return when the sim state does not contain a requested value. Cannot use None because sometimes
the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is a sentinel for this purpose.
"""
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
"""
Access an item from a deeply dictionary with a list of keys.
For example, if the dictionary is {1: 'a', 2: {3: {4: 'b'}}}, then the key [2, 3, 4] would return 'b', and the key
[2, 3] would return {4: 'b'}. Raises a KeyError if specified key does not exist at any level of nesting.
:param dictionary: Deeply nested dictionary
:type dictionary: Dict
:param keys: List of dict keys used to traverse the nested dict. Each item corresponds to one level of depth.
:type keys: List[Hashable]
:return: The value in the dictionary
:rtype: Any
"""
key_list = [*keys] # copy keys to a new list to prevent editing original list
if len(key_list) == 0:
return dictionary
k = key_list.pop(0)
if k not in dictionary:
return NOT_PRESENT_IN_STATE
return access_from_nested_dict(dictionary[k], key_list)

View File

@@ -0,0 +1,471 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple
from arcd_gate.client.gate_client import ActType, GATEClient
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from gymnasium.spaces.utils import flatten, flatten_space
from pydantic import BaseModel
from primaite import getLogger
from src.primaite.game.agent.actions import ActionManager
from src.primaite.game.agent.interface import AbstractAgent, RandomAgent
from src.primaite.game.agent.observations import ObservationSpace
from src.primaite.game.agent.rewards import RewardFunction
from src.primaite.simulator.network.hardware.base import Link, NIC, Node
from src.primaite.simulator.network.hardware.nodes.computer import Computer
from src.primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from src.primaite.simulator.network.hardware.nodes.server import Server
from src.primaite.simulator.network.hardware.nodes.switch import Switch
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
from src.primaite.simulator.network.transmission.transport_layer import Port
from src.primaite.simulator.sim_container import Simulation
from src.primaite.simulator.system.applications.application import Application
from src.primaite.simulator.system.applications.database_client import DatabaseClient
from src.primaite.simulator.system.applications.web_browser import WebBrowser
from src.primaite.simulator.system.services.database.database_service import DatabaseService
from src.primaite.simulator.system.services.dns.dns_client import DNSClient
from src.primaite.simulator.system.services.dns.dns_server import DNSServer
from src.primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from src.primaite.simulator.system.services.service import Service
from src.primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteGATEClient(GATEClient):
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
"""
Create a new GATE client for PrimAITE.
:param parent_session: The parent session object.
:type parent_session: PrimaiteSession
:param service_port: The port on which the GATE service is running.
:type service_port: int, optional
"""
super().__init__(service_port=service_port)
self.parent_session: "PrimaiteSession" = parent_session
@property
def rl_framework(self) -> str:
"""The reinforcement learning framework to use."""
return self.parent_session.training_options.rl_framework
@property
def rl_algorithm(self) -> str:
"""The reinforcement learning algorithm to use."""
return self.parent_session.training_options.rl_algorithm
@property
def seed(self) -> int | None:
"""The seed to use for the environment's random number generator."""
return self.parent_session.training_options.seed
@property
def n_learn_episodes(self) -> int:
"""The number of episodes in each learning run."""
return self.parent_session.training_options.n_learn_episodes
@property
def n_learn_steps(self) -> int:
"""The number of steps in each learning episode."""
return self.parent_session.training_options.n_learn_steps
@property
def n_eval_episodes(self) -> int:
"""The number of episodes in each evaluation run."""
return self.parent_session.training_options.n_eval_episodes
@property
def n_eval_steps(self) -> int:
"""The number of steps in each evaluation episode."""
return self.parent_session.training_options.n_eval_steps
@property
def action_space(self) -> spaces.Space:
"""The gym action space of the agent."""
return self.parent_session.rl_agent.action_space.space
@property
def observation_space(self) -> spaces.Space:
"""The gymnasium observation space of the agent."""
return flatten_space(self.parent_session.rl_agent.observation_space.space)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
"""Take a step in the environment.
This method is called by GATE to advance the simulation by one timestep.
:param action: The agent's action.
:type action: ActType
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
:rtype: Tuple[ObsType, float, bool, bool, Dict]
"""
self.parent_session.rl_agent.most_recent_action = action
self.parent_session.step()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
rew = self.parent_session.rl_agent.reward_function.calculate(state)
term = False
trunc = False
info = {}
return obs, rew, term, trunc, info
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
"""Reset the environment.
This method is called when the environment is initialized and at the end of each episode.
:param seed: The seed to use for the environment's random number generator.
:type seed: int, optional
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
compatibility with GATE.
:type options: dict[str, Any], optional
:return: The initial observation and an empty info dictionary.
:rtype: Tuple[ObsType, Dict]
"""
self.parent_session.reset()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
return obs, {}
def close(self):
"""Close the session, this will stop the gate client and close the simulation."""
self.parent_session.close()
class PrimaiteSessionOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
ports: List[str]
protocols: List[str]
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
rl_framework: str
rl_algorithm: str
seed: Optional[int]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int
n_eval_steps: int
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
def __init__(self):
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agent: AbstractAgent
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
def start_session(self) -> None:
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
self.gate_client.start()
def step(self):
"""
Perform one step of the simulation/agent loop.
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
agent. The steps are as follows:
1. The simulation state is updated.
2. The simulation state is sent to each agent.
3. Each agent converts the state to an observation and calculates a reward.
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# currently designed with assumption that all agents act once per step in order
for agent in self.agents:
# 3. primaite session asks simulation to provide initial state
# 4. primate session gives state to all agents
# 5. primaite session asks agents to produce an action based on most recent state
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
sim_state = self.simulation.describe_state()
# 6. each agent takes most recent state and converts it to CAOS observation
agent_obs = agent.convert_state_to_obs(sim_state)
# 7. meanwhile each agent also takes state and calculates reward
agent_reward = agent.calculate_reward_from_state(sim_state)
# 8. each agent takes observation and applies decision rule to observation to create CAOS
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
# code should live inside of the GATE agent subclass)
# gets action in CAOS format
_LOGGER.debug("Getting agent action")
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
agent_request = agent.format_request(agent_action, action_options)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
self.simulation.apply_request(agent_request)
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
self.simulation.apply_timestep(self.step_counter)
self.step_counter += 1
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
return NotImplemented
def close(self) -> None:
"""Close the session, this will stop the gate client and close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent. Used by GATE.
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
# TODO: create documentation page and add links to it here.
:param cfg: The config dictionary.
:type cfg: dict
:return: A PrimaiteSession object.
:rtype: PrimaiteSession
"""
sess = cls()
sess.options = PrimaiteSessionOptions(
ports=cfg["game_config"]["ports"],
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
sim = sess.simulation
net = sim.network
sess.ref_map_nodes: Dict[str, Node] = {}
sess.ref_map_services: Dict[str, Service] = {}
sess.ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
for node_cfg in nodes_cfg:
node_ref = node_cfg["ref"]
n_type = node_cfg["type"]
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg["dns_server"],
)
elif n_type == "server":
new_node = Server(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg.get("dns_server"),
)
elif n_type == "switch":
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
elif n_type == "router":
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
# this: 'r_cfg.get('src_port')'
# Port/IPProtocol. TODO Refactor
new_node.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("ip_address"),
dst_ip_address=r_cfg.get("ip_address"),
position=r_num,
)
else:
print("invalid node type")
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
service_ref = service_cfg["ref"]
service_type = service_cfg["type"]
service_types_mapping = {
"DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself.
"DNSServer": DNSServer,
"DatabaseClient": DatabaseClient,
"DatabaseService": DatabaseService,
"WebServer": WebServer,
"DataManipulationBot": DataManipulationBot,
}
if service_type in service_types_mapping:
print(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
sess.ref_map_services[service_ref] = new_service
else:
print(f"service type not found {service_type}")
# service-dependent options
if service_type == "DatabaseClient":
if "options" in service_cfg:
opt = service_cfg["options"]
if "db_server_ip" in opt:
new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"]))
if service_type == "DNSServer":
if "options" in service_cfg:
opt = service_cfg["options"]
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, ip)
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
application_ref = application_cfg["ref"]
application_type = application_cfg["type"]
application_types_mapping = {
"WebBrowser": WebBrowser,
}
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
sess.ref_map_applications[application_ref] = new_application
else:
print(f"application type not found {application_type}")
if "nics" in node_cfg:
for nic_num, nic_cfg in node_cfg["nics"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
net.add_node(new_node)
new_node.power_on()
sess.ref_map_nodes[
node_ref
] = (
new_node.uuid
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
if isinstance(node_a, Switch):
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
else:
endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]]
if isinstance(node_b, Switch):
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
else:
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
# 3. create agents
game_cfg = cfg["game_config"]
agents_cfg = game_cfg["agents"]
for agent_cfg in agents_cfg:
agent_ref = agent_cfg["ref"] # noqa: F841
agent_type = agent_cfg["type"]
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
# However, it's not possible to specify the node uuids directly in the config, as they are generated
# dynamically, so we have to translate node references to uuids before passing this config on.
if "action_list" in action_space_cfg:
for action_config in action_space_cfg["action_list"]:
if "options" in action_config:
if "target_router_ref" in action_config["options"]:
_target = action_config["options"]["target_router_ref"]
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
action_space = ActionManager.from_config(sess, action_space_cfg)
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
# TODO: implement non-random agents and fix this parsing
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agent = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
else:
print("agent type not found")
return sess

View File

@@ -1,2 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Network connections between nodes in the simulation."""

View File

@@ -1,114 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The link class."""
from typing import List
from primaite.common.protocol import Protocol
class Link(object):
"""Link class."""
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
"""
Initialise a Link within the simulated network.
:param _id: The IER id
:param _bandwidth: The bandwidth of the link (bps)
:param _source_node_name: The name of the source node
:param _dest_node_name: The name of the destination node
:param _protocols: The protocols to add to the link
"""
self.id: str = _id
self.bandwidth: int = _bandwidth
self.source_node_name: str = _source_node_name
self.dest_node_name: str = _dest_node_name
self.protocol_list: List[Protocol] = []
# Add the default protocols
for protocol_name in _services:
self.add_protocol(protocol_name)
def add_protocol(self, _protocol: str) -> None:
"""
Adds a new protocol to the list of protocols on this link.
Args:
_protocol: The protocol to be added (enum)
"""
self.protocol_list.append(Protocol(_protocol))
def get_id(self) -> str:
"""
Gets link ID.
Returns:
Link ID
"""
return self.id
def get_source_node_name(self) -> str:
"""
Gets source node name.
Returns:
Source node name
"""
return self.source_node_name
def get_dest_node_name(self) -> str:
"""
Gets destination node name.
Returns:
Destination node name
"""
return self.dest_node_name
def get_bandwidth(self) -> int:
"""
Gets bandwidth of link.
Returns:
Link bandwidth (bps)
"""
return self.bandwidth
def get_protocol_list(self) -> List[Protocol]:
"""
Gets list of protocols on this link.
Returns:
List of protocols on this link
"""
return self.protocol_list
def get_current_load(self) -> int:
"""
Gets current total load on this link.
Returns:
Total load on this link (bps)
"""
total_load = 0
for protocol in self.protocol_list:
total_load += protocol.get_load()
return total_load
def add_protocol_load(self, _protocol: str, _load: int) -> None:
"""
Adds a loading to a protocol on this link.
Args:
_protocol: The protocol to load
_load: The amount to load (bps)
"""
for protocol in self.protocol_list:
if protocol.get_name() == _protocol:
protocol.add_load(_load)
else:
pass
def clear_traffic(self) -> None:
"""Clears all traffic on this link."""
for protocol in self.protocol_list:
protocol.clear_load()

View File

@@ -5,17 +5,16 @@ from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
from src.primaite.config.load import load
from src.primaite.game.session import PrimaiteSession
# from src.primaite.primaite_session import PrimaiteSession
_LOGGER = getLogger(__name__)
def run(
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
config_path: Optional[Union[str, Path]] = "",
) -> None:
"""
Run the PrimAITE Session.
@@ -31,27 +30,17 @@ def run(
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
session = PrimaiteSession(
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
)
session.setup()
session.learn()
session.evaluate()
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg)
sess.start_session()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
parser.add_argument("--config")
args = parser.parse_args()
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
if not args.config:
_LOGGER.error("Please provide a config file using the --config " "argument")
run(session_path=args.config)

View File

@@ -1,2 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Nodes represent network hosts in the simulation."""

View File

@@ -1,208 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""An Active Node (i.e. not an actuator)."""
import logging
from typing import Final
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class ActiveNode(Node):
"""Active Node class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
ip_address: str,
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
) -> None:
"""
Initialise an active node.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
self.ip_address: str = ip_address
# Related to Software
self._software_state: SoftwareState = software_state
self.patching_count: int = 0
# Related to File System
self.file_system_state_actual: FileSystemState = file_system_state
self.file_system_state_observed: FileSystemState = file_system_state
self.file_system_scanning: bool = False
self.file_system_scanning_count: int = 0
self.file_system_action_count: int = 0
@property
def software_state(self) -> SoftwareState:
"""
Get the software_state.
:return: The software_state.
"""
return self._software_state
@software_state.setter
def software_state(self, software_state: SoftwareState) -> None:
"""
Get the software_state.
:param software_state: Software State.
"""
if self.hardware_state != HardwareState.OFF:
self._software_state = software_state
if software_state == SoftwareState.PATCHING:
self.patching_count = self.config_values.os_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so OS State cannot be "
f"changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
"""
Sets Software State if the node is not compromised.
Args:
software_state: Software State
"""
if self.hardware_state != HardwareState.OFF:
if self._software_state != SoftwareState.COMPROMISED:
self._software_state = software_state
if software_state == SoftwareState.PATCHING:
self.patching_count = self.config_values.os_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so OS State cannot be changed."
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.software_state:{self._software_state}"
)
def update_os_patching_status(self) -> None:
"""Updates operating system status based on patching cycle."""
self.patching_count -= 1
if self.patching_count <= 0:
self.patching_count = 0
self._software_state = SoftwareState.GOOD
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed).
Args:
file_system_state: File system state
"""
if self.hardware_state != HardwareState.OFF:
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so File System State "
f"cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
"""
Sets the file system state (actual and observed) if not in a compromised state.
Use for green PoL to prevent it overturning a compromised state
Args:
file_system_state: File system state
"""
if self.hardware_state != HardwareState.OFF:
if (
self.file_system_state_actual != FileSystemState.CORRUPT
and self.file_system_state_actual != FileSystemState.DESTROYED
):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so File System State (if not "
f"compromised) cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def start_file_system_scan(self) -> None:
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def update_file_system_state(self) -> None:
"""Updates file system status based on scanning/restore/repair cycle."""
# Deprecate both the action count (for restoring or reparing) and the scanning count
self.file_system_action_count -= 1
self.file_system_scanning_count -= 1
# Reparing / Restoring updates
if self.file_system_action_count <= 0:
self.file_system_action_count = 0
if (
self.file_system_state_actual == FileSystemState.REPAIRING
or self.file_system_state_actual == FileSystemState.RESTORING
):
self.file_system_state_actual = FileSystemState.GOOD
self.file_system_state_observed = FileSystemState.GOOD
# Scanning updates
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
self.file_system_state_observed = self.file_system_state_actual
self.file_system_scanning = False
self.file_system_scanning_count = 0
def update_resetting_status(self) -> None:
"""Updates the reset count & makes software and file state to GOOD."""
super().update_resetting_status()
if self.resetting_count <= 0:
self.file_system_state_actual = FileSystemState.GOOD
self.software_state = SoftwareState.GOOD
def update_booting_status(self) -> None:
"""Updates the booting software and file state to GOOD."""
super().update_booting_status()
if self.booting_count <= 0:
self.file_system_state_actual = FileSystemState.GOOD
self.software_state = SoftwareState.GOOD

View File

@@ -1,79 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The base Node class."""
from typing import Final
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
class Node:
"""Node class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
) -> None:
"""
Initialise a node.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
self.node_id: Final[str] = node_id
self.name: Final[str] = name
self.node_type: Final[NodeType] = node_type
self.priority = priority
self.hardware_state: HardwareState = hardware_state
self.resetting_count: int = 0
self.config_values: TrainingConfig = config_values
self.booting_count: int = 0
self.shutting_down_count: int = 0
def __repr__(self) -> str:
"""Returns the name of the node."""
return self.name
def turn_on(self) -> None:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self) -> None:
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
self.shutting_down_count = self.config_values.node_shutdown_duration
def reset(self) -> None:
"""Sets the node state to Resetting and starts the reset count."""
self.hardware_state = HardwareState.RESETTING
self.resetting_count = self.config_values.node_reset_duration
def update_resetting_status(self) -> None:
"""Updates the resetting count."""
self.resetting_count -= 1
if self.resetting_count <= 0:
self.resetting_count = 0
self.hardware_state = HardwareState.ON
def update_booting_status(self) -> None:
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self) -> None:
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:
self.shutting_down_count = 0
self.hardware_state = HardwareState.OFF

View File

@@ -1,94 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
class NodeStateInstructionGreen(object):
"""The Node State Instruction class."""
def __init__(
self,
_id: str,
_start_step: int,
_end_step: int,
_node_id: str,
_node_pol_type: "NodePOLType",
_service_name: str,
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
) -> None:
"""
Initialise the Node State Instruction.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _node_id: The id of the associated node
:param _node_pol_type: The pattern of life type
:param _service_name: The service name
:param _state: The state (node or service)
"""
self.id = _id
self.start_step = _start_step
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type: "NodePOLType" = _node_pol_type
self.service_name: str = _service_name # Not used when not a service instruction
# TODO: confirm type of state
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
def get_start_step(self) -> int:
"""
Gets the start step.
Returns:
The start step
"""
return self.start_step
def get_end_step(self) -> int:
"""
Gets the end step.
Returns:
The end step
"""
return self.end_step
def get_node_id(self) -> str:
"""
Gets the node ID.
Returns:
The node ID
"""
return self.node_id
def get_node_pol_type(self) -> "NodePOLType":
"""
Gets the node pattern of life type (enum).
Returns:
The node pattern of life type (enum)
"""
return self.node_pol_type
def get_service_name(self) -> str:
"""
Gets the service name.
Returns:
The service name
"""
return self.service_name
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).
Returns:
The state (node or service)
"""
return self.state

View File

@@ -1,143 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Defines node behaviour for Green PoL."""
from typing import TYPE_CHECKING, Union
from primaite.common.enums import NodePOLType
if TYPE_CHECKING:
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
class NodeStateInstructionRed:
"""The Node State Instruction class."""
def __init__(
self,
_id: str,
_start_step: int,
_end_step: int,
_target_node_id: str,
_pol_initiator: "NodePOLInitiator",
_pol_type: NodePOLType,
pol_protocol: str,
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
_pol_source_node_id: str,
_pol_source_node_service: str,
_pol_source_node_service_state: str,
) -> None:
"""
Initialise the Node State Instruction for the red agent.
:param _id: The node state instruction id
:param _start_step: The start step of the instruction
:param _end_step: The end step of the instruction
:param _target_node_id: The id of the associated node
:param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE)
:param _pol_type: The pattern of life type
:param pol_protocol: The pattern of life protocol/service affected
:param _pol_state: The state (node or service)
:param _pol_source_node_id: The source node Id (used for initiator type SERVICE)
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
"""
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.target_node_id: str = _target_node_id
self.initiator: "NodePOLInitiator" = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name: str = pol_protocol # Not used when not a service instruction
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
self.source_node_id: str = _pol_source_node_id
self.source_node_service: str = _pol_source_node_service
self.source_node_service_state = _pol_source_node_service_state
def get_start_step(self) -> int:
"""
Gets the start step.
Returns:
The start step
"""
return self.start_step
def get_end_step(self) -> int:
"""
Gets the end step.
Returns:
The end step
"""
return self.end_step
def get_target_node_id(self) -> str:
"""
Gets the node ID.
Returns:
The node ID
"""
return self.target_node_id
def get_initiator(self) -> "NodePOLInitiator":
"""
Gets the initiator.
Returns:
The initiator
"""
return self.initiator
def get_pol_type(self) -> NodePOLType:
"""
Gets the node pattern of life type (enum).
Returns:
The node pattern of life type (enum)
"""
return self.pol_type
def get_service_name(self) -> str:
"""
Gets the service name.
Returns:
The service name
"""
return self.service_name
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
"""
Gets the state (node or service).
Returns:
The state (node or service)
"""
return self.state
def get_source_node_id(self) -> str:
"""
Gets the source node id (used for initiator type SERVICE).
Returns:
The source node id
"""
return self.source_node_id
def get_source_node_service(self) -> str:
"""
Gets the source node service (used for initiator type SERVICE).
Returns:
The source node service
"""
return self.source_node_service
def get_source_node_service_state(self) -> str:
"""
Gets the source node service state (used for initiator type SERVICE).
Returns:
The source node service state
"""
return self.source_node_service_state

View File

@@ -1,42 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The Passive Node class (i.e. an actuator)."""
from primaite.common.enums import HardwareState, NodeType, Priority
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
class PassiveNode(Node):
"""The Passive Node class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
config_values: TrainingConfig,
) -> None:
"""
Initialise a passive node.
:param node_id: The node id.
:param name: The name of the node.
:param node_type: The type of the node.
:param priority: The priority of the node.
:param hardware_state: The state of the node.
:param config_values: Config values.
"""
# Pass through to Super for now
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
@property
def ip_address(self) -> str:
"""
Gets the node IP address as an empty string.
No concept of IP address for passive nodes for now.
:return: The node IP address.
"""
return ""

View File

@@ -1,190 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""A Service Node (i.e. not an actuator)."""
import logging
from typing import Dict, Final
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
class ServiceNode(ActiveNode):
"""ServiceNode class."""
def __init__(
self,
node_id: str,
name: str,
node_type: NodeType,
priority: Priority,
hardware_state: HardwareState,
ip_address: str,
software_state: SoftwareState,
file_system_state: FileSystemState,
config_values: TrainingConfig,
) -> None:
"""
Initialise a Service Node.
:param node_id: The node ID
:param name: The node name
:param node_type: The node type (enum)
:param priority: The node priority (enum)
:param hardware_state: The node Hardware State
:param ip_address: The node IP address
:param software_state: The node Software State
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(
node_id,
name,
node_type,
priority,
hardware_state,
ip_address,
software_state,
file_system_state,
config_values,
)
self.services: Dict[str, Service] = {}
def add_service(self, service: Service) -> None:
"""
Adds a service to the node.
:param service: The service to add
"""
self.services[service.name] = service
def has_service(self, protocol_name: str) -> bool:
"""
Indicates whether a service is on a node.
:param protocol_name: The service (protocol)e.
:return: True if service (protocol) is on the node, otherwise False.
"""
for service_key, service_value in self.services.items():
if service_key == protocol_name:
return True
return False
def service_running(self, protocol_name: str) -> bool:
"""
Indicates whether a service is in a running state on the node.
:param protocol_name: The service (protocol)
:return: True if service (protocol) is in a running state on the node, otherwise False.
"""
for service_key, service_value in self.services.items():
if service_key == protocol_name:
if service_value.software_state != SoftwareState.PATCHING:
return True
else:
return False
return False
def service_is_overwhelmed(self, protocol_name: str) -> bool:
"""
Indicates whether a service is in an overwhelmed state on the node.
:param protocol_name: The service (protocol)
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
"""
for service_key, service_value in self.services.items():
if service_key == protocol_name:
if service_value.software_state == SoftwareState.OVERWHELMED:
return True
else:
return False
return False
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
:param protocol_name: The service (protocol).
:param software_state: The software_state.
"""
if self.hardware_state != HardwareState.OFF:
service_key = protocol_name
service_value = self.services.get(service_key)
if service_value:
# Can't set to compromised if you're in a patching state
if (
software_state == SoftwareState.COMPROMISED
and service_value.software_state != SoftwareState.PATCHING
) or software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
f"cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.services[<key>]:{protocol_name}, "
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
"""
Sets the software_state of a service (protocol) on the node.
Done if the software_state is not "compromised".
:param protocol_name: The service (protocol).
:param software_state: The software_state.
"""
if self.hardware_state != HardwareState.OFF:
service_key = protocol_name
service_value = self.services.get(service_key)
if service_value:
if service_value.software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
f"cannot be changed. "
f"Node.node_id:{self.node_id}, "
f"Node.hardware_state:{self.hardware_state}, "
f"Node.services[<key>]:{protocol_name}, "
f"Node.services[<key>].software_state:{software_state}"
)
def get_service_state(self, protocol_name: str) -> SoftwareState:
"""
Gets the state of a service.
:return: The software_state of the service.
"""
service_key = protocol_name
service_value = self.services.get(service_key)
if service_value:
return service_value.software_state
def update_services_patching_status(self) -> None:
"""Updates the patching counter for any service that are patching."""
for service_key, service_value in self.services.items():
if service_value.software_state == SoftwareState.PATCHING:
service_value.reduce_patching_count()
def update_resetting_status(self) -> None:
"""Update resetting counter and set software state if it reached 0."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self) -> None:
"""Update booting counter and set software to good if it reached 0."""
super().update_booting_status()
if self.booting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD

View File

View File

@@ -1,107 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.networks import arcd_uc2_network\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = arcd_uc2_network()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### set up some services to test if actions are working"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv = net.get_node_by_hostname('database_server')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.system.services.database_service import DatabaseService"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_svc = DatabaseService(file_system=db_serv.file_system)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv.install_service(db_svc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_serv.describe_state()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,2 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Pattern of Life- Represents the actions of users on the network."""

View File

@@ -1,264 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements Pattern of Life on the network (nodes and links)."""
from typing import Dict
from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_VERBOSE: bool = False
def apply_iers(
network: MultiGraph,
nodes: Dict[str, NodeUnion],
links: Dict[str, Link],
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
) -> None:
"""
Applies IERs to the links (link pattern of life).
Args:
network: The network modelled in the environment
nodes: The nodes within the environment
links: The links within the environment
iers: The IERs to apply to the links
acl: The Access Control List
step: The step number.
"""
if _VERBOSE:
print("Applying IERs")
# Go through each IER and check the conditions for it being applied
# If everything is in place, apply the IER protocol load to the relevant links
for ier_key, ier_value in iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
protocol = ier_value.get_protocol()
port = ier_value.get_port()
load = ier_value.get_load()
source_node_id = ier_value.get_source_node_id()
dest_node_id = ier_value.get_dest_node_id()
# Need to set the running status to false first for all IERs
ier_value.set_is_running(False)
source_valid = True
dest_valid = True
acl_block = False
if step >= start_step and step <= stop_step:
# continue --------------------------
# Get the source and destination node for this link
source_node = nodes[source_node_id]
dest_node = nodes[dest_node_id]
# 1. Check the source node situation
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
if source_node.node_type == NodeType.SWITCH:
# It's a switch
if (
source_node.hardware_state == HardwareState.ON
and source_node.software_state != SoftwareState.PATCHING
):
source_valid = True
else:
# IER no longer valid
source_valid = False
elif source_node.node_type == NodeType.ACTUATOR:
# It's an actuator
# TO DO
pass
else:
# It's not a switch or an actuator (so active node)
if (
source_node.hardware_state == HardwareState.ON
and source_node.software_state != SoftwareState.PATCHING
):
if source_node.has_service(protocol):
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
source_valid = True
else:
source_valid = False
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
source_valid = False
else:
# Do nothing - IER no longer valid
source_valid = False
# 2. Check the dest node situation
if dest_node.node_type == NodeType.SWITCH:
# It's a switch
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
dest_valid = True
else:
# IER no longer valid
dest_valid = False
elif dest_node.node_type == NodeType.ACTUATOR:
# It's an actuator
pass
else:
# It's not a switch or an actuator (so active node)
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
if dest_node.has_service(protocol):
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
dest_valid = True
else:
dest_valid = False
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
dest_valid = False
else:
# Do nothing - IER no longer valid
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
"ACL block on source: "
+ source_node.ip_address
+ ", dest: "
+ dest_node.ip_address
+ ", protocol: "
+ protocol
+ ", port: "
+ port
)
else:
if _VERBOSE:
print("No ACL block")
# Check whether both the source and destination are valid, and there's no ACL block
if source_valid and dest_valid and not acl_block:
# Load up the link(s) with the traffic
if _VERBOSE:
print("Source, Dest and ACL valid")
# Get the shortest path (i.e. nodes) between source and destination
path_node_list = shortest_path(network, source_node, dest_node)
path_node_list_length = len(path_node_list)
path_valid = True
# We might have a switch in the path, so check all nodes are operational
for node in path_node_list:
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
path_valid = False
if path_valid:
if _VERBOSE:
print("Applying IER to link(s)")
count = 0
link_capacity_exceeded = False
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
pass
count += 1
# Check whether the link capacity for any links on this path have been exceeded
if link_capacity_exceeded == False:
# Now apply the new loads to the links
count = 0
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]
# Add the load from this IER
link.add_protocol_load(protocol, load)
count += 1
# This IER is now valid, so set it to running
ier_value.set_is_running(True)
else:
# One of the nodes is not operational
if _VERBOSE:
print("Path not valid - one or more nodes not operational")
pass
else:
if _VERBOSE:
print("Source, Dest or ACL were not valid")
pass
# ------------------------------------
else:
# Do nothing - IER no longer valid
pass
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[str, NodeStateInstructionGreen],
step: int,
) -> None:
"""
Applies node pattern of life.
Args:
nodes: The nodes within the environment
node_pol: The node pattern of life to apply
step: The step number.
"""
if _VERBOSE:
print("Applying Node PoL")
for key, node_instruction in node_pol.items():
start_step = node_instruction.get_start_step()
stop_step = node_instruction.get_end_step()
node_id = node_instruction.get_node_id()
node_pol_type = node_instruction.get_node_pol_type()
service_name = node_instruction.get_service_name()
state = node_instruction.get_state()
if step >= start_step and step <= stop_step:
# continue --------------------------
node = nodes[node_id]
if node_pol_type == NodePOLType.OPERATING:
# Change hardware state
node.hardware_state = state
elif node_pol_type == NodePOLType.OS:
# Change OS state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_software_state_if_not_compromised(state)
elif node_pol_type == NodePOLType.SERVICE:
# Change a service state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ServiceNode):
node.set_service_state_if_not_compromised(service_name, state)
else:
# Change the file system status
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_file_system_state_if_not_compromised(state)
else:
# PoL is not valid in this time step
pass

View File

@@ -1,147 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""
Information Exchange Requirements for APE.
Used to represent an information flow from source to destination.
"""
class IER(object):
"""Information Exchange Requirement class."""
def __init__(
self,
_id: str,
_start_step: int,
_end_step: int,
_load: int,
_protocol: str,
_port: str,
_source_node_id: str,
_dest_node_id: str,
_mission_criticality: int,
_running: bool = False,
) -> None:
"""
Initialise an Information Exchange Request.
:param _id: The IER id
:param _start_step: The step when this IER should start
:param _end_step: The step when this IER should end
:param _load: The load this IER should put on a link (bps)
:param _protocol: The protocol of this IER
:param _port: The port this IER runs on
:param _source_node_id: The source node ID
:param _dest_node_id: The destination node ID
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
:param _running: Indicates whether the IER is currently running
"""
self.id: str = _id
self.start_step: int = _start_step
self.end_step: int = _end_step
self.source_node_id: str = _source_node_id
self.dest_node_id: str = _dest_node_id
self.load: int = _load
self.protocol: str = _protocol
self.port: str = _port
self.mission_criticality: int = _mission_criticality
self.running: bool = _running
def get_id(self) -> str:
"""
Gets IER ID.
Returns:
IER ID
"""
return self.id
def get_start_step(self) -> int:
"""
Gets IER start step.
Returns:
IER start step
"""
return self.start_step
def get_end_step(self) -> int:
"""
Gets IER end step.
Returns:
IER end step
"""
return self.end_step
def get_load(self) -> int:
"""
Gets IER load.
Returns:
IER load
"""
return self.load
def get_protocol(self) -> str:
"""
Gets IER protocol.
Returns:
IER protocol
"""
return self.protocol
def get_port(self) -> str:
"""
Gets IER port.
Returns:
IER port
"""
return self.port
def get_source_node_id(self) -> str:
"""
Gets IER source node ID.
Returns:
IER source node ID
"""
return self.source_node_id
def get_dest_node_id(self) -> str:
"""
Gets IER destination node ID.
Returns:
IER destination node ID
"""
return self.dest_node_id
def get_is_running(self) -> bool:
"""
Informs whether the IER is currently running.
Returns:
True if running
"""
return self.running
def set_is_running(self, _value: bool) -> None:
"""
Sets the running state of the IER.
Args:
_value: running status
"""
self.running = _value
def get_mission_criticality(self) -> int:
"""
Gets the IER mission criticality (used in the reward function).
Returns:
Mission criticality value (0 lowest to 5 highest)
"""
return self.mission_criticality

View File

@@ -1,353 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
from typing import Dict
from networkx import MultiGraph, shortest_path
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_LOGGER = getLogger(__name__)
_VERBOSE: bool = False
def apply_red_agent_iers(
network: MultiGraph,
nodes: Dict[str, NodeUnion],
links: Dict[str, Link],
iers: Dict[str, IER],
acl: AccessControlList,
step: int,
) -> None:
"""
Applies IERs to the links (link POL) resulting from red agent attack.
Args:
network: The network modelled in the environment
nodes: The nodes within the environment
links: The links within the environment
iers: The red agent IERs to apply to the links
acl: The Access Control List
step: The step number.
"""
# Go through each IER and check the conditions for it being applied
# If everything is in place, apply the IER protocol load to the relevant links
for ier_key, ier_value in iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
protocol = ier_value.get_protocol()
port = ier_value.get_port()
load = ier_value.get_load()
source_node_id = ier_value.get_source_node_id()
dest_node_id = ier_value.get_dest_node_id()
# Need to set the running status to false first for all IERs
ier_value.set_is_running(False)
source_valid = True
dest_valid = True
acl_block = False
if step >= start_step and step <= stop_step:
# continue --------------------------
# Get the source and destination node for this link
source_node = nodes[source_node_id]
dest_node = nodes[dest_node_id]
# 1. Check the source node situation
if source_node.node_type == NodeType.SWITCH:
# It's a switch
if source_node.hardware_state == HardwareState.ON:
source_valid = True
else:
# IER no longer valid
source_valid = False
elif source_node.node_type == NodeType.ACTUATOR:
# It's an actuator
# TO DO
pass
else:
# It's not a switch or an actuator (so active node)
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
# to change according to duck typing.
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
source_valid = True
else:
source_valid = False
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
source_valid = False
else:
# Do nothing - IER no longer valid
source_valid = False
# 2. Check the dest node situation
if dest_node.node_type == NodeType.SWITCH:
# It's a switch
if dest_node.hardware_state == HardwareState.ON:
dest_valid = True
else:
# IER no longer valid
dest_valid = False
elif dest_node.node_type == NodeType.ACTUATOR:
# It's an actuator
pass
else:
# It's not a switch or an actuator (so active node)
if dest_node.hardware_state == HardwareState.ON:
if dest_node.has_service(protocol):
# We don't care what state the destination service is in for an IER
dest_valid = True
else:
# Do nothing - IER is not valid on this node
# (This shouldn't happen if the IER has been written correctly)
dest_valid = False
else:
# Do nothing - IER no longer valid
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
"ACL block on source: "
+ source_node.ip_address
+ ", dest: "
+ dest_node.ip_address
+ ", protocol: "
+ protocol
+ ", port: "
+ port
)
else:
if _VERBOSE:
print("No ACL block")
# Check whether both the source and destination are valid, and there's no ACL block
if source_valid and dest_valid and not acl_block:
# Load up the link(s) with the traffic
if _VERBOSE:
print("Source, Dest and ACL valid")
# Get the shortest path (i.e. nodes) between source and destination
path_node_list = shortest_path(network, source_node, dest_node)
path_node_list_length = len(path_node_list)
path_valid = True
# We might have a switch in the path, so check all nodes are operational
# We're assuming here that red agents can get past switches that are patching
for node in path_node_list:
if node.hardware_state != HardwareState.ON:
path_valid = False
if path_valid:
if _VERBOSE:
print("Applying IER to link(s)")
count = 0
link_capacity_exceeded = False
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
pass
count += 1
# Check whether the link capacity for any links on this path have been exceeded
if link_capacity_exceeded == False:
# Now apply the new loads to the links
count = 0
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count],
path_node_list[count + 1],
)
link_id = edge_dict[0].get("id")
link = links[link_id]
# Add the load from this IER
link.add_protocol_load(protocol, load)
count += 1
# This IER is now valid, so set it to running
ier_value.set_is_running(True)
if _VERBOSE:
print("Red IER was allowed to run in step " + str(step))
else:
# One of the nodes is not operational
if _VERBOSE:
print("Path not valid - one or more nodes not operational")
pass
else:
if _VERBOSE:
print("Red IER was NOT allowed to run in step " + str(step))
print("Source, Dest or ACL were not valid")
pass
# ------------------------------------
else:
# Do nothing - IER no longer valid
pass
pass
def apply_red_agent_node_pol(
nodes: Dict[str, NodeUnion],
iers: Dict[str, IER],
node_pol: Dict[str, NodeStateInstructionRed],
step: int,
) -> None:
"""
Applies node pattern of life.
Args:
nodes: The nodes within the environment
iers: The red agent IERs
node_pol: The red agent node pattern of life to apply
step: The step number.
"""
if _VERBOSE:
print("Applying Node Red Agent PoL")
for key, node_instruction in node_pol.items():
start_step = node_instruction.get_start_step()
stop_step = node_instruction.get_end_step()
target_node_id = node_instruction.get_target_node_id()
initiator = node_instruction.get_initiator()
pol_type = node_instruction.get_pol_type()
service_name = node_instruction.get_service_name()
state = node_instruction.get_state()
source_node_id = node_instruction.get_source_node_id()
source_node_service_name = node_instruction.get_source_node_service()
source_node_service_state_value = node_instruction.get_source_node_service_state()
passed_checks = False
if step >= start_step and step <= stop_step:
# continue --------------------------
target_node: NodeUnion = nodes[target_node_id]
# check if the initiator type is a str, and if so, cast it as
# NodePOLInitiator
if isinstance(initiator, str):
initiator = NodePOLInitiator[initiator]
# Based the action taken on the initiator type
if initiator == NodePOLInitiator.DIRECT:
# No conditions required, just apply the change
passed_checks = True
elif initiator == NodePOLInitiator.IER:
# Need to check there is a red IER incoming
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
elif initiator == NodePOLInitiator.SERVICE:
# Need to check the condition of a service on another node
source_node = nodes[source_node_id]
if source_node.has_service(source_node_service_name):
if (
source_node.get_service_state(source_node_service_name)
== SoftwareState[source_node_service_state_value]
):
passed_checks = True
else:
# Do nothing, no matching state value
pass
else:
# Do nothing, service not on this node
pass
else:
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
# Only apply the PoL if the checks have passed (based on the initiator type)
if passed_checks:
# Apply the change
if pol_type == NodePOLType.OPERATING:
# Change hardware state
target_node.hardware_state = state
elif pol_type == NodePOLType.OS:
# Change OS state
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.software_state = state
elif pol_type == NodePOLType.SERVICE:
# Change a service state
if isinstance(target_node, ServiceNode):
target_node.set_service_state(service_name, state)
else:
# Change the file system status
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
else:
# PoL is not valid in this time step
pass
def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool:
"""Checks if the RED IER is incoming.
:param node: Destination node of the IER
:type node: NodeUnion
:param iers: Directory of IERs
:type iers: Dict[str,IER]
:param node_pol_type: Type of Pattern-Of-Life
:type node_pol_type: NodePOLType
:return: Whether the RED IER is incoming.
:rtype: bool
"""
node_id = node.node_id
for ier_key, ier_value in iers.items():
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
if (
node_pol_type == NodePOLType.OPERATING
or node_pol_type == NodePOLType.OS
or node_pol_type == NodePOLType.FILE
):
# It's looking to change hardware state, file system or SoftwareState, so valid
return True
elif node_pol_type == NodePOLType.SERVICE:
# Check if the service is present on the node and running
ier_protocol = ier_value.get_protocol()
if isinstance(node, ServiceNode):
if node.has_service(ier_protocol):
if node.service_running(ier_protocol):
# Matching service is present and running, so valid
return True
else:
# Service is present, but not running
return False
else:
# Service is not present
return False
else:
# Not a service node
return False
else:
# Shouldn't get here - instruction type is undefined
return False
else:
# The IER destination is not this node, or the IER is not running
return False

View File

@@ -1,228 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, Final, Optional, Tuple, Union
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
# from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.utils.session_metadata_parser import parse_session_metadata
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
_LOGGER = getLogger(__name__)
class PrimaiteSession:
"""
The PrimaiteSession class.
Provides a single learning and evaluation entry point for all training and lay down configurations.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
legacy_training_config: bool = False,
legacy_lay_down_config: bool = False,
) -> None:
"""
The PrimaiteSession constructor.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = session_path # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
self.legacy_training_config = legacy_training_config
self.legacy_lay_down_config = legacy_lay_down_config
# check if session path is provided
if session_path is not None:
# set load_session to true
self.is_load_session = True
if not isinstance(session_path, Path):
session_path = Path(session_path)
# if a session path is provided, load it
if not session_path.exists():
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path, str]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path, legacy_training_config
)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) # noqa
def setup(self) -> None:
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
self._agent_session = DummyAgent(
self._training_config_path, self._lay_down_config_path, self.session_path
)
else:
# Invalid AgentFramework AgentIdentifier combo
raise ValueError
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
self._agent_session = SB3Agent(
self._training_config_path,
self._lay_down_config_path,
self.session_path,
self.legacy_training_config,
self.legacy_lay_down_config,
)
# elif self._training_config.agent_framework == AgentFramework.RLLIB:
# _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# # Ray RLlib Agent
# self._agent_session = RLlibAgent(
# self._training_config_path, self._lay_down_config_path, self.session_path
# )
else:
# Invalid AgentFramework
raise ValueError
self.session_path: Path = self._agent_session.session_path
self.timestamp_str: str = self._agent_session.timestamp_str
self.learning_path: Path = self._agent_session.learning_path
self.evaluation_path: Path = self._agent_session.evaluation_path
def learn(
self,
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.EVAL:
self._agent_session.learn(**kwargs)
def evaluate(
self,
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(**kwargs)
def close(self) -> None:
"""Closes the agent."""
self._agent_session.close()
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the learn av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.learning_path / csv_file)
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
"""Get the eval av reward per episode from file."""
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the learn all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.learning_path / csv_file)
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
"""Get the eval all transactions from file."""
csv_file = f"all_transactions_{self.timestamp_str}.csv"
return all_transactions_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
return json.load(file)

View File

@@ -1,14 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from primaite import getLogger
_LOGGER = getLogger(__name__)
def run() -> None:
"""Perform the full clean-up."""
pass
if __name__ == "__main__":
run()

View File

@@ -3,7 +3,7 @@ from enum import Enum
from typing import Dict
from primaite import getLogger
from primaite.simulator.core import SimComponent
from src.primaite.simulator.core import SimComponent
_LOGGER = getLogger(__name__)

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import Dict, Final, List, Literal, Tuple
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent
from primaite.simulator.domain.account import Account, AccountType
from src.primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent
from src.primaite.simulator.domain.account import Account, AccountType
# placeholder while these objects don't yet exist

View File

@@ -7,8 +7,8 @@ from pathlib import Path
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
from src.primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
from src.primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
_LOGGER = getLogger(__name__)

View File

@@ -7,11 +7,11 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file import File
from primaite.simulator.file_system.file_type import FileType
from primaite.simulator.file_system.folder import Folder
from primaite.simulator.system.core.sys_log import SysLog
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
from src.primaite.simulator.file_system.file import File
from src.primaite.simulator.file_system.file_type import FileType
from src.primaite.simulator.file_system.folder import Folder
from src.primaite.simulator.system.core.sys_log import SysLog
_LOGGER = getLogger(__name__)

View File

@@ -6,8 +6,8 @@ from enum import Enum
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.system.core.sys_log import SysLog
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
from src.primaite.simulator.system.core.sys_log import SysLog
_LOGGER = getLogger(__name__)
@@ -41,19 +41,19 @@ def convert_size(size_bytes: int) -> str:
class FileSystemItemHealthStatus(Enum):
"""Status of the FileSystemItem."""
GOOD = 0
GOOD = 1
"""File/Folder is OK."""
COMPROMISED = 1
COMPROMISED = 2
"""File/Folder is quarantined."""
CORRUPT = 2
CORRUPT = 3
"""File/Folder is corrupted."""
RESTORING = 3
RESTORING = 4
"""File/Folder is in the process of being restored."""
REPAIRING = 3
REPAIRING = 5
"""File/Folder is in the process of being repaired."""
@@ -93,8 +93,8 @@ class FileSystemItemABC(SimComponent):
"""
state = super().describe_state()
state["name"] = self.name
state["status"] = self.health_status.name
state["visible_status"] = self.visible_health_status.name
state["health_status"] = self.health_status.value
state["visible_status"] = self.visible_health_status.value
state["previous_hash"] = self.previous_hash
state["revealed_to_red"] = self.revealed_to_red
return state

View File

@@ -5,9 +5,9 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.file_system.file import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
from src.primaite.simulator.core import RequestManager, RequestType
from src.primaite.simulator.file_system.file import File
from src.primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
_LOGGER = getLogger(__name__)

View File

@@ -6,12 +6,12 @@ from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
from src.primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
from src.primaite.simulator.network.hardware.nodes.computer import Computer
from src.primaite.simulator.network.hardware.nodes.router import Router
from src.primaite.simulator.network.hardware.nodes.server import Server
from src.primaite.simulator.network.hardware.nodes.switch import Switch
_LOGGER = getLogger(__name__)
@@ -160,8 +160,8 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {i for i, node in self._node_id_map.items()},
"links": {i: link.describe_state() for i, link in self._link_id_map.items()},
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
}
)
return state
@@ -218,7 +218,9 @@ class Network(SimComponent):
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
self._node_request_manager.remove_request(name=node.uuid)
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
def connect(
self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs
) -> Optional[Link]:
"""
Connect two endpoints on the network by creating a link between their NICs/SwitchPorts.
@@ -245,6 +247,7 @@ class Network(SimComponent):
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
link.parent = self
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
return link
def remove_link(self, link: Link) -> None:
"""Disconnect a link from the network.

View File

@@ -10,22 +10,22 @@ from typing import Any, Dict, Literal, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.packet_capture import PacketCapture
from primaite.simulator.system.core.session_manager import SessionManager
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from src.primaite.exceptions import NetworkError
from src.primaite.simulator import SIM_OUTPUT
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
from src.primaite.simulator.domain.account import Account
from src.primaite.simulator.file_system.file_system import FileSystem
from src.primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
from src.primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from src.primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from src.primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
from src.primaite.simulator.system.applications.application import Application
from src.primaite.simulator.system.core.packet_capture import PacketCapture
from src.primaite.simulator.system.core.session_manager import SessionManager
from src.primaite.simulator.system.core.software_manager import SoftwareManager
from src.primaite.simulator.system.core.sys_log import SysLog
from src.primaite.simulator.system.processes.process import Process
from src.primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
@@ -859,14 +859,14 @@ class ICMP:
class NodeOperatingState(Enum):
"""Enumeration of Node Operating States."""
OFF = 0
"The node is powered off."
ON = 1
"The node is powered on."
SHUTTING_DOWN = 2
"The node is in the process of shutting down."
OFF = 2
"The node is powered off."
BOOTING = 3
"The node is in the process of booting up."
SHUTTING_DOWN = 4
"The node is in the process of shutting down."
class Node(SimComponent):
@@ -962,6 +962,7 @@ class Node(SimComponent):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):
kwargs["software_manager"] = SoftwareManager(
parent_node=self,
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system"),
@@ -1369,7 +1370,8 @@ class Node(SimComponent):
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))
def uninstall_service(self, service: Service) -> None:
"""Uninstall and completely remove service from this node.
"""
Uninstall and completely remove service from this node.
:param service: Service object that is currently associated with this node.
:type service: Service

View File

@@ -1,7 +1,7 @@
from primaite.simulator.network.hardware.base import NIC, Node
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from src.primaite.simulator.network.hardware.base import NIC, Node
from src.primaite.simulator.system.applications.web_browser import WebBrowser
from src.primaite.simulator.system.services.dns.dns_client import DNSClient
from src.primaite.simulator.system.services.ftp.ftp_client import FTPClient
class Computer(Node):

View File

@@ -7,12 +7,12 @@ from typing import Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
from primaite.simulator.system.core.sys_log import SysLog
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
from src.primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from src.primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from src.primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
from src.primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
from src.primaite.simulator.system.core.sys_log import SysLog
class ACLAction(Enum):
@@ -58,7 +58,14 @@ class ACLRule(SimComponent):
:return: A dictionary representing the current state.
"""
pass
state = super().describe_state()
state["action"] = self.action.value
state["protocol"] = self.protocol.value if self.protocol else None
state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None
state["src_port"] = self.src_port.value if self.src_port else None
state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None
state["dst_port"] = self.dst_port.value if self.dst_port else None
return state
class AccessControlList(SimComponent):
@@ -104,11 +111,11 @@ class AccessControlList(SimComponent):
RequestType(
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
IPProtocol[request[1]],
IPv4Address[request[2]],
Port[request[3]],
IPv4Address[request[4]],
Port[request[5]],
None if request[1] == "ALL" else IPProtocol[request[1]],
IPv4Address(request[2]),
None if request[3] == "ALL" else Port[request[3]],
IPv4Address(request[4]),
None if request[5] == "ALL" else Port[request[5]],
int(request[6]),
)
),
@@ -123,7 +130,12 @@ class AccessControlList(SimComponent):
:return: A dictionary representing the current state.
"""
pass
state = super().describe_state()
state["implicit_action"] = self.implicit_action.value
state["implicit_rule"] = self.implicit_rule.describe_state()
state["max_acl_rules"] = self.max_acl_rules
state["acl"] = {i: r.describe_state() if isinstance(r, ACLRule) else None for i, r in enumerate(self._acl)}
return state
@property
def acl(self) -> List[Optional[ACLRule]]:
@@ -648,7 +660,10 @@ class Router(Node):
:return: A dictionary representing the current state.
"""
pass
state = super().describe_state()
state["num_ports"] = (self.num_ports,)
state["acl"] = (self.acl.describe_state(),)
return state
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
"""

View File

@@ -1,4 +1,4 @@
from primaite.simulator.network.hardware.nodes.computer import Computer
from src.primaite.simulator.network.hardware.nodes.computer import Computer
class Server(Computer):

View File

@@ -3,10 +3,9 @@ from typing import Dict
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.links.link import Link
from primaite.simulator.network.hardware.base import Node, SwitchPort
from primaite.simulator.network.transmission.data_link_layer import Frame
from src.primaite.exceptions import NetworkError
from src.primaite.simulator.network.hardware.base import Link, Node, SwitchPort
from src.primaite.simulator.network.transmission.data_link_layer import Frame
_LOGGER = getLogger(__name__)
@@ -55,12 +54,11 @@ class Switch(Node):
:return: Current state of this object and child objects.
"""
return {
"uuid": self.uuid,
"num_ports": self.num_ports, # redundant?
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
}
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
state["num_ports"] = self.num_ports # redundant?
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
return state
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
"""

View File

@@ -1,19 +1,19 @@
from ipaddress import IPv4Address
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.web_server.web_server import WebServer
from src.primaite.simulator.network.container import Network
from src.primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from src.primaite.simulator.network.hardware.nodes.computer import Computer
from src.primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from src.primaite.simulator.network.hardware.nodes.server import Server
from src.primaite.simulator.network.hardware.nodes.switch import Switch
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
from src.primaite.simulator.network.transmission.transport_layer import Port
from src.primaite.simulator.system.applications.database_client import DatabaseClient
from src.primaite.simulator.system.services.database.database_service import DatabaseService
from src.primaite.simulator.system.services.dns.dns_server import DNSServer
from src.primaite.simulator.system.services.ftp.ftp_server import FTPServer
from src.primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from src.primaite.simulator.system.services.web_server.web_server import WebServer
def client_server_routed() -> Network:

View File

@@ -5,7 +5,7 @@ from typing import Optional
from pydantic import BaseModel
from primaite.simulator.network.protocols.packet import DataPacket
from src.primaite.simulator.network.protocols.packet import DataPacket
class ARPEntry(BaseModel):

View File

@@ -5,7 +5,7 @@ from typing import Optional
from pydantic import BaseModel
from primaite.simulator.network.protocols.packet import DataPacket
from src.primaite.simulator.network.protocols.packet import DataPacket
class DNSRequest(BaseModel):

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional, Union
from primaite.simulator.network.protocols.packet import DataPacket
from src.primaite.simulator.network.protocols.packet import DataPacket
class FTPCommand(Enum):

View File

@@ -1,6 +1,6 @@
from enum import Enum
from primaite.simulator.network.protocols.packet import DataPacket
from src.primaite.simulator.network.protocols.packet import DataPacket
class HttpRequestMethod(Enum):

View File

@@ -4,12 +4,12 @@ from typing import Any, Optional
from pydantic import BaseModel
from primaite import getLogger
from primaite.simulator.network.protocols.arp import ARPPacket
from primaite.simulator.network.protocols.packet import DataPacket
from primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
from primaite.simulator.network.utils import convert_bytes_to_megabits
from src.primaite.simulator.network.protocols.arp import ARPPacket
from src.primaite.simulator.network.protocols.packet import DataPacket
from src.primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol
from src.primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
from src.primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
from src.primaite.simulator.network.utils import convert_bytes_to_megabits
_LOGGER = getLogger(__name__)

View File

@@ -1,8 +1,8 @@
from typing import Dict
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.controller import DomainController
from primaite.simulator.network.container import Network
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
from src.primaite.simulator.domain.controller import DomainController
from src.primaite.simulator.network.container import Network
class Simulation(SimComponent):
@@ -27,6 +27,7 @@ class Simulation(SimComponent):
rm.add_request("network", RequestType(func=self.network._request_manager))
# pass through domain requests to the domain object
rm.add_request("domain", RequestType(func=self.domain._request_manager))
rm.add_request("do_nothing", RequestType(func=lambda request, context: ()))
return rm
def describe_state(self) -> Dict:

Some files were not shown because too many files have changed in this diff Show More