Merge branch 'dev' into bugfix/2299-check_hash_function_corrupts_files_and_folders

This commit is contained in:
Nick Todd
2024-04-23 16:40:02 +01:00
140 changed files with 9846 additions and 5182 deletions

View File

@@ -14,11 +14,11 @@ parameters:
- name: matrix
type: object
default:
- job_name: 'UbuntuPython38'
py: '3.8'
img: 'ubuntu-latest'
every_time: false
publish_coverage: false
# - job_name: 'UbuntuPython38'
# py: '3.8'
# img: 'ubuntu-latest'
# every_time: false
# publish_coverage: false
- job_name: 'UbuntuPython310'
py: '3.10'
img: 'ubuntu-latest'

View File

@@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
args: ['--maxkb=1000']
args: ['--maxkb=5000']
- id: mixed-line-ending
- id: requirements-txt-fixer
- repo: http://github.com/psf/black
@@ -28,3 +28,7 @@ repos:
additional_dependencies:
- flake8-docstrings
- flake8-annotations
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
hooks:
- id: nbstripout

View File

@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## 3.0.0b9
- Removed deprecated `PrimaiteSession` class.
- Upgraded pydantic to version 2.7.0
- Upgraded Ray to version >= 2.9
- Added ipywidgets to the dependencies
## [Unreleased]
- Made requests fail to reach their target if the node is off
- Added responses to requests
@@ -12,27 +18,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack.
- Changed the data manipulation scenario to include a second green agent on client 1.
- Refactored actions and observations to be configurable via object name, instead of UUID.
- Fixed a bug where ACL rules were not resetting on episode reset.
- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses
- Fixed a bug where deleted files and folders did not reset correctly on episode reset.
- Fixed a bug where service health status was using the actual health state instead of the visible health state
- Fixed a bug where the database file health status was using the incorrect value for negative rewards
- Fixed a bug preventing file actions from reaching their intended file
- Made database patch correctly take 2 timesteps instead of being immediate
- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING
- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon.
- Fix a bug where ACLs were not showing up correctly in the observation space.
- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function.
- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config.
- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config.
- Fixed an issue where the data manipulation attack was triggered at episode start.
- Fixed a bug where FTP STOR stored an additional copy on the client machine's filesystem
- Fixed a bug where the red agent acted to early
- Fixed the order of service health state
- Fixed an issue where starting a node didn't start the services on it
- Made observation space flattening optional (on by default). To turn off for an agent, change the `agent_settings.flatten_obs` setting in the config.
- Added support for SQL INSERT command.
- Added ability to log each agent's action choices in each step to a JSON file.
### Bug Fixes
- ACL rules were not resetting on episode reset.
- ACLs were not showing up correctly in the observation space.
- Blue agent's ACL actions were being applied against the wrong IP addresses
- Deleted files and folders did not reset correctly on episode reset.
- Service health status was using the actual health state instead of the visible health state
- Database file health status was using the incorrect value for negative rewards
- Preventing file actions from reaching their intended file
- The data manipulation attack was triggered at episode start.
- FTP STOR stored an additional copy on the client machine's filesystem
- The red agent acted to early
- Order of service health state
- Starting a node didn't start the services on it
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
### Added
@@ -51,8 +60,12 @@ a Service/Application another machine.
SessionManager.
- Permission System - each action can define criteria that will be used to permit or deny agent actions.
- File System - ability to emulate a node's file system during a simulation
- Example notebooks - There is currently 1 jupyter notebook which walks through using PrimAITE
1. Creating a simulation - this notebook explains how to build up a simulation using the Python package. (WIP)
- Example notebooks - There are 5 jupyter notebook which walk through using PrimAITE
1. Training a Stable Baselines 3 agent
2. Training a single agent system using Ray RLLib
3. Training a multi-agent system Ray RLLib
4. Data manipulation end to end demonstration
5. Data manipulation scenario with customised red agents
- Database:
- `DatabaseClient` and `DatabaseService` created to allow emulation of database actions
- Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup
@@ -62,7 +75,6 @@ SessionManager.
- DNS Services: `DNSClient` and `DNSServer`
- FTP Services: `FTPClient` and `FTPServer`
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
- NTP Services: `NTPClient` and `NTPServer`
- **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic.
- **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required.
@@ -101,6 +113,7 @@ SessionManager.
- Ability to add ``Router``/``Firewall`` ``ACLRule`` via config
- NMNE capturing capabilities to `NetworkInterface` class for detecting and logging Malicious Network Events.
- New `nmne_config` settings in the simulation configuration to enable NMNE capturing and specify keywords such as "DELETE".
- Router-specific SessionManager Implementation: Introduced a specialized version of the SessionManager tailored for router operations. This enhancement enables the SessionManager to determine the routing path by consulting the route table.
### Changed
- Integrated the RouteTable into the Routers frame processing.
@@ -113,7 +126,7 @@ SessionManager.
- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios.
- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules.
- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them.
- Integration of NMNE capturing functionality within the `NicObservation` class.
- Integration of NMNE capturing functionality within the `NICObservation` class.
- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario
### Removed

View File

@@ -26,7 +26,7 @@ PrimAITE presents the following features:
## Getting Started with PrimAITE
### 💫 Install & Run
### 💫 Installation
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
Currently, the PrimAITE wheel can only be installed from GitHub. This may change in the future with release to PyPi.
@@ -47,11 +47,6 @@ pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/relea
primaite setup
```
**Run:**
``` bash
primaite session
```
#### Unix
@@ -75,12 +70,6 @@ pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/relea
primaite setup
```
**Run:**
``` bash
primaite session
```
### Developer Install from Source
@@ -125,6 +114,29 @@ python3 -m pip install -e .[dev]
primaite setup
```
### Running PrimAITE
Use the provided jupyter notebooks as a starting point to try running PrimAITE. They are automatically copied to your PrimAITE notebook folder when you run `primaite setup`.
#### 1. Activate the virtual environment
##### Windows (Powershell)
```powershell
.\venv\Scripts\activate
```
##### Unix
```bash
source venv/bin/activate
```
#### 2. Open jupyter notebook
```bash
python -m jupyter notebook
```
Then, click the URL provided by the jupyter command to open the jupyter application in your browser. You can also open notebooks in your IDE if supported.
## 📚 Building documentation
The PrimAITE documentation can be built with the following commands:

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
import platform

View File

@@ -48,7 +48,7 @@ class "ActiveNode" as primaite.nodes.active_node.ActiveNode {
file_system_state_actual : GOOD
file_system_state_observed : REPAIRING, RESTORING, GOOD
ip_address : str
patching_count : int
fixing_count : int
software_state
software_state : GOOD
set_file_system_state(file_system_state: FileSystemState) -> None
@@ -353,10 +353,10 @@ class "SB3Agent" as primaite.agents.sb3.SB3Agent {
}
class "Service" as primaite.common.service.Service {
name : str
patching_count : int
fixing_count : int
port : str
software_state : GOOD
reduce_patching_count() -> None
reduce_fixing_count() -> None
}
class "ServiceNode" as primaite.nodes.service_node.ServiceNode {
services : Dict[str, Service]
@@ -455,7 +455,7 @@ class "TrainingConfig" as primaite.config.training_config.TrainingConfig {
sb3_output_verbose_level
scanning : float
seed : Optional[int]
service_patching_duration : int
service_fixing_duration : int
session_type
time_delay : int
from_dict(config_dict: Dict[str, Any]) -> TrainingConfig

BIN
docs/_static/notebooks/extensions.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

View File

@@ -10,6 +10,7 @@ import datetime
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import os
import sys
from typing import Any
import furo # noqa
@@ -63,3 +64,21 @@ html_theme = "furo"
html_static_path = ["_static"]
html_theme_options = {"globaltoc_collapse": True, "globaltoc_maxdepth": 2}
html_copy_source = False
def replace_token(app: Any, docname: Any, source: Any):
"""Replaces a token from the list of tokens."""
result = source[0]
for key in app.config.tokens:
result = result.replace(key, app.config.tokens[key])
source[0] = result
tokens = {"{VERSION}": release} # Token VERSION is replaced by the value of the PrimAITE version in the version file
"""Dict containing the tokens that need to be replaced in documentation."""
def setup(app: Any):
"""Custom setup for sphinx."""
app.add_config_value("tokens", {}, True)
app.connect("source-read", replace_token)

View File

@@ -105,6 +105,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/getting_started
source/primaite_session
source/example_notebooks
source/simulation
source/game_layer
source/config

View File

@@ -41,11 +41,11 @@ The game layer is built on top of the simulator and it consumes the simulation a
* Hardware State (ON, OFF, RESETTING, SHUTTING_DOWN, BOOTING - enumeration)
Active Nodes also have the following attributes (Class: Active Node):
* IP Address
* Software State (GOOD, PATCHING, COMPROMISED - enumeration)
* Software State (GOOD, FIXING, COMPROMISED - enumeration)
* File System State (GOOD, CORRUPT, DESTROYED, REPAIRING, RESTORING - enumeration)
Service Nodes also have the following attributes (Class: Service Node):
* List of Services (where service is composed of service name and port). There is no theoretical limit on the number of services that can be modelled. Services and protocols are currently intrinsically linked (i.e. a service is an application on a node transmitting traffic of this protocol type)
* Service state (GOOD, PATCHING, COMPROMISED, OVERWHELMED - enumeration)
* Service state (GOOD, FIXING, COMPROMISED, OVERWHELMED - enumeration)
Passive Nodes are currently not used (but may be employed for non IP-based components such as machinery actuators in future releases).
**Links**
Links are modelled both as network edges (networkx) and as Python classes, in order to extend their functionality. Links include the following attributes:
@@ -70,8 +70,8 @@ The game layer is built on top of the simulator and it consumes the simulation a
* Running status (i.e. on / off)
The application of green agent IERs between a source and destination follows a number of rules. Specifically:
1. Does the current simulation time step fall between IER start and end step
2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not PATCHING)
2. Is the source node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not FIXING)
3. Is the destination node operational (both physically and at an O/S level), and is the service (protocol / port) associated with the IER (a) present on this node, and (b) in an operational state (i.e. not FIXING)
4. Are there any Access Control List rules in place that prevent the application of this IER
5. Are all switches in the (OSPF) path between source and destination operational (both physically and at an O/S level)
For red agent IERs, the application of IERs between a source and destination follows a number of subtly different rules. Specifically:
@@ -95,7 +95,7 @@ The game layer is built on top of the simulator and it consumes the simulation a
* Active Nodes and Service Nodes:
* Software State:
* GOOD
* PATCHING - when a status of patching is entered, the node will automatically exit this state after a number of steps (as defined by the osPatchingDuration configuration item) after which it returns to a GOOD state
* FIXING - when a status of FIXING is entered, the node will automatically exit this state after a number of steps (as defined by the osFIXINGDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* File System State:
* GOOD
@@ -106,7 +106,7 @@ The game layer is built on top of the simulator and it consumes the simulation a
* Service Nodes only:
* Service State (for any associated service):
* GOOD
* PATCHING - when a status of patching is entered, the service will automatically exit this state after a number of steps (as defined by the servicePatchingDuration configuration item) after which it returns to a GOOD state
* FIXING - when a status of FIXING is entered, the service will automatically exit this state after a number of steps (as defined by the serviceFIXINGDuration configuration item) after which it returns to a GOOD state
* COMPROMISED
* OVERWHELMED
Red agent pattern-of-life has an additional feature not found in the green pattern-of-life. This is the ability to influence the state of the attributes of a node via a number of different conditions:
@@ -211,8 +211,8 @@ The game layer is built on top of the simulator and it consumes the simulation a
Hardware State (1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
Operating System State (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
File System State (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
Service1/Protocol1 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
Service2/Protocol2 state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
Service1/Protocol1 state (0=none, 1=GOOD, 2=FIXING, 3=COMPROMISED)
Service2/Protocol2 state (0=none, 1=GOOD, 2=FIXING, 3=COMPROMISED)
]
(Note that each service available in the network is provided as a column, although not all nodes may utilise all services)
For the links, the following statuses are represented:
@@ -241,8 +241,8 @@ The game layer is built on top of the simulator and it consumes the simulation a
hardware_state (0=none, 1=ON, 2=OFF, 3=RESETTING, 4=SHUTTING_DOWN, 5=BOOTING)
software_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
file_system_state (0=none, 1=GOOD, 2=CORRUPT, 3=DESTROYED, 4=REPAIRING, 5=RESTORING)
service1_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
service1_state (0=none, 1=GOOD, 2=FIXING, 3=COMPROMISED)
service2_state (0=none, 1=GOOD, 2=FIXING, 3=COMPROMISED)
]
In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
.. code-block::
@@ -278,7 +278,7 @@ The game layer is built on top of the simulator and it consumes the simulation a
3. Any (Agent can take both node-based and ACL-based actions)
The choice of action space used during a training session is determined in the config_[name].yaml file.
**Node-Based**
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is a Gymnasium spaces.Discrete type, as follows:
The agent is able to influence the status of nodes by switching them off, resetting, or FIXING operating systems and services. In this instance, the action space is a Gymnasium spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3,x4] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, num nodes] - Node ID (0 = nothing, node ID)

View File

@@ -5,8 +5,7 @@
PrimAITE |VERSION| Configuration
********************************
PrimAITE uses a single configuration file to define everything needed to train and evaluate an RL policy in a custom cybersecurity scenario. This includes the configuration of the network, the scripted or trained agents that interact with the network, as well as settings that define how to perform training in Stable Baselines 3 or Ray RLLib.
The entire config is used by the ``PrimaiteSession`` object for users who wish to let PrimAITE handle the agent definition and training. If you wish to define custom agents and control the training loop yourself, you can use the config with the ``PrimaiteGame``, and ``PrimaiteGymEnv`` objects instead. That way, only the network configuration and agent setup parts of the config are used, and the training section is ignored.
PrimAITE uses a single configuration file to define everything needed to create the training environment for RL agents, including the network, the scripted agents, and the RL agent's action space, observation space, and reward function.
Example Configuration Hierarchy
###############################
@@ -14,8 +13,6 @@ The top level configuration items in a configuration file is as follows
.. code-block:: yaml
training_config:
...
io_settings:
...
game:
@@ -33,7 +30,6 @@ Configurable items
.. toctree::
:maxdepth: 1
configuration/training_config.rst
configuration/io_settings.rst
configuration/game.rst
configuration/agents.rst

View File

@@ -82,7 +82,7 @@ Allows configuration of the chosen observation type. These are optional.
* ``num_services_per_node``, ``num_folders_per_node``, ``num_files_per_folder``, ``num_nics_per_node`` all define the shape of the observation space. The size and shape of the obs space must remain constant, but the number of files, folders, ACL rules, and other components can change within an episode. Therefore padding is performed and these options set the size of the obs space.
* ``nodes``: list of nodes that will be present in this agent's observation space. The ``node_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. Each node can also be configured with services, and files that should be monitored.
* ``links``: list of links that will be present in this agent's observation space. The ``link_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config.
* ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_address_order`` sets the encoding of ip addresses as integers within the observation space.
* ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_list`` sets the encoding of ip addresses as integers within the observation space.
For more information see :py:mod:`primaite.game.agent.observations`

View File

@@ -13,42 +13,12 @@ This section configures how PrimAITE saves data during simulation and training.
.. code-block:: yaml
io_settings:
save_final_model: True
save_checkpoints: False
checkpoint_interval: 10
# save_logs: True
# save_transactions: False
save_agent_actions: True
save_step_metadata: False
save_pcap_logs: False
save_sys_logs: False
``save_final_model``
--------------------
Optional. Default value is ``True``.
Only used if training with PrimaiteSession.
If ``True``, the policy will be saved after the final training iteration.
``save_checkpoints``
--------------------
Optional. Default value is ``False``.
Only used if training with PrimaiteSession.
If ``True``, the policy will be saved periodically during training.
``checkpoint_interval``
-----------------------
Optional. Default value is ``10``.
Only used if training with PrimaiteSession and if ``save_checkpoints`` is ``True``.
Defines how often to save the policy during training.
``save_logs``
-------------

View File

@@ -22,35 +22,35 @@ example firewall
network:
nodes:
- ref: firewall
hostname: firewall
type: firewall
start_up_duration: 0
shut_down_duration: 0
ports:
external_port: # port 1
ip_address: 192.168.20.1
subnet_mask: 255.255.255.0
internal_port: # port 2
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
dmz_port: # port 3
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
internal_inbound_acl:
hostname: firewall
type: firewall
start_up_duration: 0
shut_down_duration: 0
ports:
external_port: # port 1
ip_address: 192.168.20.1
subnet_mask: 255.255.255.0
internal_port: # port 2
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
dmz_port: # port 3
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
internal_inbound_acl:
...
internal_outbound_acl:
...
dmz_inbound_acl:
...
dmz_outbound_acl:
...
external_inbound_acl:
...
external_outbound_acl:
...
routes:
...
internal_outbound_acl:
...
dmz_inbound_acl:
...
dmz_outbound_acl:
...
external_inbound_acl:
...
external_outbound_acl:
...
routes:
...
.. include:: common/common_node_attributes.rst

View File

@@ -1,75 +0,0 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
``training_config``
===================
Configuration items relevant to how the Reinforcement Learning agent(s) will be trained.
``training_config`` hierarchy
-----------------------------
.. code-block:: yaml
training_config:
rl_framework: SB3 # or RLLIB_single_agent or RLLIB_multi_agent
rl_algorithm: PPO # or A2C
n_learn_episodes: 5
max_steps_per_episode: 200
n_eval_episodes: 1
deterministic_eval: True
seed: 123
``rl_framework``
----------------
The RL (Reinforcement Learning) Framework to use in the training session
Options available are:
- ``SB3`` (Stable Baselines 3)
- ``RLLIB_single_agent`` (Single Agent Ray RLLib)
- ``RLLIB_multi_agent`` (Multi Agent Ray RLLib)
``rl_algorithm``
----------------
The Reinforcement Learning Algorithm to use in the training session
Options available are:
- ``PPO`` (Proximal Policy Optimisation)
- ``A2C`` (Advantage Actor Critic)
``n_learn_episodes``
--------------------
The number of episodes to train the agent(s).
This should be an integer value above ``0``
``max_steps_per_episode``
-------------------------
The number of steps each episode will last for.
This should be an integer value above ``0``.
``n_eval_episodes``
-------------------
Optional. Default value is ``0``.
The number of evaluation episodes to run the trained agent for.
This should be an integer value above ``0``.
``deterministic_eval``
----------------------
Optional. By default this value is ``False``.
If this is set to ``True``, the agents will act deterministically instead of stochastically.
``seed``
--------
Optional.
The seed is used (alongside ``deterministic_eval``) to reproduce a previous instance of training and evaluation of an RL agent.
The seed should be an integer value.
Useful for debugging.

View File

@@ -0,0 +1,77 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Example Jupyter Notebooks
=========================
There are a few example notebooks included which help with the understanding of PrimAITE's capabilities.
The Jupyter Notebooks can be run via the 2 examples below. These assume that the instructions to install PrimAITE from the :ref:`Getting Started <getting-started>` page is completed as a prerequisite.
Running Jupyter Notebooks
-------------------------
1. Navigate to the PrimAITE directory
.. code-block:: bash
:caption: Unix
cd ~/primaite/{VERSION}
.. code-block:: powershell
:caption: Windows (Powershell)
cd ~\primaite\{VERSION}
2. Run jupyter notebook (the python environment to which you installed PrimAITE must be active)
.. code-block:: bash
:caption: Unix
jupyter notebook
.. code-block:: powershell
:caption: Windows (Powershell)
jupyter notebook
3. Opening the jupyter webpage (optional)
The default web browser may automatically open the webpage. However, if that is not the case, click the link shown in your command prompt output. It should look like this: ``http://localhost:8888/?token=0123456798abc0123456789abc``
4. Navigate to the list of notebooks
The example notebooks are located in ``notebooks/example_notebooks/``. The file system shown in the jupyter webpage is relative to the location in which the ``jupyter notebook`` command was used.
Running Jupyter Notebooks via VSCode
------------------------------------
It is also possible to view the Jupyter notebooks within VSCode.
The best place to start is by opening a notebook file (.ipynb) in VSCode. If using VSCode to view a notebook for the first time, follow the steps below.
Installing extensions
"""""""""""""""""""""
VSCode may need some extensions to be installed if not already done.
To do this, press the "Select Kernel" button on the top right.
This should open a dialog which has the option to install python and jupyter extensions.
.. image:: ../../_static/notebooks/install_extensions.png
:width: 700
:align: center
:alt: :: The top dialog option that appears will automatically install the extensions
The following extensions should now be installed
.. image:: ../../_static/notebooks/extensions.png
:width: 300
:align: center
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.10
You should now be able to interact with the notebook.

View File

@@ -6,49 +6,82 @@ The Primaite codebase consists of two main modules:
* ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes.
* ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules.
The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API.
..
TODO: write up these APIs and link them here.
Game layer
----------
The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API.
The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components:
PrimAITE Session
^^^^^^^^^^^^^^^^
.. admonition:: Deprecated
:class: deprecated
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
Agents
^^^^^^
======
All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types:
* RL agents action during each step is decided by an appropriate RL algorithm. The agent within PrimAITE just acts to format and forward actions decided by an RL policy.
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable.
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed is settable.
..
TODO: add seed to stochastic scripted agents
Observations
^^^^^^^^^^^^^^^^^^
============
An agent's observations are managed by the ``ObservationManager`` class. It generates observations based on the current simulation state dictionary. It also provides the observation space during initial setup. The data is formatted so it's compatible with ``Gymnasium.spaces``. Observation spaces are composed of one or more components which are defined by the ``AbstractObservation`` base class.
Actions
^^^^^^^
=======
An agent's actions are managed by the ``ActionManager``. It converts actions selected by agents (which are typically integers chosen from a ``gymnasium.spaces.Discrete`` space) into simulation-friendly requests. It also provides the action space during initial setup. Action spaces are composed of one or more components which are defined by the ``AbstractAction`` base class.
Rewards
^^^^^^^
=======
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. The reward components are defined by the AbstractReward base class.
An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server.
Reward Components
-----------------
Currently implemented are reward components tailored to the data manipulation scenario. View the full API and description of how they work here: :py:module:`primaite.game.agent.reward`.
Reward Sharing
--------------
An agent's reward can be based on rewards of other agents. This is particularly useful for modelling a situation where the blue agent's job is to protect the ability of green agents to perform their pattern-of-life. This can be configured in the YAML file this way:
```yaml
green_agent_1: # this agent sometimes tries to access the webpage, and sometimes the database
# actions, observations, and agent settings go here
reward_function:
reward_components:
# When the webpage loads, the reward goes up by 0.25 when it fails to load, it goes down to -0.25
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
# When the database is reachable, the reward goes up by 0.05, when it is unreachable it goes down to -0.05
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
blue_agent:
# actions, observations, and agent settings go here
reward_function:
reward_components:
# When the database file is in a good state, blue's reward is 0.4, when it's in a corrupted state the reward is -0.4
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
# The green's reward is added onto the blue's reward.
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_2_green_user
```
When defining agent reward sharing, users must be careful to avoid circular references, as that would lead to an infinite calculation loop. PrimAITE will prevent circular dependencies and provide a helpful error message if they are detected in the yaml.

View File

@@ -11,7 +11,7 @@ Getting Started
Pre-Requisites
In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.11 installed. If you don't already have it, this is how to install it:
In order to get **PrimAITE** installed, you will need Python, venv, and pip. If you don't already have them, this is how to install it:
.. code-block:: bash
@@ -30,6 +30,8 @@ In order to get **PrimAITE** installed, you will need to have a python version b
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
Installing PrimAITE has been tested with all supported python versions, venv 20.24.1, and pip 23.
Install PrimAITE
****************
@@ -38,12 +40,12 @@ Install PrimAITE
.. code-block:: bash
:caption: Unix
mkdir ~/primaite/3.0.0
mkdir -p ~/primaite/{VERSION}
.. code-block:: powershell
:caption: Windows (Powershell)
mkdir ~\primaite\3.0.0
mkdir ~\primaite\{VERSION}
2. Navigate to the primaite directory and create a new python virtual environment (venv)
@@ -51,13 +53,13 @@ Install PrimAITE
.. code-block:: bash
:caption: Unix
cd ~/primaite/3.0.0
cd ~/primaite/{VERSION}
python3 -m venv .venv
.. code-block:: powershell
:caption: Windows (Powershell)
cd ~\primaite\3.0.0
cd ~\primaite\{VERSION}
python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory

View File

@@ -1,41 +0,0 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _run a primaite session:
.. admonition:: Deprecated
:class: deprecated
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
Run a PrimAITE Session
======================
``PrimaiteSession`` allows the user to train or evaluate an RL agent on the primaite simulation with just a config file,
no code required. It manages the lifecycle of a training or evaluation session, including the setup of the environment,
policy, simulator, agents, and IO.
If you want finer control over the RL policy, you can interface with the :py:module::`primaite.session.environment`
module directly without running a session.
Run
---
A PrimAITE session can be started either with the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
There are two parameters that can be specified:
- ``--config``: The path to the config file to use. If not specified, the default config file is used.
- ``--agent-load-file``: The path to the pre-trained agent to load. If not specified, a new agent is created.
Outputs
-------
Running a session creates a session output directory in your user data folder. The filepath looks like this:
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that
contains the action, reward, and simulation state. These can be found in
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_<n>/step_metadata/step_<n>.json``

View File

@@ -25,6 +25,7 @@ Contents
simulation_components/network/nodes/switch
simulation_components/network/nodes/wireless_router
simulation_components/network/nodes/firewall
simulation_components/network/switch
simulation_components/network/network
simulation_components/system/internal_frame_processing
simulation_components/system/sys_log

View File

@@ -73,7 +73,7 @@ Network Interface Classes
- Malicious Network Events Monitoring:
* Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns.
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies.
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NICObservation`` to classify and record network anomalies.
* Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms.
**WiredNetworkInterface (Connection Type Layer)**

View File

@@ -38,8 +38,9 @@ dependencies = [
"stable-baselines3[extra]==2.1.0",
"tensorflow==2.12.0",
"typer[all]==0.9.0",
"pydantic==2.1.1",
"ray[rllib] == 2.8.0, < 3"
"pydantic==2.7.0",
"ray[rllib] >= 2.9, < 3",
"ipywidgets"
]
[tool.setuptools.dynamic]

View File

@@ -1 +1 @@
3.0.0b6
3.0.0b9dev

View File

@@ -114,23 +114,3 @@ def setup(overwrite_existing: bool = True) -> None:
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
config: Optional[str] = None,
agent_load_file: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from primaite.config.load import data_manipulation_config_path
from primaite.main import run
if not config:
config = data_manipulation_config_path()
print(config)
run(config_path=config, agent_load_path=agent_load_file)

View File

@@ -1,26 +1,12 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true
save_sys_logs: false
game:
max_episode_length: 256
max_episode_length: 128
ports:
- HTTP
- POSTGRES_SERVER
@@ -43,8 +29,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -76,7 +61,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
@@ -86,8 +78,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -119,7 +110,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
@@ -129,10 +127,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -165,61 +160,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
@@ -232,7 +239,7 @@ agents:
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_PATCH
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
@@ -246,14 +253,10 @@ agents:
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
@@ -309,7 +312,7 @@ agents:
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
@@ -327,7 +330,7 @@ agents:
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
action: "NODE_SERVICE_FIX"
options:
node_id: 2
service_id: 0
@@ -337,7 +340,7 @@ agents:
node_id: 2
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
@@ -465,8 +468,9 @@ agents:
node_id: 6
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
@@ -474,9 +478,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
@@ -484,9 +491,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
@@ -494,9 +504,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
@@ -504,9 +517,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
50: # old action num: 26
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
@@ -514,9 +530,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
51: # old action num: 27
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
@@ -524,123 +543,135 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
52: # old action num: 28
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 0
53: # old action num: 29
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 1
54: # old action num: 30
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 2
55: # old action num: 31
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 3
56: # old action num: 32
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 4
57: # old action num: 33
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 5
58: # old action num: 34
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 6
59: # old action num: 35
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 7
60: # old action num: 36
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 8
61: # old action num: 37
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 9
62: # old action num: 38
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
63: # old action num: 39
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
64: # old action num: 40
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
65: # old action num: 41
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
66: # old action num: 42
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
67: # old action num: 43
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
68: # old action num: 44
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
69: # old action num: 45
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
70: # old action num: 46
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
71: # old action num: 47
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
72: # old action num: 48
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
73: # old action num: 49
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
74: # old action num: 50
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
75: # old action num: 51
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
76: # old action num: 52
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
77: # old action num: 53
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
@@ -672,23 +703,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
@@ -699,22 +722,17 @@ agents:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
agent_name: client_2_green_user
agent_settings:
@@ -732,8 +750,7 @@ simulation:
- DELETE
nodes:
- ref: router_1
hostname: router_1
- hostname: router_1
type: router
num_ports: 5
ports:
@@ -768,74 +785,61 @@ simulation:
action: PERMIT
protocol: ICMP
- ref: switch_1
hostname: switch_1
- hostname: switch_1
type: switch
num_ports: 8
- ref: switch_2
hostname: switch_2
- hostname: switch_2
type: switch
num_ports: 8
- ref: domain_controller
hostname: domain_controller
- hostname: domain_controller
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
hostname: web_server
- hostname: web_server
type: server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
- type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: database_server
hostname: database_server
- hostname: database_server
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- ref: database_ftp_client
type: FTPClient
- type: FTPClient
- ref: backup_server
hostname: backup_server
- hostname: backup_server
type: server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: FTPServer
- type: FTPServer
- ref: security_suite
hostname: security_suite
- hostname: security_suite
type: server
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
@@ -846,110 +850,88 @@ simulation:
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
hostname: client_1
- hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_1_web_browser
type: WebBrowser
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: client_1_database_client
type: DatabaseClient
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
- type: DNSClient
- ref: client_2
hostname: client_2
- hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: data_manipulation_bot
type: DataManipulationBot
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_2_database_client
type: DatabaseClient
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient
- type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_hostname: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_hostname: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_hostname: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_hostname: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_hostname: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_hostname: security_suite
endpoint_b_port: 2

File diff suppressed because it is too large Load Diff

View File

@@ -156,12 +156,12 @@ class NodeServiceEnableAction(NodeServiceAbstractAction):
self.verb: str = "enable"
class NodeServicePatchAction(NodeServiceAbstractAction):
"""Action which patches a service."""
class NodeServiceFixAction(NodeServiceAbstractAction):
"""Action which fixes a service."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb: str = "patch"
self.verb: str = "fix"
class NodeApplicationAbstractAction(AbstractAction):
@@ -195,6 +195,69 @@ class NodeApplicationExecuteAction(NodeApplicationAbstractAction):
self.verb: str = "execute"
class NodeApplicationScanAction(NodeApplicationAbstractAction):
"""Action which scans an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
self.verb: str = "scan"
class NodeApplicationCloseAction(NodeApplicationAbstractAction):
"""Action which closes an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
self.verb: str = "close"
class NodeApplicationFixAction(NodeApplicationAbstractAction):
"""Action which fixes an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
self.verb: str = "fix"
class NodeApplicationInstallAction(AbstractAction):
"""Action which installs an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str, ip_address: str) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
return [
"network",
"node",
node_name,
"software_manager",
"application",
"install",
application_name,
ip_address,
]
class NodeApplicationRemoveAction(AbstractAction):
"""Action which removes/uninstalls an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
return ["network", "node", node_name, "software_manager", "application", "uninstall", application_name]
class NodeFolderAbstractAction(AbstractAction):
"""
Base class for folder actions.
@@ -381,25 +444,22 @@ class NodeResetAction(NodeAbstractAction):
self.verb: str = "reset"
class NetworkACLAddRuleAction(AbstractAction):
class RouterACLAddRuleAction(AbstractAction):
"""Action which adds a rule to a router's ACL."""
def __init__(
self,
manager: "ActionManager",
target_router_hostname: str,
max_acl_rules: int,
num_ips: int,
num_ports: int,
num_protocols: int,
**kwargs,
) -> None:
"""Init method for NetworkACLAddRuleAction.
"""Init method for RouterACLAddRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param target_router_hostname: hostname of the router to which the ACL rule should be added.
:type target_router_hostname: str
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
:param num_ips: Number of IP addresses in the simulation.
@@ -420,14 +480,16 @@ class NetworkACLAddRuleAction(AbstractAction):
"dest_port_id": num_ports,
"protocol_id": num_protocols,
}
self.target_router_name: str = target_router_hostname
def form_request(
self,
target_router_nodename: str,
position: int,
permission: int,
source_ip_id: int,
source_wildcard_id: int,
dest_ip_id: int,
dest_wildcard_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
@@ -437,7 +499,149 @@ class NetworkACLAddRuleAction(AbstractAction):
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif permission == 1:
permission_str = "ALLOW"
permission_str = "PERMIT"
elif permission == 2:
permission_str = "DENY"
else:
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
if protocol_id == 1:
protocol = "ALL"
else:
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
# subtract 2 to account for UNUSED=0 and ALL=1.
if source_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif source_ip_id == 1:
src_ip = "ALL"
else:
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
if source_port_id == 0:
return ["do_nothing"] # invalid formulation
elif source_port_id == 1:
src_port = "ALL"
else:
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if dest_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_ip_id == 1:
dst_ip = "ALL"
else:
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
if dest_port_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_port_id == 1:
dst_port = "ALL"
else:
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
return [
"network",
"node",
target_router_nodename,
"acl",
"add_rule",
permission_str,
protocol,
str(src_ip),
src_wildcard,
src_port,
str(dst_ip),
dst_wildcard,
dst_port,
position,
]
class RouterACLRemoveRuleAction(AbstractAction):
"""Action which removes a rule from a router's ACL."""
def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None:
"""Init method for RouterACLRemoveRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
def form_request(self, target_router_nodename: str, position: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", target_router_nodename, "acl", "remove_rule", position]
class FirewallACLAddRuleAction(AbstractAction):
"""Action which adds a rule to a firewall port's ACL."""
def __init__(
self,
manager: "ActionManager",
max_acl_rules: int,
num_ips: int,
num_ports: int,
num_protocols: int,
**kwargs,
) -> None:
"""Init method for FirewallACLAddRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
:param num_ips: Number of IP addresses in the simulation.
:type num_ips: int
:param num_ports: Number of ports in the simulation.
:type num_ports: int
:param num_protocols: Number of protocols in the simulation.
:type num_protocols: int
"""
super().__init__(manager=manager)
num_permissions = 3
self.shape: Dict[str, int] = {
"position": max_acl_rules,
"permission": num_permissions,
"source_ip_id": num_ips,
"dest_ip_id": num_ips,
"source_port_id": num_ports,
"dest_port_id": num_ports,
"protocol_id": num_protocols,
}
def form_request(
self,
target_firewall_nodename: str,
firewall_port_name: str,
firewall_port_direction: str,
position: int,
permission: int,
source_ip_id: int,
source_wildcard_id: int,
dest_ip_id: int,
dest_wildcard_id: int,
source_port_id: int,
dest_port_id: int,
protocol_id: int,
) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if permission == 0:
permission_str = "UNUSED"
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
elif permission == 1:
permission_str = "PERMIT"
elif permission == 2:
permission_str = "DENY"
else:
@@ -468,7 +672,7 @@ class NetworkACLAddRuleAction(AbstractAction):
src_port = self.manager.get_port_by_idx(source_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
if source_ip_id == 0:
if dest_ip_id == 0:
return ["do_nothing"] # invalid formulation
elif dest_ip_id == 1:
dst_ip = "ALL"
@@ -483,46 +687,60 @@ class NetworkACLAddRuleAction(AbstractAction):
else:
dst_port = self.manager.get_port_by_idx(dest_port_id - 2)
# subtract 2 to account for UNUSED=0, and ALL=1
src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id)
dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id)
return [
"network",
"node",
self.target_router_name,
target_firewall_nodename,
firewall_port_name,
firewall_port_direction,
"acl",
"add_rule",
permission_str,
protocol,
str(src_ip),
src_wildcard,
src_port,
str(dst_ip),
dst_wildcard,
dst_port,
position,
]
class NetworkACLRemoveRuleAction(AbstractAction):
"""Action which removes a rule from a router's ACL."""
class FirewallACLRemoveRuleAction(AbstractAction):
"""Action which removes a rule from a firewall port's ACL."""
def __init__(self, manager: "ActionManager", target_router_hostname: str, max_acl_rules: int, **kwargs) -> None:
"""Init method for NetworkACLRemoveRuleAction.
def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None:
"""Init method for RouterACLRemoveRuleAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param target_router_hostname: Hostname of the router from which the ACL rule should be removed.
:type target_router_hostname: str
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
:type max_acl_rules: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
self.target_router_name: str = target_router_hostname
def form_request(self, position: int) -> List[str]:
def form_request(
self, target_firewall_nodename: str, firewall_port_name: str, firewall_port_direction: str, position: int
) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", self.target_router_name, "acl", "remove_rule", position]
return [
"network",
"node",
target_firewall_nodename,
firewall_port_name,
firewall_port_direction,
"acl",
"remove_rule",
position,
]
class NetworkNICAbstractAction(AbstractAction):
class HostNICAbstractAction(AbstractAction):
"""
Abstract base class for NIC actions.
@@ -531,7 +749,7 @@ class NetworkNICAbstractAction(AbstractAction):
"""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkNICAbstractAction.
"""Init method for HostNICAbstractAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
@@ -553,7 +771,7 @@ class NetworkNICAbstractAction(AbstractAction):
return ["network", "node", node_name, "network_interface", nic_num, self.verb]
class NetworkNICEnableAction(NetworkNICAbstractAction):
class HostNICEnableAction(HostNICAbstractAction):
"""Action which enables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
@@ -561,7 +779,7 @@ class NetworkNICEnableAction(NetworkNICAbstractAction):
self.verb: str = "enable"
class NetworkNICDisableAction(NetworkNICAbstractAction):
class HostNICDisableAction(HostNICAbstractAction):
"""Action which disables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
@@ -569,6 +787,44 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
self.verb: str = "disable"
class NetworkPortEnableAction(AbstractAction):
"""Action which enables are port on a router or a firewall."""
def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkPortEnableAction.
:param max_nics_per_node: Maximum number of NICs per node.
:type max_nics_per_node: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
return ["network", "node", target_nodename, "network_interface", port_id, "enable"]
class NetworkPortDisableAction(AbstractAction):
"""Action which disables are port on a router or a firewall."""
def __init__(self, manager: "ActionManager", max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkPortDisableAction.
:param max_nics_per_node: Maximum number of NICs per node.
:type max_nics_per_node: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
return ["network", "node", target_nodename, "network_interface", port_id, "disable"]
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -582,8 +838,13 @@ class ActionManager:
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
"NODE_SERVICE_PATCH": NodeServicePatchAction,
"NODE_SERVICE_FIX": NodeServiceFixAction,
"NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction,
"NODE_APPLICATION_SCAN": NodeApplicationScanAction,
"NODE_APPLICATION_CLOSE": NodeApplicationCloseAction,
"NODE_APPLICATION_FIX": NodeApplicationFixAction,
"NODE_APPLICATION_INSTALL": NodeApplicationInstallAction,
"NODE_APPLICATION_REMOVE": NodeApplicationRemoveAction,
"NODE_FILE_SCAN": NodeFileScanAction,
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
"NODE_FILE_DELETE": NodeFileDeleteAction,
@@ -598,10 +859,14 @@ class ActionManager:
"NODE_SHUTDOWN": NodeShutdownAction,
"NODE_STARTUP": NodeStartupAction,
"NODE_RESET": NodeResetAction,
"NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
"ROUTER_ACL_ADDRULE": RouterACLAddRuleAction,
"ROUTER_ACL_REMOVERULE": RouterACLRemoveRuleAction,
"FIREWALL_ACL_ADDRULE": FirewallACLAddRuleAction,
"FIREWALL_ACL_REMOVERULE": FirewallACLRemoveRuleAction,
"HOST_NIC_ENABLE": HostNICEnableAction,
"HOST_NIC_DISABLE": HostNICDisableAction,
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""
@@ -617,7 +882,8 @@ class ActionManager:
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port
ip_address_list: List[str] = [], # to allow us to map an index to an ip address.
ip_list: List[str] = [], # to allow us to map an index to an ip address.
wildcard_list: List[str] = [], # to allow mapping from wildcard index to
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
) -> None:
"""Init method for ActionManager.
@@ -643,8 +909,8 @@ class ActionManager:
:type protocols: List[str]
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
:type ports: List[str]
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
:type ip_address_list: Optional[List[str]]
:param ip_list: List of IP addresses that known to this agent. Used for calculating action shape.
:type ip_list: Optional[List[str]]
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
:type act_map: Optional[Dict[int, Dict]]
"""
@@ -705,8 +971,10 @@ class ActionManager:
self.protocols: List[str] = protocols
self.ports: List[str] = ports
self.ip_address_list: List[str] = ip_address_list
self.ip_address_list: List[str] = ip_list
self.wildcard_list: List[str] = wildcard_list
if self.wildcard_list == []:
self.wildcard_list = ["NONE"]
# action_args are settings which are applied to the action space as a whole.
global_action_args = {
"num_nodes": len(self.node_names),
@@ -743,7 +1011,8 @@ class ActionManager:
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
"""
if act_map is None:
self.action_map = self._enumerate_actions()
# raise RuntimeError("Action map must be specified in the config file.")
pass
else:
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
# make sure all numbers between 0 and N are represented as dict keys in action map
@@ -940,6 +1209,24 @@ class ActionManager:
raise RuntimeError(msg)
return self.ip_address_list[ip_idx]
def get_wildcard_by_idx(self, wildcard_idx: int) -> str:
"""
Get the IP wildcard corresponding to the given index.
:param ip_idx: The index of the IP wildcard to retrieve.
:type ip_idx: int
:return: The wildcard address.
:rtype: str
"""
if wildcard_idx >= len(self.wildcard_list):
msg = (
f"Error: agent attempted to perform an action on ip wildcard {wildcard_idx} but this"
f" is out of range for its action space. Wildcard list: {self.wildcard_list}"
)
_LOGGER.error(msg)
raise RuntimeError(msg)
return self.wildcard_list[wildcard_idx]
def get_port_by_idx(self, port_idx: int) -> str:
"""
Get the port corresponding to the given index.
@@ -998,37 +1285,14 @@ class ActionManager:
:return: The constructed ActionManager.
:rtype: ActionManager
"""
# If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from
# the nodes in the simulation.
# TODO: refactor. Options:
# 1: This should be pulled out into it's own function for clarity
# 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to
# go through the nodes here.
ip_address_order = cfg["options"].pop("ip_address_order", {})
ip_address_list = []
for entry in ip_address_order:
node_name = entry["node_name"]
nic_num = entry["nic_num"]
node_obj = game.simulation.network.get_node_by_hostname(node_name)
ip_address = node_obj.network_interface[nic_num].ip_address
ip_address_list.append(ip_address)
if not ip_address_list:
node_names = [n["node_name"] for n in cfg.get("nodes", {})]
for node_name in node_names:
node_obj = game.simulation.network.get_node_by_hostname(node_name)
if node_obj is None:
continue
network_interfaces = node_obj.network_interfaces
for nic_uuid, nic_obj in network_interfaces.items():
ip_address_list.append(nic_obj.ip_address)
if "ip_list" not in cfg["options"]:
cfg["options"]["ip_list"] = []
obj = cls(
actions=cfg["action_list"],
**cfg["options"],
protocols=game.options.protocols,
ports=game.options.ports,
ip_address_list=ip_address_list,
act_map=cfg.get("action_map"),
)

View File

@@ -1,6 +1,6 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
@@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.interface.request import RequestFormat, RequestResponse
if TYPE_CHECKING:
pass
class AgentActionHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
timestep: int
"""Timestep of this action."""
action: str
"""CAOS Action name."""
parameters: Dict[str, Any]
"""CAOS parameters for the given action."""
request: RequestFormat
"""The request that was sent to the simulation based on the CAOS action chosen."""
response: RequestResponse
"""The response sent back by the simulator for this action."""
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -90,6 +110,7 @@ class AbstractAgent(ABC):
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.action_history: List[AgentActionHistoryItem] = []
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -109,7 +130,7 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state)
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -120,8 +141,6 @@ class AbstractAgent(ABC):
:param obs: Observation of the environment.
:type obs: ObsType
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
:type reward: float, optional
:param timestep: The current timestep in the simulation, used for non-RL agents. Optional
:type timestep: int
:return: Action to be taken in the environment.
@@ -138,9 +157,15 @@ class AbstractAgent(ABC):
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
def reset_agent_for_episode(self) -> None:
"""Agent reset logic should go here."""
pass
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.action_history.append(
AgentActionHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
)
class AbstractScriptedAgent(AbstractAgent):

View File

@@ -0,0 +1,20 @@
# flake8: noqa
# Pre-import all the observations when we load up the observations module so that they can be resolved by the parser.
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
from primaite.game.agent.observations.firewall_observation import FirewallObservation
from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation
from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation
from primaite.game.agent.observations.node_observations import NodesObservation
from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.router_observation import RouterObservation
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
# fmt: off
__all__ = [
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
# fmt: on

View File

@@ -0,0 +1,187 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
class ACLObservation(AbstractObservation, identifier="ACL"):
"""ACL observation, provides information about access control lists within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ACLObservation."""
ip_list: Optional[List[IPv4Address]] = None
"""List of IP addresses."""
wildcard_list: Optional[List[str]] = None
"""List of wildcard strings."""
port_list: Optional[List[int]] = None
"""List of port numbers."""
protocol_list: Optional[List[str]] = None
"""List of protocol names."""
num_rules: Optional[int] = None
"""Number of ACL rules."""
def __init__(
self,
where: WhereType,
num_rules: int,
ip_list: List[IPv4Address],
wildcard_list: List[str],
port_list: List[int],
protocol_list: List[str],
) -> None:
"""
Initialise an ACL observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
:type where: WhereType
:param num_rules: Number of ACL rules.
:type num_rules: int
:param ip_list: List of IP addresses.
:type ip_list: List[IPv4Address]
:param wildcard_list: List of wildcard strings.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param protocol_list: List of protocol names.
:type protocol_list: List[str]
"""
self.where = where
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol_id": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing ACL rules.
:rtype: ObsType
"""
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol_id": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
dst_ip = rule_state["dst_ip_address"]
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
src_wildcard = rule_state["src_wildcard_mask"]
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
dst_wildcard = rule_state["dst_wildcard_mask"]
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
src_port = rule_state["src_port"]
src_port_id = self.port_to_id.get(src_port, 1)
dst_port = rule_state["dst_port"]
dst_port_id = self.port_to_id.get(dst_port, 1)
protocol = rule_state["protocol"]
protocol_id = self.protocol_to_id.get(protocol, 1)
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_ip_id": src_node_id,
"source_wildcard_id": src_wildcard_id,
"source_port_id": src_port_id,
"dest_ip_id": dst_node_id,
"dest_wildcard_id": dst_wildcard_id,
"dest_port_id": dst_port_id,
"protocol_id": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for ACL rules.
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
"""
Create an ACL observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the ACL observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this ACL's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed ACL observation instance.
:rtype: ACLObservation
"""
return cls(
where=parent_where + ["acl", "acl"],
num_rules=config.num_rules,
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,
protocol_list=config.protocol_list,
)

View File

@@ -1,188 +0,0 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.observations import (
AbstractObservation,
AclObservation,
ICSObservation,
LinkObservation,
NullObservation,
)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class UC2BlueObservation(AbstractObservation):
"""Container for all observations used by the blue agent in UC2.
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
for the purpose of compiling several observation components.
"""
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
"""Initialise UC2 blue observation.
:param nodes: List of node observations
:type nodes: List[NodeObservation]
:param links: List of link observations
:type links: List[LinkObservation]
:param acl: The Access Control List observation
:type acl: AclObservation
:param ics: The ICS observation
:type ics: ICSObservation
:param where: Where in the simulation state dict to find information. Not used in this particular observation
because it only compiles other observations and doesn't contribute any new information, defaults to None
:type where: Optional[List[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
num_nics_per_node = config["num_nics_per_node"]
nodes = [
NodeObservation.from_config(
config=n,
game=game,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for n in node_configs
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, game=game)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, game=game)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
"""Container for all observations used by the red agent in UC2."""
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation):
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
pass

View File

@@ -1,126 +1,168 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, Iterable, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class FileObservation(AbstractObservation, identifier="FILE"):
"""File observation, provides status information about a file within the simulation environment."""
class FileObservation(AbstractObservation):
"""Observation of a file on a node in the network."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FileObservation."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
file_name: str
"""Name of the file, used for querying simulation state dictionary."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to the file in the observation."""
def __init__(self, where: WhereType, include_num_access: bool) -> None:
"""
Initialise file observation.
Initialise a file observation instance.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
:param where: Where in the simulation state dictionary to find the relevant information for this file.
A typical location for a file might be
['network', 'nodes', <node_hostname>, 'file_system', 'folder', <folder_name>, 'files', <file_name>].
:type where: WhereType
:param include_num_access: Whether to include the number of accesses to the file in the observation.
:type include_num_access: bool
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
self.where: WhereType = where
self.include_num_access: bool = include_num_access
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
self.default_observation: ObsType = {"health_status": 0}
if self.include_num_access:
self.default_observation["num_access"] = 0
:param state: Simulation state dictionary
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_access(self, num_access: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_access > self.high_threshold:
return 3
elif num_access > self.med_threshold:
return 2
elif num_access > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing the health status of the file and optionally the number of accesses.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["visible_status"]}
obs = {"health_status": file_state["visible_status"]}
if self.include_num_access:
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:return: Gymnasium space representing the observation space for file status.
:rtype: spaces.Space
"""
return spaces.Dict({"health_status": spaces.Discrete(6)})
space = {"health_status": spaces.Discrete(6)}
if self.include_num_access:
space["num_access"] = spaces.Discrete(4)
return spaces.Dict(space)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param game: _description_
:type game: PrimaiteGame
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
:rtype: _type_
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
"""
return cls(where=parent_where + ["files", config["file_name"]])
Create a file observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the file observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this file's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed file observation instance.
:rtype: FileObservation
"""
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
class FolderObservation(AbstractObservation):
"""Folder observation, including files inside of the folder."""
class FolderObservation(AbstractObservation, identifier="FOLDER"):
"""Folder observation, provides status information about a folder within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FolderObservation."""
folder_name: str
"""Name of the folder, used for querying simulation state dictionary."""
files: List[FileObservation.ConfigSchema] = []
"""List of file configurations within the folder."""
num_files: Optional[int] = None
"""Number of spaces for file observations in this folder."""
include_num_access: Optional[bool] = None
"""Whether files in this folder should include the number of accesses in their observation."""
def __init__(
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
) -> None:
"""Initialise folder Observation, including files inside the folder.
"""
Initialise a folder observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
:type max_files: int, optional
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
that even if new files are created, the existing files will always occupy the same space in the observation
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
name, it will take the position defined in this dict. Defaults to {}
:type file_positions: Dict[int, str], optional
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
:type where: WhereType
:param files: List of file observation instances within the folder.
:type files: Iterable[FileObservation]
:param num_files: Number of files expected in the folder.
:type num_files: int
:param include_num_access: Whether to include the number of accesses to files in the observation.
:type include_num_access: bool
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.where: WhereType = where
self.files: List[FileObservation] = files
while len(self.files) < num_files_per_folder:
self.files.append(FileObservation())
while len(self.files) > num_files_per_folder:
while len(self.files) < num_files:
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
while len(self.files) > num_files:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
_LOGGER.warning(msg)
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
if self.files:
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the health status of the folder and status of files within the folder.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
@@ -130,48 +172,42 @@ class FolderObservation(AbstractObservation):
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
if self.files:
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:return: Gymnasium space representing the observation space for folder status.
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
shape = {"health_status": spaces.Discrete(6)}
if self.files:
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
return spaces.Dict(shape)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
"""
Create a folder observation from a configuration schema.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param config: Configuration schema containing the necessary information for the folder observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_hostname>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed folder observation
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed folder observation instance.
:rtype: FolderObservation
"""
where = parent_where + ["folders", config["folder_name"]]
where = parent_where + ["file_system", "folders", config.folder_name]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
# pass down shared/common config items
for file_config in config.files:
file_config.include_num_access = config.include_num_access
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
_LOGGER = getLogger(__name__)
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""Firewall observation, provides status information about a firewall within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FirewallObservation."""
hostname: str
"""Hostname of the firewall node, used for querying simulation state dictionary."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
def __init__(
self,
where: WhereType,
ip_list: List[str],
wildcard_list: List[str],
port_list: List[int],
protocol_list: List[str],
num_rules: int,
) -> None:
"""
Initialise a firewall observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
:type where: WhereType
:param ip_list: List of IP addresses.
:type ip_list: List[str]
:param wildcard_list: List of wildcard rules.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param protocol_list: List of protocol types.
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
"""
self.where: WhereType = where
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
]
# TODO: check what the port nums are for firewall.
self.internal_inbound_acl = ACLObservation(
where=self.where + ["internal_inbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.internal_outbound_acl = ACLObservation(
where=self.where + ["internal_outbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.dmz_inbound_acl = ACLObservation(
where=self.where + ["dmz_inbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.dmz_outbound_acl = ACLObservation(
where=self.where + ["dmz_outbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.external_inbound_acl = ACLObservation(
where=self.where + ["external_inbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.external_outbound_acl = ACLObservation(
where=self.where + ["external_outbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.default_observation = {
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.default_observation,
"OUTBOUND": self.internal_outbound_acl.default_observation,
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.default_observation,
"OUTBOUND": self.dmz_outbound_acl.default_observation,
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.default_observation,
"OUTBOUND": self.external_outbound_acl.default_observation,
},
},
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
:rtype: ObsType
"""
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
},
}
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for firewall status.
:rtype: spaces.Space
"""
space = spaces.Dict(
{
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
"ACL": spaces.Dict(
{
"INTERNAL": spaces.Dict(
{
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}
),
"DMZ": spaces.Dict(
{
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}
),
"EXTERNAL": spaces.Dict(
{
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}
),
}
),
}
)
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
"""
Create a firewall observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the firewall observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this firewall's
parent node. A typical location for a node might be ['network', 'nodes', <firewall_hostname>].
:type parent_where: WhereType, optional
:return: Constructed firewall observation instance.
:rtype: FirewallObservation
"""
return cls(
where=parent_where + [config.hostname],
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,
protocol_list=config.protocol_list,
num_rules=config.num_rules,
)

View File

@@ -0,0 +1,251 @@
from __future__ import annotations
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
class HostObservation(AbstractObservation, identifier="HOST"):
"""Host observation, provides status information about a host within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for HostObservation."""
hostname: str
"""Hostname of the host, used for querying simulation state dictionary."""
services: List[ServiceObservation.ConfigSchema] = []
"""List of services to observe on the host."""
applications: List[ApplicationObservation.ConfigSchema] = []
"""List of applications to observe on the host."""
folders: List[FolderObservation.ConfigSchema] = []
"""List of folders to observe on the host."""
network_interfaces: List[NICObservation.ConfigSchema] = []
"""List of network interfaces to observe on the host."""
num_services: Optional[int] = None
"""Number of spaces for service observations on this host."""
num_applications: Optional[int] = None
"""Number of spaces for application observations on this host."""
num_folders: Optional[int] = None
"""Number of spaces for folder observations on this host."""
num_files: Optional[int] = None
"""Number of spaces for file observations on this host."""
num_nics: Optional[int] = None
"""Number of spaces for network interface observations on this host."""
include_nmne: Optional[bool] = None
"""Whether network interface observations should include number of malicious network events."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to files observations on this host."""
def __init__(
self,
where: WhereType,
services: List[ServiceObservation],
applications: List[ApplicationObservation],
folders: List[FolderObservation],
network_interfaces: List[NICObservation],
num_services: int,
num_applications: int,
num_folders: int,
num_files: int,
num_nics: int,
include_nmne: bool,
include_num_access: bool,
) -> None:
"""
Initialise a host observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this host.
A typical location for a host might be ['network', 'nodes', <hostname>].
:type where: WhereType
:param services: List of service observations on the host.
:type services: List[ServiceObservation]
:param applications: List of application observations on the host.
:type applications: List[ApplicationObservation]
:param folders: List of folder observations on the host.
:type folders: List[FolderObservation]
:param network_interfaces: List of network interface observations on the host.
:type network_interfaces: List[NICObservation]
:param num_services: Number of services to observe.
:type num_services: int
:param num_applications: Number of applications to observe.
:type num_applications: int
:param num_folders: Number of folders to observe.
:type num_folders: int
:param num_files: Number of files.
:type num_files: int
:param num_nics: Number of network interfaces.
:type num_nics: int
:param include_nmne: Flag to include network metrics and errors.
:type include_nmne: bool
:param include_num_access: Flag to include the number of accesses to files.
:type include_num_access: bool
"""
self.where: WhereType = where
self.include_num_access = include_num_access
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
while len(self.services) < num_services:
self.services.append(ServiceObservation(where=None))
while len(self.services) > num_services:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
self.applications: List[ApplicationObservation] = applications
while len(self.applications) < num_applications:
self.applications.append(ApplicationObservation(where=None))
while len(self.applications) > num_applications:
truncated_application = self.applications.pop()
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
_LOGGER.warning(msg)
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders:
self.folders.append(
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
)
while len(self.folders) > num_folders:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
_LOGGER.warning(msg)
self.nics: List[NICObservation] = network_interfaces
while len(self.nics) < num_nics:
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
while len(self.nics) > num_nics:
truncated_nic = self.nics.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
_LOGGER.warning(msg)
self.default_observation: ObsType = {
"operating_status": 0,
}
if self.services:
self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)}
if self.applications:
self.default_observation["APPLICATIONS"] = {
i + 1: a.default_observation for i, a in enumerate(self.applications)
}
if self.folders:
self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)}
if self.nics:
self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)}
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status information about the host.
:rtype: ObsType
"""
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["operating_status"] = node_state["operating_state"]
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for host status.
:rtype: spaces.Space
"""
shape = {
"operating_status": spaces.Discrete(5),
}
if self.services:
shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)})
if self.applications:
shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)})
if self.folders:
shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)})
if self.nics:
shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)})
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> HostObservation:
"""
Create a host observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the host observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this host.
A typical location might be ['network', 'nodes', <hostname>].
:type parent_where: WhereType, optional
:return: Constructed host observation instance.
:rtype: HostObservation
"""
if parent_where == []:
where = ["network", "nodes", config.hostname]
else:
where = parent_where + [config.hostname]
# Pass down shared/common config items
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
# If list of network interfaces is not defined, assume we want to
# monitor the first N interfaces. Network interface numbering starts at 1.
count = 1
while len(nics) < config.num_nics:
nic_config = NICObservation.ConfigSchema(nic_num=count, include_nmne=config.include_nmne)
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
count += 1
return cls(
where=where,
services=services,
applications=applications,
folders=folders,
network_interfaces=nics,
num_services=config.num_services,
num_applications=config.num_applications,
num_folders=config.num_folders,
num_files=config.num_files,
num_nics=config.num_nics,
include_nmne=config.include_nmne,
include_num_access=config.include_num_access,
)

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
from typing import Any, Dict, List
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
class LinkObservation(AbstractObservation, identifier="LINK"):
"""Link observation, providing information about a specific link within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinkObservation."""
link_reference: str
"""Reference identifier for the link."""
def __init__(self, where: WhereType) -> None:
"""
Initialise a link observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this link.
A typical location for a link might be ['network', 'links', <link_reference>].
:type where: WhereType
"""
self.where = where
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
def observe(self, state: Dict) -> Any:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about the link.
:rtype: Any
"""
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
self.where[-1] = "<->".join(self.where[-1].split("<->")[::-1]) # try swapping endpoint A and B
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
utilisation_category = int(utilisation_fraction * 9) + 1
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for link status.
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinkObservation:
"""
Create a link observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the link observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this link.
A typical location might be ['network', 'links', <link_reference>].
:type parent_where: WhereType, optional
:return: Constructed link observation instance.
:rtype: LinkObservation
"""
link_reference = config.link_reference
if parent_where == []:
where = ["network", "links", link_reference]
else:
where = parent_where + ["links", link_reference]
return cls(where=where)
class LinksObservation(AbstractObservation, identifier="LINKS"):
"""Collection of link observations representing multiple links within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinksObservation."""
link_references: List[str]
"""List of reference identifiers for the links."""
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
"""
Initialise a links observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for these links.
A typical location for links might be ['network', 'links'].
:type where: WhereType
:param links: List of link observations.
:type links: List[LinkObservation]
"""
self.where: WhereType = where
self.links: List[LinkObservation] = links
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about multiple links.
:rtype: ObsType
"""
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for multiple links.
:rtype: spaces.Space
"""
return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> LinksObservation:
"""
Create a links observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the links observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about these links.
A typical location might be ['network'].
:type parent_where: WhereType, optional
:return: Constructed links observation instance.
:rtype: LinksObservation
"""
where = parent_where + ["network"]
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
links = [LinkObservation.from_config(c, parent_where=where) for c in link_cfgs]
return cls(where=where, links=links)

View File

@@ -1,97 +1,53 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
"""Status information about a network interface within the simulation environment."""
low_nmne_threshold: int = 0
"""The minimum number of malicious network events to be considered low."""
med_nmne_threshold: int = 5
"""The minimum number of malicious network events to be considered medium."""
high_nmne_threshold: int = 10
"""The minimum number of malicious network events to be considered high."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NICObservation."""
global CAPTURE_NMNE
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
return data
nic_num: int
"""Number of the network interface."""
include_nmne: Optional[bool] = None
"""Whether to include number of malicious network events (NMNE) in the observation."""
def __init__(
self,
where: Optional[Tuple[str]] = None,
low_nmne_threshold: Optional[int] = 0,
med_nmne_threshold: Optional[int] = 5,
high_nmne_threshold: Optional[int] = 10,
where: WhereType,
include_nmne: bool,
) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialise a network interface observation instance.
global CAPTURE_NMNE
if CAPTURE_NMNE:
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
A typical location for a network interface might be
['network', 'nodes', <node_hostname>, 'NICs', <nic_num>].
:type where: WhereType
:param include_nmne: Flag to determine whether to include NMNE information in the observation.
:type include_nmne: bool
"""
self.where = where
self.include_nmne: bool = include_nmne
self.default_observation: ObsType = {"nic_status": 0}
if self.include_nmne:
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
self.nmne_inbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
self.nmne_outbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
self._validate_nmne_categories(
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
def _validate_nmne_categories(
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
):
"""
Validates the nmne threshold config.
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
"""
if high_nmne_threshold <= med_nmne_threshold:
raise Exception(
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
f"than medium nmne count ({med_nmne_threshold})"
)
if med_nmne_threshold <= low_nmne_threshold:
raise Exception(
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
f"than low nmne count ({low_nmne_threshold})"
)
self.high_nmne_threshold = high_nmne_threshold
self.med_nmne_threshold = med_nmne_threshold
self.low_nmne_threshold = low_nmne_threshold
# TODO: allow these to be configured in yaml
self.high_nmne_threshold = 10
self.med_nmne_threshold = 5
self.low_nmne_threshold = 0
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
@@ -116,73 +72,120 @@ class NicObservation(AbstractObservation):
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of the network interface and optionally NMNE information.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs_dict
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
if self.include_nmne:
obs.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for network interface status and NMNE information.
:rtype: spaces.Space
"""
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if CAPTURE_NMNE:
if self.include_nmne:
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
"""
low_nmne_threshold = None
med_nmne_threshold = None
high_nmne_threshold = None
Create a network interface observation from a configuration schema.
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
threshold = game.options.thresholds["nmne"]
:param config: Configuration schema containing the necessary information for the network interface observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed network interface observation instance.
:rtype: NICObservation
"""
return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne)
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
return cls(
where=parent_where + ["NICs", config["nic_num"]],
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
class PortObservation(AbstractObservation, identifier="PORT"):
"""Port observation, provides status information about a network port within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for PortObservation."""
port_id: int
"""Identifier of the port, used for querying simulation state dictionary."""
def __init__(self, where: WhereType) -> None:
"""
Initialise a port observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this port.
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].
:type where: WhereType
"""
self.where = where
self.default_observation: ObsType = {"operating_status": 0}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the operating status of the port.
:rtype: ObsType
"""
port_state = access_from_nested_dict(state, self.where)
if port_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": 1 if port_state["enabled"] else 2}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for port status.
:rtype: spaces.Space
"""
return spaces.Dict({"operating_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
"""
Create a port observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the port observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this port's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed port observation instance.
:rtype: PortObservation
"""
return cls(where=parent_where + ["NICs", config.port_id])

View File

@@ -1,200 +1,216 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import model_validator
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.software_observation import ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.game.agent.observations.firewall_observation import FirewallObservation
from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.router_observation import RouterObservation
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NodesObservation(AbstractObservation, identifier="NODES"):
"""Nodes observation, provides status information about nodes within the simulation environment."""
class NodeObservation(AbstractObservation):
"""Observation of a node in the network. Includes services, folders and NICs."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NodesObservation."""
hosts: List[HostObservation.ConfigSchema] = []
"""List of configurations for host observations."""
routers: List[RouterObservation.ConfigSchema] = []
"""List of configurations for router observations."""
firewalls: List[FirewallObservation.ConfigSchema] = []
"""List of configurations for firewall observations."""
num_services: Optional[int] = None
"""Number of services."""
num_applications: Optional[int] = None
"""Number of applications."""
num_folders: Optional[int] = None
"""Number of folders."""
num_files: Optional[int] = None
"""Number of files."""
num_nics: Optional[int] = None
"""Number of network interface cards (NICs)."""
include_nmne: Optional[bool] = None
"""Flag to include nmne."""
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
@model_validator(mode="after")
def force_optional_fields(self) -> NodesObservation.ConfigSchema:
"""Check that options are specified only if they are needed for the nodes that are part of the config."""
# check for hosts:
host_fields = (
self.num_services,
self.num_applications,
self.num_folders,
self.num_files,
self.num_nics,
self.include_nmne,
self.include_num_access,
)
router_fields = (
self.num_ports,
self.ip_list,
self.wildcard_list,
self.port_list,
self.protocol_list,
self.num_rules,
)
firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules)
if len(self.hosts) > 0 and any([x is None for x in host_fields]):
raise ValueError("Configuration error: Host observation options were not fully specified.")
if len(self.routers) > 0 and any([x is None for x in router_fields]):
raise ValueError("Configuration error: Router observation options were not fully specified.")
if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]):
raise ValueError("Configuration error: Firewall observation options were not fully specified.")
return self
def __init__(
self,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
network_interfaces: List[NicObservation] = [],
logon_status: bool = False,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
where: WhereType,
hosts: List[HostObservation],
routers: List[RouterObservation],
firewalls: List[FirewallObservation],
) -> None:
"""
Configurable observation for a node in the simulation.
Initialise a nodes observation instance.
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service name, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
:type max_services: int, optional
:param folders: Mapping between position in observation space and folder name, defaults to {}
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
:type network_interfaces: Dict[int,str], optional
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
:type max_nics: int, optional
:param where: Where in the simulation state dictionary to find the relevant information for nodes.
A typical location for nodes might be ['network', 'nodes'].
:type where: WhereType
:param hosts: List of host observations.
:type hosts: List[HostObservation]
:param routers: List of router observations.
:type routers: List[RouterObservation]
:param firewalls: List of firewall observations.
:type firewalls: List[FirewallObservation]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.where: WhereType = where
self.services: List[ServiceObservation] = services
while len(self.services) < num_services_per_node:
# add empty service observation without `where` parameter so it always returns default (blank) observation
self.services.append(ServiceObservation())
while len(self.services) > num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
# truncate service list
self.hosts: List[HostObservation] = hosts
self.routers: List[RouterObservation] = routers
self.firewalls: List[FirewallObservation] = firewalls
self.folders: List[FolderObservation] = folders
# add empty folder observation without `where` parameter that will always return default (blank) observations
while len(self.folders) < num_folders_per_node:
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warning(msg)
self.network_interfaces: List[NicObservation] = network_interfaces
while len(self.network_interfaces) < num_nics_per_node:
self.network_interfaces.append(NicObservation())
while len(self.network_interfaces) > num_nics_per_node:
truncated_nic = self.network_interfaces.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warning(msg)
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
self.default_observation = {
**{f"HOST{i}": host.default_observation for i, host in enumerate(self.hosts)},
**{f"ROUTER{i}": router.default_observation for i, router in enumerate(self.routers)},
**{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)},
}
if self.logon_status:
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing status information about nodes.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
obs = {
**{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)},
**{f"ROUTER{i}": router.observe(state) for i, router in enumerate(self.routers)},
**{f"FIREWALL{i}": firewall.observe(state) for i, firewall in enumerate(self.firewalls)},
}
if self.logon_status:
obs["logon_status"] = 0
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
}
if self.logon_status:
space_shape["logon_status"] = spaces.Discrete(3)
"""
Gymnasium space object describing the observation space shape.
return spaces.Dict(space_shape)
:return: Gymnasium space representing the observation space for nodes.
:rtype: spaces.Space
"""
space = spaces.Dict(
{
**{f"HOST{i}": host.space for i, host in enumerate(self.hosts)},
**{f"ROUTER{i}": router.space for i, router in enumerate(self.routers)},
**{f"FIREWALL{i}": firewall.space for i, firewall in enumerate(self.firewalls)},
}
)
return space
@classmethod
def from_config(
cls,
config: Dict,
game: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> "NodeObservation":
"""Create node observation from a config. Also creates child service, folder and NIC observations.
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
observation size) , defaults to 2
:type num_services_per_node: int, optional
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
observation size) , defaults to 2
:type num_folders_per_node: int, optional
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed node observation
:rtype: NodeObservation
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation:
"""
node_hostname = config["node_hostname"]
if parent_where is None:
where = ["network", "nodes", node_hostname]
else:
where = parent_where + ["nodes", node_hostname]
Create a nodes observation from a configuration schema.
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
services=services,
folders=folders,
network_interfaces=network_interfaces,
logon_status=logon_status,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
:param config: Configuration schema containing the necessary information for nodes observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about nodes.
A typical location for nodes might be ['network', 'nodes'].
:type parent_where: WhereType, optional
:return: Constructed nodes observation instance.
:rtype: NodesObservation
"""
if not parent_where:
where = ["network", "nodes"]
else:
where = parent_where + ["nodes"]
for host_config in config.hosts:
if host_config.num_services is None:
host_config.num_services = config.num_services
if host_config.num_applications is None:
host_config.num_applications = config.num_applications
if host_config.num_folders is None:
host_config.num_folders = config.num_folders
if host_config.num_files is None:
host_config.num_files = config.num_files
if host_config.num_nics is None:
host_config.num_nics = config.num_nics
if host_config.include_nmne is None:
host_config.include_nmne = config.include_nmne
if host_config.include_num_access is None:
host_config.include_num_access = config.include_num_access
for router_config in config.routers:
if router_config.num_ports is None:
router_config.num_ports = config.num_ports
if router_config.ip_list is None:
router_config.ip_list = config.ip_list
if router_config.wildcard_list is None:
router_config.wildcard_list = config.wildcard_list
if router_config.port_list is None:
router_config.port_list = config.port_list
if router_config.protocol_list is None:
router_config.protocol_list = config.protocol_list
if router_config.num_rules is None:
router_config.num_rules = config.num_rules
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
firewall_config.ip_list = config.ip_list
if firewall_config.wildcard_list is None:
firewall_config.wildcard_list = config.wildcard_list
if firewall_config.port_list is None:
firewall_config.port_list = config.port_list
if firewall_config.protocol_list is None:
firewall_config.protocol_list = config.protocol_list
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)

View File

@@ -1,16 +1,142 @@
from typing import Dict, TYPE_CHECKING
from __future__ import annotations
from typing import Any, Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
from primaite.game.agent.observations.agent_observations import (
UC2BlueObservation,
UC2GreenObservation,
UC2RedObservation,
)
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NestedObservation(AbstractObservation, identifier="CUSTOM"):
"""Observation type that allows combining other observations into a gymnasium.spaces.Dict space."""
class NestedObservationItem(BaseModel):
"""One list item of the config."""
model_config = ConfigDict(extra="forbid")
type: str
"""Select observation class. It maps to the identifier of the obs class by checking the registry."""
label: str
"""Dict key in the final observation space."""
options: Dict
"""Options to pass to the observation class from_config method."""
@model_validator(mode="after")
def check_model(self) -> "NestedObservation.NestedObservationItem":
"""Make sure tha the config options match up with the selected observation type."""
obs_subclass_name = self.type
obs_options = self.options
if obs_subclass_name not in AbstractObservation._registry:
raise ValueError(f"Observation of type {obs_subclass_name} could not be found.")
obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema
try:
obs_schema(**obs_options)
except ValidationError as e:
raise ValueError(f"Observation options did not match schema, got this error: {e}")
return self
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NestedObservation."""
components: List[NestedObservation.NestedObservationItem] = []
"""List of observation components to be part of this space."""
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
"""Initialise nested observation."""
self.components: Dict[str, AbstractObservation] = components
"""Maps label: observation object"""
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
"""Default observation is just the default observations of constituents."""
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status information about the host.
:rtype: ObsType
"""
return {label: obs.observe(state) for label, obs in self.components.items()}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the nested observation space.
:rtype: spaces.Space
"""
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NestedObservation:
"""
Read the Nested observation config and create all defined subcomponents.
Example configuration that utilises NestedObservation:
This lets us have different options for different types of hosts.
```yaml
observation_space:
- type: CUSTOM
options:
components:
- type: HOSTS
label: COMPUTERS # What is the dictionary key called
options:
hosts:
- client_1
- client_2
num_services: 0
num_applications: 5
... # other options
- type: HOSTS
label: SERVERS # What is the dictionary key called
options:
hosts:
- hostname: database_server
- hostname: web_server
num_services: 4
num_applications: 0
num_folders: 2
num_files: 2
```
"""
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
instances[component.label] = obs_instance
return cls(components=instances)
class NullObservation(AbstractObservation, identifier="NONE"):
"""Empty observation that acts as a placeholder."""
def __init__(self) -> None:
"""Initialise the empty observation."""
self.default_observation = 0
def observe(self, state: Dict) -> Any:
"""Simply return 0."""
return 0
@property
def space(self) -> spaces.Space:
"""Essentially empty space."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: NullObservation.ConfigSchema, parent_where: WhereType = []) -> NullObservation:
"""Instantiate a NullObservation. Accepts parameters to comply with API."""
return cls()
class ObservationManager:
@@ -23,18 +149,15 @@ class ObservationManager:
3. Formatting this information so an agent can use it to make decisions.
"""
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
# refactor.
def __init__(self, observation: AbstractObservation) -> None:
def __init__(self, obs: AbstractObservation) -> None:
"""Initialise observation space.
:param observation: Observation object
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.obs: AbstractObservation = obs
self.current_observation: ObsType
"""Cached copy of the observation at the time it was most recently calculated."""
def update(self, state: Dict) -> Dict:
"""
@@ -52,22 +175,22 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
def from_config(cls, config: Optional[Dict]) -> "ObservationManager":
"""
Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
If None, a blank observation space is created.
Otherwise, this must be a Dict with a type field and options field.
type: string that corresponds to one of the observation identifiers that are provided when subclassing
AbstractObservation
options: this must adhere to the chosen observation type's ConfigSchema nested class.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
else:
raise ValueError("Observation space type invalid")
if config is None:
return cls(NullObservation())
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
obs_manager = cls(observation)
return obs_manager

View File

@@ -1,22 +1,48 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, Iterable, Optional, Type, Union
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
WhereType = Optional[Iterable[Union[str, int]]]
class AbstractObservation(ABC):
"""Abstract class for an observation space component."""
class ConfigSchema(ABC, BaseModel):
"""Config schema for observations."""
model_config = ConfigDict(extra="forbid")
_registry: Dict[str, Type["AbstractObservation"]] = {}
"""Registry of observation components, with their name as key.
Automatically populated when subclasses are defined. Used for defining from_config.
"""
def __init__(self) -> None:
"""Initialise an observation. This method must be overwritten."""
self.default_observation: ObsType
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register an observation type.
:param identifier: Identifier used to uniquely specify observation component types.
:type identifier: str
:raises ValueError: When attempting to create a component with a name that is already in use.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate observation component type {identifier}")
cls._registry[identifier] = cls
@abstractmethod
def observe(self, state: Dict) -> Any:
"""
@@ -37,273 +63,6 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `game` parameter is for a the PrimaiteGame object that spawns this component.
"""
pass
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_hostname"]
nic_num = ip_map_config["nic_num"]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.network_interface[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
pass

View File

@@ -0,0 +1,145 @@
from __future__ import annotations
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""Router observation, provides status information about a router within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for RouterObservation."""
hostname: str
"""Hostname of the router, used for querying simulation state dictionary."""
ports: Optional[List[PortObservation.ConfigSchema]] = None
"""Configuration of port observations for this router."""
num_ports: Optional[int] = None
"""Number of port observations configured for this router."""
acl: Optional[ACLObservation.ConfigSchema] = None
"""Configuration of ACL observation on this router."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
def __init__(
self,
where: WhereType,
ports: List[PortObservation],
num_ports: int,
acl: ACLObservation,
) -> None:
"""
Initialise a router observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this router.
A typical location for a router might be ['network', 'nodes', <node_hostname>].
:type where: WhereType
:param ports: List of port observations representing the ports of the router.
:type ports: List[PortObservation]
:param num_ports: Number of ports for the router.
:type num_ports: int
:param acl: ACL observation representing the access control list of the router.
:type acl: ACLObservation
"""
self.where: WhereType = where
self.ports: List[PortObservation] = ports
self.acl: ACLObservation = acl
self.num_ports: int = num_ports
while len(self.ports) < num_ports:
self.ports.append(PortObservation(where=None))
while len(self.ports) > num_ports:
self.ports.pop()
msg = "Too many ports in router observation. Truncating."
_LOGGER.warning(msg)
self.default_observation = {
"ACL": self.acl.default_observation,
}
if self.ports:
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of ports and ACL configuration of the router.
:rtype: ObsType
"""
router_state = access_from_nested_dict(state, self.where)
if router_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for router status.
:rtype: spaces.Space
"""
shape = {"ACL": self.acl.space}
if self.ports:
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
"""
Create a router observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the router observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this router's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed router observation instance.
:rtype: RouterObservation
"""
where = parent_where + [config.hostname]
if config.acl is None:
config.acl = ACLObservation.ConfigSchema()
if config.acl.num_rules is None:
config.acl.num_rules = config.num_rules
if config.acl.ip_list is None:
config.acl.ip_list = config.ip_list
if config.acl.wildcard_list is None:
config.acl.wildcard_list = config.wildcard_list
if config.acl.port_list is None:
config.acl.port_list = config.port_list
if config.acl.protocol_list is None:
config.acl.protocol_list = config.protocol_list
if config.ports is None:
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)

View File

@@ -1,45 +1,43 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
"""Service observation, shows status of a service in the simulation environment."""
class ServiceObservation(AbstractObservation):
"""Observation of a service in the network."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ServiceObservation."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
service_name: str
"""Name of the service, used for querying simulation state dictionary"""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise service observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'services', <service_name>]`
:type where: Optional[List[str]]
def __init__(self, where: WhereType) -> None:
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialise a service observation instance.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param where: Where in the simulation state dictionary to find the relevant information for this service.
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
:type where: WhereType
"""
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0}
:param state: Simulation state dictionary
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing the operating status and health status of the service.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
@@ -50,114 +48,116 @@ class ServiceObservation(AbstractObservation):
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for service status.
:rtype: spaces.Space
"""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
"""
Create a service observation from a configuration schema.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:param config: Configuration schema containing the necessary information for the service observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this service's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed service observation instance.
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", config["service_name"]])
return cls(where=parent_where + ["services", config.service_name])
class ApplicationObservation(AbstractObservation):
"""Observation of an application in the network."""
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
"""Application observation, shows the status of an application within the simulation environment."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
"Default observation is what should be returned when the application doesn't exist."
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ApplicationObservation."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise application observation.
application_name: str
"""Name of the application, used for querying simulation state dictionary"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'applications', <application_name>]`
:type where: Optional[List[str]]
def __init__(self, where: WhereType) -> None:
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialise an application observation instance.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param where: Where in the simulation state dictionary to find the relevant information for this application.
A typical location for an application might be
['network', 'nodes', <node_hostname>, 'applications', <application_name>].
:type where: WhereType
"""
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
:param state: Simulation state dictionary
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_executions(self, num_executions: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
return 3
elif num_executions > self.med_threshold:
return 2
elif num_executions > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Obs containing the operating status, health status, and number of executions of the application.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
app_state = access_from_nested_dict(state, self.where)
if app_state is NOT_PRESENT_IN_STATE:
application_state = access_from_nested_dict(state, self.where)
if application_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {
"operating_status": app_state["operating_state"],
"health_status": app_state["health_state_visible"],
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for application status.
:rtype: spaces.Space
"""
return spaces.Dict(
{
"operating_status": spaces.Discrete(7),
"health_status": spaces.Discrete(6),
"health_status": spaces.Discrete(5),
"num_executions": spaces.Discrete(4),
}
)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ApplicationObservation":
"""Create application observation from a config.
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
"""
Create an application observation from a configuration schema.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:param config: Configuration schema containing the necessary information for the application observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this application's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed application observation instance.
:rtype: ApplicationObservation
"""
return cls(where=parent_where + ["services", config["application_name"]])
@classmethod
def _categorise_num_executions(cls, num_executions: int) -> int:
"""
Categorise the number of executions of an application.
Helps classify the number of application executions into different categories.
Current categories:
- 0: Application is never executed
- 1: Application is executed a low number of times (1-5)
- 2: Application is executed often (6-10)
- 3: Application is executed a high number of times (more than 10)
:param: num_executions: Number of times the application is executed
"""
if num_executions > 10:
return 3
elif num_executions > 5:
return 2
elif num_executions > 0:
return 1
return 0
return cls(where=parent_where + ["applications", config.application_name])

View File

@@ -26,19 +26,25 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, Type
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing_extensions import Never
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentActionHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -58,7 +64,7 @@ class AbstractReward:
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -98,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
file_name,
]
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -106,7 +112,7 @@ class DatabaseFileIntegrity(AbstractReward):
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
_LOGGER.info(
_LOGGER.debug(
f"Could not calculate {self.__class__} reward because "
"simulation state did not contain enough information."
)
@@ -153,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
"""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -203,19 +209,30 @@ class WebpageUnavailablePenalty(AbstractReward):
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
"""
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""
Calculate the reward based on current simulation state.
Calculate the reward based on current simulation state, and the recent agent action.
:param state: The current state of the simulation.
:type state: Dict
When the green agent requests to execute the browser application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request did actually go through, then check if the webpage also loaded
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
_LOGGER.info(
_LOGGER.debug(
"Web browser reward could not be calculated because the web browser history on node",
f"{self._node} was not reported in the simulation state. Returning 0.0",
)
@@ -252,19 +269,32 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
"""
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
def calculate(self, state: Dict) -> float:
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""
Calculate the reward based on current simulation state.
Calculate the reward based on current simulation state, and the recent agent action.
:param state: The current state of the simulation.
:type state: Dict
When the green agent requests to execute the database client application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
@@ -284,6 +314,51 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
return cls(node_hostname=node_hostname)
class SharedReward(AbstractReward):
"""Adds another agent's reward to the overall reward."""
def __init__(self, agent_name: Optional[str] = None) -> None:
"""
Initialise the shared reward.
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
correctly.
:param agent_name: The name whose reward is an input
:type agent_name: Optional[str]
"""
self.agent_name = agent_name
"""Agent whose reward to track."""
def default_callback(agent_name: str) -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
return self.callback(self.agent_name)
@classmethod
def from_config(cls, config: Dict) -> "SharedReward":
"""
Build the SharedReward object from config.
:param config: Configuration dictionary
:type config: Dict
"""
agent_name = config.get("agent_name")
return cls(agent_name=agent_name)
class RewardFunction:
"""Manages the reward function for the agent."""
@@ -293,6 +368,7 @@ class RewardFunction:
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
}
"""List of reward class identifiers."""
@@ -313,7 +389,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def update(self, state: Dict) -> float:
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
@@ -323,7 +399,7 @@ class RewardFunction:
for comp_and_weight in self.reward_components:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
self.current_reward = total
return self.current_reward

View File

@@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reset_agent_for_episode()
self.setup_agent()
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent):
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def reset_agent_for_episode(self) -> None:
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
super().reset_agent_for_episode()
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)

View File

@@ -1,8 +1,13 @@
from typing import Dict, Tuple
import random
from typing import Dict, Optional, Tuple
from gymnasium.core import ObsType
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
class RandomAgent(AbstractScriptedAgent):
@@ -19,3 +24,60 @@ class RandomAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.action_manager.space.sample())
class PeriodicAgent(AbstractScriptedAgent):
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
class Settings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
start_step: int = 20
"The timestep at which an agent begins performing it's actions."
start_variance: int = 5
"Deviation around the start step."
frequency: int = 5
"The number of timesteps to wait between performing actions."
variance: int = 0
"The amount the frequency can randomly change to."
max_executions: int = 999999
"Maximum number of times the agent can execute its action."
def __init__(
self,
agent_name: str,
action_space: ActionManager,
observation_space: ObservationManager,
reward_function: RewardFunction,
settings: Optional[Settings] = None,
) -> None:
"""Initialise PeriodicAgent."""
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.settings = settings or PeriodicAgent.Settings()
self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance)
self.num_executions = 0
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep when the next execute action should be taken.
:type timestep: int
:param variance: Uniform random variance applied to the timestep
:type variance: int
"""
random_increment = random.randint(-variance, variance)
self.next_execution_timestep = timestep + random_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
self.num_executions += 1
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
return "DONOTHING", {}

View File

@@ -0,0 +1,78 @@
import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class TAP001(AbstractScriptedAgent):
"""
TAP001 | Mobile Malware -- Ransomware Variant.
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
next_execution_timestep: int = 0
starting_node_idx: int = 0
installed: bool = False
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute the ransomware application.
This application acts a wrapper around the kill-chain, similar to green-analyst and
the previous UC2 data manipulation bot.
:param obs: Current observation for this agent.
:type obs: ObsType
:param timestep: The current simulation timestep, used for scheduling actions
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
if not self.installed:
self.installed = True
return "NODE_APPLICATION_INSTALL", {
"node_id": self.starting_node_idx,
"application_name": "RansomwareScript",
"ip_address": self.ip_address,
}
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
for n, act in self.action_manager.action_map.items():
if not act[0] == "NODE_APPLICATION_INSTALL":
continue
if act[1]["node_id"] == self.starting_node_idx:
self.ip_address = act[1]["ip_address"]
return
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Hashable, Sequence
from typing import Any, Dict, Hashable, Optional, Sequence
NOT_PRESENT_IN_STATE = object()
"""
@@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
"""
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any:
"""
Access an item from a deeply dictionary with a list of keys.
@@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
:return: The value in the dictionary
:rtype: Any
"""
if keys is None:
return NOT_PRESENT_IN_STATE
key_list = [*keys] # copy keys to a new list to prevent editing original list
if len(key_list) == 0:
return dictionary

View File

@@ -1,6 +1,6 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
@@ -8,22 +8,28 @@ from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator.network.airspace import AIR_SPACE
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
@@ -41,6 +47,7 @@ APPLICATION_TYPES_MAPPING = {
"DatabaseClient": DatabaseClient,
"DataManipulationBot": DataManipulationBot,
"DoSBot": DoSBot,
"RansomwareScript": RansomwareScript,
}
"""List of available applications that can be installed on nodes in the PrimAITE Simulation."""
@@ -100,21 +107,12 @@ class PrimaiteGame:
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
self.ref_map_nodes: Dict[str, str] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, str] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, str] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, str] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
self._reward_calculation_order: List[str] = [name for name in self.agents]
"""Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards."""
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -135,49 +133,55 @@ class PrimaiteGame:
"""
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
self.pre_timestep()
if self.step_counter == 0:
state = self.get_sim_state()
for agent in self.agents.values():
agent.update_observation(state=state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state, and the response from the last action
self.update_agents(state=sim_state)
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for _, agent in self.agents.items():
agent.update_observation(state)
agent.update_reward(state)
for agent_name in self._reward_calculation_order:
agent = self.agents[agent_name]
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
agent.reward_function.total_reward += agent.reward_function.current_reward
def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]:
"""
Apply all actions to simulation as requests.
:return: A recap of each agent's actions, in CAOS format.
:rtype: Dict[str, Tuple[str, Dict]]
"""
agent_actions = {}
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
request = agent.format_request(action_choice, options)
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
request = agent.format_request(action_choice, parameters)
response = self.simulation.apply_request(request)
agent_actions[agent.agent_name] = {
"action": action_choice,
"parameters": options,
"response": response.model_dump(),
}
return agent_actions
agent.process_action_response(
timestep=self.step_counter,
action=action_choice,
parameters=parameters,
request=request,
response=response,
)
def pre_timestep(self) -> None:
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
self.simulation.pre_timestep(self.step_counter)
def advance_timestep(self) -> None:
"""Advance timestep."""
@@ -206,8 +210,8 @@ class PrimaiteGame:
"""Create a PrimaiteGame object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteGame.
1. io_settings: options for logging data during training
2. game_config: options for the game itself, such as agents.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.
@@ -218,6 +222,7 @@ class PrimaiteGame:
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
AIR_SPACE.clear()
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
@@ -233,7 +238,6 @@ class PrimaiteGame:
links_cfg = network_config.get("links", [])
for node_cfg in nodes_cfg:
node_ref = node_cfg["ref"]
n_type = node_cfg["type"]
if n_type == "computer":
new_node = Computer(
@@ -269,18 +273,29 @@ class PrimaiteGame:
new_node = Router.from_config(node_cfg)
elif n_type == "firewall":
new_node = Firewall.from_config(node_cfg)
elif n_type == "wireless_router":
new_node = WirelessRouter.from_config(node_cfg)
elif n_type == "printer":
new_node = Printer(
hostname=node_cfg["hostname"],
ip_address=node_cfg["ip_address"],
subnet_mask=node_cfg["subnet_mask"],
operating_state=NodeOperatingState.ON
if not (p := node_cfg.get("operating_state"))
else NodeOperatingState[p.upper()],
)
else:
_LOGGER.warning(f"invalid node type {n_type} in config")
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
raise ValueError(msg)
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
service_ref = service_cfg["ref"]
service_type = service_cfg["type"]
if service_type in SERVICE_TYPES_MAPPING:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type])
new_service = new_node.software_manager.software[service_type]
game.ref_map_services[service_ref] = new_service.uuid
# start the service
new_service.start()
@@ -316,13 +331,11 @@ class PrimaiteGame:
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
new_application = None
application_ref = application_cfg["ref"]
application_type = application_cfg["type"]
if application_type in APPLICATION_TYPES_MAPPING:
new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type])
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application.uuid
else:
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
@@ -341,6 +354,19 @@ class PrimaiteGame:
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
elif application_type == "RansomwareScript":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
c2_beacon_p_of_success=float(opt.get("c2_beacon_p_of_success", "0.5")),
target_scan_p_of_success=float(opt.get("target_scan_p_of_success", "0.1")),
ransomware_encrypt_p_of_success=float(
opt.get("ransomware_encrypt_p_of_success", "0.1")
),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:
opt = application_cfg["options"]
@@ -376,7 +402,6 @@ class PrimaiteGame:
# run through the power on step if the node is to be turned on at the start
if new_node.operating_state == NodeOperatingState.ON:
new_node.power_on()
game.ref_map_nodes[node_ref] = new_node.uuid
# set start up and shut down duration
new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3))
@@ -384,8 +409,9 @@ class PrimaiteGame:
# 2. create links between nodes
for link_cfg in links_cfg:
node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
node_b = net.nodes[game.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"])
node_b = net.get_node_by_hostname(link_cfg["endpoint_b_hostname"])
if isinstance(node_a, Switch):
endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]]
else:
@@ -394,8 +420,7 @@ class PrimaiteGame:
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
else:
endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]]
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
game.ref_map_links[link_cfg["ref"]] = new_link.uuid
net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
# 3. create agents
agents_cfg = cfg.get("agents", [])
@@ -408,7 +433,7 @@ class PrimaiteGame:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg, game)
obs_space = ObservationManager.from_config(observation_space_cfg)
# CREATE ACTION SPACE
action_space = ActionManager.from_config(game, action_space_cfg)
@@ -427,6 +452,16 @@ class PrimaiteGame:
reward_function=reward_function,
settings=settings,
)
elif agent_type == "PeriodicAgent":
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
new_agent = PeriodicAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
settings=settings,
)
elif agent_type == "ProxyAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = ProxyAgent(
@@ -447,13 +482,64 @@ class PrimaiteGame:
reward_function=reward_function,
agent_settings=agent_settings,
)
elif agent_type == "TAP001":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = TAP001(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
else:
msg = f"Configuration error: {agent_type} is not a valid agent type."
_LOGGER.error(msg)
raise ValueError(msg)
game.agents[agent_cfg["ref"]] = new_agent
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
return game
def setup_reward_sharing(self):
"""Do necessary setup to enable reward sharing between agents.
This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1
depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop.
Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is
sharing their reward. This callback is provided by this setup method.
Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards
that rely on the value of another reward are evaluated later.
:raises RuntimeError: If the reward sharing is specified with a cyclic dependency.
"""
# construct dependency graph in the reward sharing between agents.
graph = {}
for name, agent in self.agents.items():
graph[name] = set()
for comp, weight in agent.reward_function.reward_components:
if isinstance(comp, SharedReward):
comp: SharedReward
graph[name].add(comp.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
if graph_has_cycle(graph):
raise RuntimeError(
(
"Detected cycle in agent reward sharing. Check the agent reward function ",
"configuration: reward sharing can only go one way.",
)
)
# sort the agents so the rewards that depend on other rewards are always evaluated later
self._reward_calculation_order = topological_sort(graph)

View File

@@ -1,4 +1,5 @@
from random import random
from typing import Any, Iterable, Mapping
def simulate_trial(p_of_success: float) -> bool:
@@ -14,3 +15,80 @@ def simulate_trial(p_of_success: float) -> bool:
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
"""
return random() < p_of_success
def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool:
"""Detect cycles in a directed graph.
Provide the graph as a dictionary that describes which nodes are linked. For example:
{0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0
{'a': ('b','c'), c:('b')} here there is no cycle
:param graph: a mapping from node to a set of nodes to which it is connected.
:type graph: Mapping[Any, Iterable[Any]]
:return: Whether the graph has any cycles
:rtype: bool
"""
visited = set()
currently_visiting = set()
def depth_first_search(node: Any) -> bool:
"""Perform depth-first search (DFS) traversal to detect cycles starting from a given node."""
if node in currently_visiting:
return True # Cycle detected
if node in visited:
return False # Already visited, no need to explore further
visited.add(node)
currently_visiting.add(node)
for neighbour in graph.get(node, []):
if depth_first_search(neighbour):
return True # Cycle detected
currently_visiting.remove(node)
return False
# Start DFS traversal from each node
for node in graph:
if depth_first_search(node):
return True # Cycle detected
return False # No cycles found
def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
"""
Perform topological sorting on a directed graph.
This guarantees that if there's a directed edge from node A to node B, then A appears before B.
:param graph: A dictionary representing the directed graph, where keys are node identifiers
and values are lists of outgoing edges from each node.
:type graph: dict[int, list[Any]]
:return: A topologically sorted list of node identifiers.
:rtype: list[Any]
"""
visited: set[Any] = set()
stack: list[Any] = []
def dfs(node: Any) -> None:
"""
Depth-first search traversal to visit nodes and their neighbors.
:param node: The current node to visit.
:type node: Any
"""
if node in visited:
return
visited.add(node)
for neighbour in graph.get(node, []):
dfs(neighbour)
stack.append(node)
# Perform DFS traversal from each node
for node in graph:
dfs(node)
return stack

View File

@@ -1,7 +1,9 @@
from typing import Dict, ForwardRef, Literal
from typing import Dict, ForwardRef, List, Literal, Union
from pydantic import BaseModel, ConfigDict, StrictBool, validate_call
RequestFormat = List[Union[str, int, float]]
RequestResponse = ForwardRef("RequestResponse")
"""This makes it possible to type-hint RequestResponse.from_bool return type."""

View File

@@ -1,47 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import data_manipulation_config_path, load
from primaite.session.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
_LOGGER = getLogger(__name__)
def run(
config_path: Optional[Union[str, Path]] = "",
agent_load_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config")
args = parser.parse_args()
if not args.config:
args.config = data_manipulation_config_path()
run(args.config)

View File

@@ -22,6 +22,7 @@
"# Imports\n",
"\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"import yaml\n",
"from pprint import pprint"
@@ -62,12 +63,12 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info[0]\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str"
]
@@ -361,7 +362,7 @@
" cfg = yaml.safe_load(f)\n",
" cfg['simulation']['network']\n",
" for node in cfg['simulation']['network']['nodes']:\n",
" if node['ref'] in ['client_1', 'client_2']:\n",
" if node['hostname'] in ['client_1', 'client_2']:\n",
" node['applications'] = change['applications']\n",
"\n",
"env = PrimaiteGymEnv(game_config = cfg)\n",
@@ -406,7 +407,7 @@
" cfg = yaml.safe_load(f)\n",
" cfg['simulation']['network']\n",
" for node in cfg['simulation']['network']['nodes']:\n",
" if node['ref'] in ['client_1', 'client_2']:\n",
" if node['hostname'] in ['client_1', 'client_2']:\n",
" node['applications'] = change['applications']\n",
"\n",
"env = PrimaiteGymEnv(game_config = cfg)\n",

View File

@@ -208,7 +208,7 @@
"|--|--|\n",
"|0|UNUSED|\n",
"|1|GOOD|\n",
"|2|PATCHING|\n",
"|2|FIXING|\n",
"|3|COMPROMISED|\n",
"|4|OVERWHELMED|\n",
"\n",
@@ -352,7 +352,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"tags": []
},
@@ -364,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"tags": []
},
@@ -373,7 +373,7 @@
"# Imports\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.game.game import PrimaiteGame\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"import yaml\n",
"from pprint import pprint\n"
]
@@ -389,162 +389,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-03-13 16:52:48,201: Resetting environment, episode 0, avg. reward: 0.0\n",
"2024-03-13 16:52:48,205: Saving agent action log to C:\\Users\\NickTodd\\primaite\\3.0.0b6\\sessions\\2024-03-13\\16-52-48\\agent_actions\\episode_0.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"env created successfully\n",
"{'ACL': {1: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 0,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 2: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 1,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 3: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 2,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 4: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 3,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 5: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 4,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 6: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 5,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 7: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 6,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 8: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 7,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 9: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 8,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0},\n",
" 10: {'dest_node_id': 0,\n",
" 'dest_port': 0,\n",
" 'permission': 0,\n",
" 'position': 9,\n",
" 'protocol': 0,\n",
" 'source_node_id': 0,\n",
" 'source_port': 0}},\n",
" 'ICS': 0,\n",
" 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n",
" 2: {'PROTOCOLS': {'ALL': 1}},\n",
" 3: {'PROTOCOLS': {'ALL': 1}},\n",
" 4: {'PROTOCOLS': {'ALL': 1}},\n",
" 5: {'PROTOCOLS': {'ALL': 1}},\n",
" 6: {'PROTOCOLS': {'ALL': 1}},\n",
" 7: {'PROTOCOLS': {'ALL': 1}},\n",
" 8: {'PROTOCOLS': {'ALL': 1}},\n",
" 9: {'PROTOCOLS': {'ALL': 1}},\n",
" 10: {'PROTOCOLS': {'ALL': 0}}},\n",
" 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 1},\n",
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n",
" 'operating_status': 1},\n",
" 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 1},\n",
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n",
" 'operating_status': 1},\n",
" 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n",
" 'health_status': 1}},\n",
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 1},\n",
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 1},\n",
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 1},\n",
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 1},\n",
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1},\n",
" 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n",
" 'health_status': 0}},\n",
" 'NICS': {1: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 1},\n",
" 2: {'NMNE': {'inbound': 0, 'outbound': 0},\n",
" 'nic_status': 0}},\n",
" 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n",
" 'operating_status': 1}}}\n"
]
}
],
"outputs": [],
"source": [
"# create the env\n",
"with open(data_manipulation_config_path(), 'r') as f:\n",
@@ -565,20 +412,9 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"res = FileSystemItemHealthStatus.GOOD\n",
"res = FileSystemItemHealthStatus.GOOD\n",
"res = FileSystemItemHealthStatus.COMPROMISED\n",
"res = FileSystemItemHealthStatus.COMPROMISED\n"
]
}
],
"outputs": [],
"source": [
"# Test NODE_FOLDER_CHECKHASH\n",
"res = env.game.simulation.network.get_node_by_hostname('database_server').file_system.get_folder(folder_name = 'database').health_status\n",
@@ -618,12 +454,12 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info[0]\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
" client = \"client 1\" if red_info[1]['node_id'] == 0 else \"client 2\"\n",
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str"
]
@@ -643,7 +479,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now the reward is -1, let's have a look at blue agent's observation."
"Now the reward is -0.8, let's have a look at blue agent's observation."
]
},
{
@@ -704,9 +540,9 @@
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n",
"print(f\"Blue reward:{reward}\" )"
]
},
@@ -714,7 +550,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The patching takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n",
"The fixing takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n",
"\n",
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n",
"\n",
@@ -727,9 +563,9 @@
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
"obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
"print(f\"Blue reward:{reward:.2f}\" )"
@@ -751,24 +587,26 @@
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(50) # Block client 1\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(51) # Block client 2\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"for step in range(30):\n",
"while abs(reward - 0.8) > 1e-5:\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
" if env.game.step_counter > 10000:\n",
" break # make sure there's no infinite loop if something went wrong"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, even though the red agent executes an attack, the reward stays at 0.8."
"Now, even though the red agent executes an attack, the reward will stay at 0.8."
]
},
{
@@ -784,7 +622,7 @@
"metadata": {},
"outputs": [],
"source": [
"obs['ACL']"
"obs['NODES']['ROUTER0']"
]
},
{
@@ -800,13 +638,30 @@
"metadata": {},
"outputs": [],
"source": [
"if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 1 has NMNEs, let's unblock client 2\n",
" env.step(58) # remove ACL rule 6\n",
"elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" env.step(57) # remove ACL rule 5\n",
"else:\n",
" print(\"something went wrong, neither client has NMNEs\")"
"env.step(58) # Remove the ACL rule that blocks client 1\n",
"env.step(57) # Remove the ACL rule that blocks client 2\n",
"\n",
"tries = 0\n",
"while True:\n",
" tries += 1\n",
" obs, reward, terminated, truncated, info = env.step(0)\n",
"\n",
" if obs['NODES']['HOST5']['NICS'][1]['NMNE']['outbound'] == 1:\n",
" # client 1 has NMNEs, let's block it\n",
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
" print(\"blocking client 1\")\n",
" break\n",
" elif obs['NODES']['HOST6']['NICS'][1]['NMNE']['outbound'] == 1:\n",
" # client 2 has NMNEs, so let's block it\n",
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
" print(\"blocking client 2\")\n",
" break\n",
" if tries>100:\n",
" print(\"Error: NMNE never increased\")\n",
" break\n",
"\n",
"env.step(13) # Patch the database\n",
"print()\n"
]
},
{
@@ -824,14 +679,14 @@
"source": [
"for step in range(30):\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)"
]
},
{

View File

@@ -1,5 +1,21 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training an SB3 Agent\n",
"\n",
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### First, we import the inital packages and read in our configuration file."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -27,7 +43,14 @@
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n"
" cfg = yaml.safe_load(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the given configuration, we generate the environment our agent will train in."
]
},
{
@@ -40,12 +63,10 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO"
"Lets define training parameters for the agent."
]
},
{
@@ -54,7 +75,13 @@
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym)\n"
"from stable_baselines3 import PPO\n",
"\n",
"EPISODE_LEN = 128\n",
"NUM_EPISODES = 10\n",
"NO_STEPS = EPISODE_LEN * NUM_EPISODES\n",
"BATCH_SIZE = 32\n",
"LEARNING_RATE = 3e-4"
]
},
{
@@ -63,7 +90,14 @@
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=10)\n"
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the agent configured, let's train for our defined number of episodes."
]
},
{
@@ -72,7 +106,14 @@
"metadata": {},
"outputs": [],
"source": [
"model.save(\"deleteme\")"
"model.learn(total_timesteps=NO_STEPS)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's save the agent to a zip file that can be used in future evaluation."
]
},
{
@@ -80,7 +121,44 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"model.save(\"PrimAITE-PPO-example-agent\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we load the saved agent and run it in evaluation mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_model = PPO(\"MlpPolicy\", gym)\n",
"eval_model = PPO.load(\"PrimAITE-PPO-example-agent\", gym)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, evaluate the agent."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"\n",
"evaluate_policy(eval_model, gym, n_eval_episodes=10)"
]
}
],
"metadata": {
@@ -99,7 +177,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -26,8 +26,13 @@ class PrimaiteGymEnv(gymnasium.Env):
def __init__(self, game_config: Dict):
"""Initialise the environment."""
super().__init__()
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game_config: Dict = game_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
@@ -36,9 +41,6 @@ class PrimaiteGymEnv(gymnasium.Env):
self.episode_counter: int = 0
"""Current episode number."""
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
@@ -46,37 +48,36 @@ class PrimaiteGymEnv(gymnasium.Env):
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
# make ProxyAgent store the action chosen by the RL policy
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
agent_actions = self.game.apply_agent_actions()
self.game.pre_timestep()
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.game.calculate_truncated()
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
info = {
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
self._write_step_metadata_json(step, action, state, reward)
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
def _write_step_metadata_json(self, step: int, action: int, state: Dict, reward: int):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": self.game.step_counter,
"step": step,
"action": int(action),
"reward": int(reward),
"state": state,
@@ -91,13 +92,13 @@ class PrimaiteGymEnv(gymnasium.Env):
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
self.game.update_agents(state)
self.game.update_agents(state=state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@@ -124,6 +125,12 @@ class PrimaiteGymEnv(gymnasium.Env):
else:
return self.agent.observation_manager.current_observation
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
@@ -147,6 +154,10 @@ class PrimaiteRayEnv(gymnasium.Env):
"""Perform a step in the environment."""
return self.env.step(action)
def close(self):
"""Close the simulation."""
self.env.close()
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
@@ -160,6 +171,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
"""
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
@@ -179,9 +192,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
super().__init__()
@property
@@ -192,8 +202,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
@@ -214,10 +224,12 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
step = self.game.step_counter
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
agent_actions = self.game.apply_agent_actions()
self.game.pre_timestep()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -235,22 +247,18 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
self._write_step_metadata_json(step, actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": self.game.step_counter,
"step": step,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
@@ -267,3 +275,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)

View File

@@ -29,10 +29,12 @@ class PrimaiteIO:
"""Whether to save a log of all agents' actions every step."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
save_pcap_logs: bool = False
save_pcap_logs: bool = True
"""Whether to save PCAP logs."""
save_sys_logs: bool = False
save_sys_logs: bool = True
"""Whether to save system logs."""
write_sys_log_to_terminal: bool = False
"""Whether to write the sys log to the terminal."""
def __init__(self, settings: Optional[Settings] = None) -> None:
"""
@@ -47,8 +49,7 @@ class PrimaiteIO:
SIM_OUTPUT.path = self.session_path / "simulation_output"
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
self.agent_action_log: List[Dict] = []
SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
@@ -72,51 +73,29 @@ class PrimaiteIO:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
"""Cache agent actions for a particular step.
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
CAOS action is a string representing one the CAOS actions.
parameters is a dict of parameter names and values for that particular CAOS action.
For example:
{
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
'defender': ('DO_NOTHING', {})
}
:type agent_actions: Dict
:param timestep: Simulation timestep when these actions occurred.
:type timestep: int
"""
self.agent_action_log.append(
[
{
"episode": episode,
"timestep": timestep,
"agent_actions": agent_actions,
}
]
)
def write_agent_actions(self, episode: int) -> None:
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number
:type episode: int
"""
data = {}
longest_history = max([len(hist) for hist in agent_actions.values()])
for i in range(longest_history):
data[i] = {"timestep": i, "episode": episode}
data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i})
path = self.generate_agent_actions_save_path(episode=episode)
path.parent.mkdir(exist_ok=True, parents=True)
path.touch()
_LOGGER.info(f"Saving agent action log to {path}")
with open(path, "w") as file:
json.dump(self.agent_action_log, fp=file, indent=1)
def clear_agent_actions(self) -> None:
"""Reset the agent action log back to an empty dictionary."""
self.agent_action_log = []
json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump())
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":
"""Create an instance of PrimaiteIO based on a configuration dict."""
new = cls()
config = config or {}
new = cls(settings=cls.Settings(**config))
return new

View File

@@ -1,4 +0,0 @@
from primaite.session.policy.rllib import RaySingleAgentPolicy
from primaite.session.policy.sb3 import SB3Policy
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]

View File

@@ -1,82 +0,0 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, Type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
:param name: Identifier used by from_config to create an instance of the policy.
:type name: str
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate policy name {identifier}")
cls._registry[identifier] = cls
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
"""
Initialize a reinforcement learning policy.
:param session: The session context.
:type session: PrimaiteSession
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass
@abstractmethod
def save(self, save_path: Path) -> None:
"""Save the agent."""
pass
@abstractmethod
def load(self) -> None:
"""Load agent from a file."""
pass
def close(self) -> None:
"""Close the agent."""
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
Subclasses should not call super().from_config(), they should just handle creation form config.
"""
# Assume that basically the contents of training_config are passed into here.
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)

View File

@@ -1,111 +0,0 @@
from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
from primaite import getLogger
_LOGGER = getLogger(__name__)
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
self.config = {
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
"num_rollout_workers": 0,
}
ray.shutdown()
ray.init()
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
self.config["training_iterations"] = n_episodes * timesteps_per_episode
self.config["train_batch_size"] = 128
self._algo = ppo.PPO(config=self.config)
_LOGGER.info("Starting RLLIB training session")
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
for ep in range(n_episodes):
obs, info = self.session.env.reset()
for step in range(self.session.game.options.max_episode_length):
action = self._algo.compute_single_action(observation=obs, explore=False)
obs, rew, term, trunc, info = self.session.env.step(action)
def save(self, save_path: Path) -> None:
"""Save the policy to a file."""
self._algo.save(save_path)
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
"""Mutli agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
"""Initialise multi agent policy wrapper."""
super().__init__(session=session)
self.config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in session.game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate trained policy."""
return NotImplemented
def save(self, save_path: Path) -> None:
"""Save policy parameters to a file."""
return NotImplemented
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
"""Create policy from config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,79 +0,0 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
self._agent_class: Type[Union[PPO, A2C]]
if algorithm == "PPO":
self._agent_class = PPO
policy = PPO_MLP
elif algorithm == "A2C":
self._agent_class = A2C
policy = A2C_MLP
else:
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
self._agent = self._agent_class(
policy=policy,
env=self.session.env,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)
else:
checkpoint_callback = None
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
_ = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
def save(self, save_path: Path) -> None:
"""
Save the current policy parameters.
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
def load(self, model_path: Path) -> None:
"""Load agent from a checkpoint."""
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,119 +0,0 @@
# raise DeprecationWarning("This module is deprecated")
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import PrimaiteIO
# from primaite.game.game import PrimaiteGame
from primaite.session.policy.policy import PolicyABC
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self, game_cfg: Dict):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
"""The environment that the RL algorithm can consume."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager: Optional["PrimaiteIO"] = None
"""IO manager for the session."""
self.game_cfg: Dict = game_cfg
"""Primaite Game object for managing main simulation loop and agents."""
self.save_checkpoints: bool = False
"""Whether to save checkpoints."""
self.checkpoint_interval: int = 10
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
def start_session(self) -> None:
"""Commence the training/eval session."""
print("Starting Primaite Session")
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
sess = cls(game_cfg=cfg)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config=cfg)
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game_config=cfg)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess

View File

@@ -14,6 +14,7 @@ class _SimOutput:
)
self.save_pcap_logs: bool = False
self.save_sys_logs: bool = False
self.write_sys_log_to_terminal: bool = False
@property
def path(self) -> Path:

View File

@@ -6,7 +6,7 @@
"source": [
"# Build a simulation using the Python API\n",
"\n",
"Currently, this notbook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
]
},
{
@@ -58,7 +58,8 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.hardware.base import Node\n"
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
"from primaite.simulator.network.hardware.nodes.host.server import Server"
]
},
{
@@ -67,9 +68,9 @@
"metadata": {},
"outputs": [],
"source": [
"my_pc = Node(hostname=\"primaite_pc\",)\n",
"my_pc = Computer(hostname=\"Computer\", ip_address=\"192.168.1.10\", subnet_mask=\"255.255.255.0\")\n",
"net.add_node(my_pc)\n",
"my_server = Node(hostname=\"google_server\")\n",
"my_server = Server(hostname=\"Server\", ip_address=\"192.168.1.11\", subnet_mask=\"255.255.255.0\")\n",
"net.add_node(my_server)\n"
]
},
@@ -86,7 +87,8 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.hardware.base import NIC, Link, Switch\n"
"from primaite.simulator.network.hardware.nodes.host.host_node import NIC\n",
"from primaite.simulator.network.hardware.nodes.network.switch import Switch\n"
]
},
{
@@ -95,19 +97,17 @@
"metadata": {},
"outputs": [],
"source": [
"my_swtich = Switch(hostname=\"switch1\", num_ports=12)\n",
"net.add_node(my_swtich)\n",
"my_switch = Switch(hostname=\"switch1\", num_ports=12)\n",
"net.add_node(my_switch)\n",
"\n",
"pc_nic = NIC(ip_address=\"130.1.1.1\", gateway=\"130.1.1.255\", subnet_mask=\"255.255.255.0\")\n",
"my_pc.connect_nic(pc_nic)\n",
"\n",
"\n",
"server_nic = NIC(ip_address=\"130.1.1.2\", gateway=\"130.1.1.255\", subnet_mask=\"255.255.255.0\")\n",
"my_server.connect_nic(server_nic)\n",
"\n",
"\n",
"net.connect(pc_nic, my_swtich.switch_ports[1])\n",
"net.connect(server_nic, my_swtich.switch_ports[2])\n"
"net.connect(pc_nic, my_switch.network_interface[1])\n",
"net.connect(server_nic, my_switch.network_interface[2])\n"
]
},
{
@@ -124,7 +124,8 @@
"outputs": [],
"source": [
"from primaite.simulator.file_system.file_type import FileType\n",
"from primaite.simulator.file_system.file_system import File"
"from primaite.simulator.file_system.file_system import File\n",
"from primaite.simulator.system.core.sys_log import SysLog"
]
},
{
@@ -134,7 +135,7 @@
"outputs": [],
"source": [
"my_pc_downloads_folder = my_pc.file_system.create_folder(\"downloads\")\n",
"my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",file_type=FileType.ZIP))"
"my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",folder_id=\"Test\", folder_name=\"downloads\" ,file_type=FileType.ZIP, sys_log=SysLog(hostname=\"Test\")))"
]
},
{
@@ -160,9 +161,12 @@
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n",
"from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.file_system.file_system import FileSystem\n",
"\n",
"# no applications exist yet so we will create our own.\n",
"class MSPaint(Application):\n",
@@ -176,7 +180,7 @@
"metadata": {},
"outputs": [],
"source": [
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, ports={Port.HTTP}, operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual')"
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port.HTTP, protocol = IPProtocol.NONE,operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
]
},
{
@@ -257,9 +261,8 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "03b2013a-b7d1-47ee-b08c-8dab83833720",
"id": "0",
"metadata": {},
"source": [
"# PrimAITE Router Simulation Demo\n",
@@ -12,7 +12,7 @@
},
{
"cell_type": "raw",
"id": "c8bb5698-e746-4e90-9c2f-efe962acdfa0",
"id": "1",
"metadata": {},
"source": [
" +------------+\n",
@@ -48,7 +48,7 @@
},
{
"cell_type": "markdown",
"id": "415d487c-6457-497d-85d6-99439b3541e7",
"id": "2",
"metadata": {},
"source": [
"## The Network\n",
@@ -60,7 +60,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "de57ac8c-5b28-4847-a759-2ceaf5593329",
"id": "3",
"metadata": {
"tags": []
},
@@ -72,7 +72,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a1e2e4df-67c0-4584-ab27-47e2c7c7fcd2",
"id": "4",
"metadata": {
"tags": []
},
@@ -83,7 +83,7 @@
},
{
"cell_type": "markdown",
"id": "fb052c56-e9ca-4093-9115-d0c440b5ff53",
"id": "5",
"metadata": {},
"source": [
"Most of the Network components have a `.show()` function that prints a table of information about that object. We can view the Nodes and Links on the Network by calling `network.show()`."
@@ -92,7 +92,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cc199741-ef2e-47f5-b2f0-e20049ccf40f",
"id": "6",
"metadata": {
"tags": []
},
@@ -103,7 +103,7 @@
},
{
"cell_type": "markdown",
"id": "76d2b7e9-280b-4741-a8b3-a84bed219fac",
"id": "7",
"metadata": {
"tags": []
},
@@ -115,7 +115,7 @@
},
{
"cell_type": "markdown",
"id": "84113002-843e-4cab-b899-667b50f25f6b",
"id": "8",
"metadata": {},
"source": [
"### Router Nodes\n",
@@ -125,7 +125,7 @@
},
{
"cell_type": "markdown",
"id": "bf63a178-eee5-4669-bf64-13aea7ecf6cb",
"id": "9",
"metadata": {},
"source": [
"Calling `router.show()` displays the Ethernet interfaces on the Router. If you need a table in markdown format, pass `markdown=True`."
@@ -134,7 +134,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e76d1854-961e-438c-b40f-77fd9c3abe38",
"id": "10",
"metadata": {
"tags": []
},
@@ -145,7 +145,7 @@
},
{
"cell_type": "markdown",
"id": "e000540c-687c-4254-870c-1d814603bdbf",
"id": "11",
"metadata": {},
"source": [
"Calling `router.arp.show()` displays the Router ARP Cache."
@@ -154,7 +154,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "92de8b42-92d7-4934-9c12-50bf724c9eb2",
"id": "12",
"metadata": {
"tags": []
},
@@ -165,7 +165,7 @@
},
{
"cell_type": "markdown",
"id": "a9ff7ee8-9482-44de-9039-b684866bdc82",
"id": "13",
"metadata": {},
"source": [
"Calling `router.acl.show()` displays the Access Control List."
@@ -174,7 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5922282a-d22b-4e55-9176-f3f3654c849f",
"id": "14",
"metadata": {
"tags": []
},
@@ -185,7 +185,7 @@
},
{
"cell_type": "markdown",
"id": "71c87884-f793-4c9f-b004-5b0df86cf585",
"id": "15",
"metadata": {},
"source": [
"Calling `router.router_table.show()` displays the static routes the Router provides."
@@ -194,7 +194,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "327203be-f475-4727-82a1-e992d3b70ed8",
"id": "16",
"metadata": {
"tags": []
},
@@ -205,7 +205,7 @@
},
{
"cell_type": "markdown",
"id": "eef561a8-3d39-4c8b-bbc8-e8b10b8ed25f",
"id": "17",
"metadata": {},
"source": [
"Calling `router.sys_log.show()` displays the Router system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
@@ -214,7 +214,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "3d0aa004-b10c-445f-aaab-340e0e716c74",
"id": "18",
"metadata": {
"tags": []
},
@@ -225,7 +225,7 @@
},
{
"cell_type": "markdown",
"id": "25630c90-c54e-4b5d-8bf4-ad1b0722e126",
"id": "19",
"metadata": {},
"source": [
"### Switch Nodes\n",
@@ -235,16 +235,16 @@
},
{
"cell_type": "markdown",
"id": "4879394d-2981-40de-a229-e19b09a34e6e",
"id": "20",
"metadata": {},
"source": [
"Calling `switch.show()` displays the Switch orts on the Switch."
"Calling `switch.show()` displays the Switch ports on the Switch."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7fd439b-5442-4e9d-9e7d-86dacb77f458",
"id": "21",
"metadata": {
"tags": []
},
@@ -255,29 +255,7 @@
},
{
"cell_type": "markdown",
"id": "beb8dbd6-7250-4ac9-9fa2-d2a9c0e5fd19",
"metadata": {
"tags": []
},
"source": [
"Calling `switch.arp.show()` displays the Switch ARP Cache."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d06e1310-4a77-4315-a59f-cb1b49ca2352",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"network.get_node_by_hostname(\"switch_1\").arp.show()"
]
},
{
"cell_type": "markdown",
"id": "fda75ac3-8123-4234-8f36-86547891d8df",
"id": "22",
"metadata": {},
"source": [
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
@@ -286,7 +264,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a0d984b7-a7c1-4bbd-aa5a-9d3caecb08dc",
"id": "23",
"metadata": {
"tags": []
},
@@ -297,7 +275,7 @@
},
{
"cell_type": "markdown",
"id": "2f1d99ad-db4f-4baf-8a35-e1d95f269586",
"id": "24",
"metadata": {},
"source": [
"### Computer/Server Nodes\n",
@@ -307,7 +285,7 @@
},
{
"cell_type": "markdown",
"id": "c9e2251a-1b47-46e5-840f-7fec3e39c5aa",
"id": "25",
"metadata": {
"tags": []
},
@@ -318,7 +296,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "656c37f6-b145-42af-9714-8d2886d0eff8",
"id": "26",
"metadata": {
"tags": []
},
@@ -329,7 +307,7 @@
},
{
"cell_type": "markdown",
"id": "f1097a49-a3da-4d79-a06d-ae8af452918f",
"id": "27",
"metadata": {},
"source": [
"Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
@@ -338,7 +316,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "66b267d6-2308-486a-b9aa-cb8d3bcf0753",
"id": "28",
"metadata": {
"tags": []
},
@@ -349,16 +327,16 @@
},
{
"cell_type": "markdown",
"id": "0d1fcad8-5b1a-4d8b-a49f-aa54a95fcaf0",
"id": "29",
"metadata": {},
"source": [
"Calling `switch.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
"Calling `computer.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b5debe8-ef1b-445d-8fa9-6a45568f21f3",
"id": "30",
"metadata": {
"tags": []
},
@@ -369,7 +347,7 @@
},
{
"cell_type": "markdown",
"id": "fcfa1773-798c-4ada-9318-c3ad928217da",
"id": "31",
"metadata": {},
"source": [
"## Basic Network Comms Check\n",
@@ -380,7 +358,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "495b7de4-b6ce-41a6-9114-f74752ab4491",
"id": "32",
"metadata": {
"tags": []
},
@@ -391,7 +369,7 @@
},
{
"cell_type": "markdown",
"id": "3e13922a-217f-4f4e-99b6-57a07613cade",
"id": "33",
"metadata": {},
"source": [
"We'll first ping client_1's default gateway."
@@ -400,7 +378,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a38abb71-994e-49e8-8f51-e9a550e95b99",
"id": "34",
"metadata": {
"tags": []
},
@@ -412,7 +390,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8388e1e9-30e3-4534-8e5a-c6e9144149d2",
"id": "35",
"metadata": {
"tags": []
},
@@ -423,7 +401,7 @@
},
{
"cell_type": "markdown",
"id": "02c76d5c-d954-49db-912d-cb9c52f46375",
"id": "36",
"metadata": {},
"source": [
"Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
@@ -432,7 +410,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ff8e976a-c16b-470c-8923-325713a30d6c",
"id": "37",
"metadata": {
"tags": []
},
@@ -443,7 +421,7 @@
},
{
"cell_type": "markdown",
"id": "80280404-a5ab-452f-8a02-771a0d7496b1",
"id": "38",
"metadata": {},
"source": [
"And finally, we'll ping the web server."
@@ -452,7 +430,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c4163f8d-6a72-410c-9f5c-4f881b7de45e",
"id": "39",
"metadata": {
"tags": []
},
@@ -463,7 +441,7 @@
},
{
"cell_type": "markdown",
"id": "1194c045-ba77-4427-be30-ed7b5b224850",
"id": "40",
"metadata": {},
"source": [
"To confirm that the ping was received and processed by the web_server, we can view the sys log"
@@ -472,7 +450,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e79a523a-5780-45b6-8798-c434e0e522bd",
"id": "41",
"metadata": {
"tags": []
},
@@ -483,17 +461,17 @@
},
{
"cell_type": "markdown",
"id": "5928f6dd-1006-45e3-99f3-8f311a875faa",
"id": "42",
"metadata": {},
"source": [
"## Advanced Network Usage\n",
"\n",
"We can now use the Network to perform some more advaced things."
"We can now use the Network to perform some more advanced things."
]
},
{
"cell_type": "markdown",
"id": "5e023ef3-7d18-4006-96ee-042a06a481fc",
"id": "43",
"metadata": {},
"source": [
"Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
@@ -502,7 +480,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "603cf913-e261-49da-a7dd-85e1bb6dec56",
"id": "44",
"metadata": {
"tags": []
},
@@ -513,7 +491,7 @@
},
{
"cell_type": "markdown",
"id": "5cf962a4-20e6-44ae-9748-7fc5267ae111",
"id": "45",
"metadata": {},
"source": [
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:"
@@ -522,7 +500,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e047de00-3de4-4823-b26a-2c8d64c7a663",
"id": "46",
"metadata": {
"tags": []
},
@@ -533,7 +511,7 @@
},
{
"cell_type": "markdown",
"id": "bdc4741d-6e3e-4aec-a69c-c2e9653bd02c",
"id": "47",
"metadata": {},
"source": [
"Now we'll add an ACL to block ICMP from 192.168.10.22"
@@ -542,7 +520,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6db355ae-b99a-441b-a2c4-4ffe78f46bff",
"id": "48",
"metadata": {
"tags": []
},
@@ -550,7 +528,7 @@
"source": [
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.network.hardware.nodes.router import ACLAction\n",
"from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n",
"network.get_node_by_hostname(\"router_1\").acl.add_rule(\n",
" action=ACLAction.DENY,\n",
" protocol=IPProtocol.ICMP,\n",
@@ -562,7 +540,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a345e000-8842-4827-af96-adc0fbe390fb",
"id": "49",
"metadata": {
"tags": []
},
@@ -573,7 +551,7 @@
},
{
"cell_type": "markdown",
"id": "3a5bfd9f-04cb-493e-a86c-cd268563a262",
"id": "50",
"metadata": {},
"source": [
"Now we attempt (and fail) to ping the web server"
@@ -582,7 +560,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a4f4ff31-590f-40fb-b13d-efaa8c2720b6",
"id": "51",
"metadata": {
"tags": []
},
@@ -593,7 +571,7 @@
},
{
"cell_type": "markdown",
"id": "83e56497-097b-45cb-964e-b15c72547b38",
"id": "52",
"metadata": {},
"source": [
"We can check that the ping was actually sent by client_2 by viewing the sys log"
@@ -602,7 +580,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f62b8a4e-fd3b-4059-b108-3d4a0b18f2a0",
"id": "53",
"metadata": {
"tags": []
},
@@ -613,7 +591,7 @@
},
{
"cell_type": "markdown",
"id": "c7040311-a879-4620-86a0-55d0774156e5",
"id": "54",
"metadata": {},
"source": [
"We can check the router sys log to see why the traffic was blocked"
@@ -622,7 +600,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "7e53d776-99da-4d2c-a2a7-bd7ce27bff4c",
"id": "55",
"metadata": {
"tags": []
},
@@ -633,7 +611,7 @@
},
{
"cell_type": "markdown",
"id": "aba0bc7d-da57-477b-b34a-3688b5aab2c6",
"id": "56",
"metadata": {},
"source": [
"Now a final check to ensure that client_1 can still ping the web_server."
@@ -642,7 +620,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d542734b-7582-4af7-8254-bda3de50d091",
"id": "57",
"metadata": {
"tags": []
},
@@ -654,7 +632,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d78e9fe3-02c6-4792-944f-5622e26e0412",
"id": "58",
"metadata": {
"tags": []
},

View File

@@ -7,12 +7,10 @@ from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.interface.request import RequestFormat, RequestResponse
_LOGGER = getLogger(__name__)
RequestFormat = List[Union[str, int, float]]
class RequestPermissionValidator(BaseModel):
"""
@@ -228,6 +226,15 @@ class SimComponent(BaseModel):
return
return self._request_manager(request, context)
def pre_timestep(self, timestep: int) -> None:
"""
Apply any logic that needs to happen at the beginning of the timestep to ensure correct observations/rewards.
:param timestep: what's the current time
:type timestep: int
"""
pass
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep evolution to this component.

View File

@@ -103,6 +103,10 @@ class File(FileSystemItemABC):
"""
super().apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
# reset the number of accesses to 0
self.num_access = 0

View File

@@ -427,15 +427,21 @@ class FileSystem(SimComponent):
"""Apply time step to FileSystem and its child folders and files."""
super().apply_timestep(timestep=timestep)
# apply timestep to folders
for folder_id in self.folders:
self.folders[folder_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
# reset number of file creations
self.num_file_creations = 0
# reset number of file deletions
self.num_file_deletions = 0
# apply timestep to folders
for folder_id in self.folders:
self.folders[folder_id].apply_timestep(timestep=timestep)
for folder in self.folders.values():
folder.pre_timestep(timestep)
###############################################################
# Agent actions

View File

@@ -128,6 +128,13 @@ class Folder(FileSystemItemABC):
for file_id in self.files:
self.files[file_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for file in self.files.values():
file.pre_timestep(timestep)
def _scan_timestep(self) -> None:
"""Apply the scan action timestep."""
if self.scan_countdown >= 0:

View File

@@ -157,7 +157,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
return
if not self._connected_node:
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
_LOGGER.warning(f"Interface {self} cannot be enabled as it is not connected to a Node")
return
if self._connected_node.operating_state != NodeOperatingState.ON:
@@ -271,7 +271,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
# Update the state with information from Layer3Interface
state.update(Layer3Interface.describe_state(self))
state["frequency"] = self.frequency
state["frequency"] = self.frequency.value
return state

View File

@@ -1,3 +1,4 @@
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional
import matplotlib.pyplot as plt
@@ -8,6 +9,7 @@ from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
@@ -85,6 +87,16 @@ class Network(SimComponent):
for link_id in self.links:
self.links[link_id].apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for node in self.nodes.values():
node.pre_timestep(timestep)
for link in self.links.values():
link.pre_timestep(timestep)
@property
def router_nodes(self) -> List[Node]:
"""The Routers in the Network."""
@@ -110,6 +122,16 @@ class Network(SimComponent):
"""The Firewalls in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"]
@property
def printer_nodes(self) -> List[Node]:
"""The printers on the network."""
return [node for node in self.nodes.values() if isinstance(node, Printer)]
@property
def wireless_router_nodes(self) -> List[Node]:
"""The Routers in the Network."""
return [node for node in self.nodes.values() if node.__class__.__name__ == "WirelessRouter"]
def show(self, nodes: bool = True, ip_addresses: bool = True, links: bool = True, markdown: bool = False):
"""
Print tables describing the Network.
@@ -128,6 +150,8 @@ class Network(SimComponent):
"Switch": self.switch_nodes,
"Server": self.server_nodes,
"Computer": self.computer_nodes,
"Printer": self.printer_nodes,
"Wireless Router": self.wireless_router_nodes,
}
if nodes:
table = PrettyTable(["Node", "Type", "Operating State"])
@@ -150,14 +174,17 @@ class Network(SimComponent):
for node in nodes:
for i, port in node.network_interface.items():
if hasattr(port, "ip_address"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
if port.ip_address != IPv4Address("127.0.0.1"):
port_str = port.port_name if port.port_name else port.port_num
table.add_row(
[node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway]
)
print(table)
if links:
table = PrettyTable(["Endpoint A", "Endpoint B", "is Up", "Bandwidth (MBits)", "Current Load"])
table = PrettyTable(
["Endpoint A", "A Port", "Endpoint B", "B Port", "is Up", "Bandwidth (MBits)", "Current Load"]
)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
@@ -170,7 +197,9 @@ class Network(SimComponent):
table.add_row(
[
link.endpoint_a.parent.hostname,
str(link.endpoint_a),
link.endpoint_b.parent.hostname,
str(link.endpoint_b),
link.is_up,
link.bandwidth,
link.current_load_percent,
@@ -208,18 +237,19 @@ class Network(SimComponent):
}
)
# Update the links one-by-one. The key is a 4-tuple of `hostname_a, port_a, hostname_b, port_b`
for uuid, link in self.links.items():
for _, link in self.links.items():
node_a = link.endpoint_a._connected_node
node_b = link.endpoint_b._connected_node
hostname_a = node_a.hostname if node_a else None
hostname_b = node_b.hostname if node_b else None
port_a = link.endpoint_a.port_num
port_b = link.endpoint_b.port_num
state["links"][uuid] = link.describe_state()
state["links"][uuid]["hostname_a"] = hostname_a
state["links"][uuid]["hostname_b"] = hostname_b
state["links"][uuid]["port_a"] = port_a
state["links"][uuid]["port_b"] = port_b
link_key = f"{hostname_a}:eth-{port_a}<->{hostname_b}:eth-{port_b}"
state["links"][link_key] = link.describe_state()
state["links"][link_key]["hostname_a"] = hostname_a
state["links"][link_key]["hostname_b"] = hostname_b
state["links"][link_key]["port_a"] = port_a
state["links"][link_key]["port_b"] = port_b
return state

View File

@@ -9,7 +9,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int:
"""
Calculate the minimum number of network switches required to connect a given number of nodes.
@@ -18,7 +18,7 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
to accommodate all nodes under this constraint.
:param num_nodes: The total number of nodes that need to be connected in the network.
:param max_switch_ports: The maximum number of ports available on each switch. Defaults to 24.
:param max_network_interface: The maximum number of ports available on each switch. Defaults to 24.
:return: The minimum number of switches required to connect all PCs.
@@ -33,11 +33,11 @@ def num_of_switches_required(num_nodes: int, max_switch_ports: int = 24) -> int:
3
"""
# Reduce the effective number of switch ports by 1 to leave space for the router
effective_switch_ports = max_switch_ports - 1
effective_network_interface = max_network_interface - 1
# Calculate the number of fully utilised switches and any additional switch for remaining PCs
full_switches = num_nodes // effective_switch_ports
extra_pcs = num_nodes % effective_switch_ports
full_switches = num_nodes // effective_network_interface
extra_pcs = num_nodes % effective_network_interface
# Return the total number of switches required
return full_switches + (1 if extra_pcs > 0 else 0)
@@ -77,7 +77,7 @@ def create_office_lan(
# Calculate the required number of switches
num_of_switches = num_of_switches_required(num_nodes=num_pcs)
effective_switch_ports = 23 # One port less for router connection
effective_network_interface = 23 # One port less for router connection
if pcs_ip_block_start <= num_of_switches:
raise ValueError(f"pcs_ip_block_start must be greater than the number of required switches {num_of_switches}")
@@ -116,7 +116,7 @@ def create_office_lan(
# Add PCs to the LAN and connect them to switches
for i in range(1, num_pcs + 1):
# Add a new edge switch if the current one is full
if switch_port == effective_switch_ports:
if switch_port == effective_network_interface:
switch_n += 1
switch_port = 0
switch = Switch(hostname=f"switch_edge_{switch_n}_{lan_name}", start_up_duration=0)

View File

@@ -5,7 +5,7 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
@@ -35,8 +35,11 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import IOSoftware
from primaite.utils.validators import IPV4Address
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
_LOGGER = getLogger(__name__)
@@ -108,7 +111,7 @@ class NetworkInterface(SimComponent, ABC):
"""Reset the original state of the SimComponent."""
super().setup_for_episode(episode=episode)
self.nmne = {}
if episode and self.pcap:
if episode and self.pcap and SIM_OUTPUT.save_pcap_logs:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
@@ -261,6 +264,9 @@ class NetworkInterface(SimComponent, ABC):
"""
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}"
def __hash__(self) -> int:
return hash(self.uuid)
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep evolution to this component.
@@ -297,7 +303,7 @@ class WiredNetworkInterface(NetworkInterface, ABC):
return True
if not self._connected_node:
_LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node")
_LOGGER.warning(f"Interface {self} cannot be enabled as it is not connected to a Node")
return False
if self._connected_node.operating_state != NodeOperatingState.ON:
@@ -343,11 +349,11 @@ class WiredNetworkInterface(NetworkInterface, ABC):
:param link: The Link instance to connect to this network interface.
"""
if self._connected_link:
_LOGGER.error(f"Cannot connect Link to network interface {self} as it already has a connection")
_LOGGER.warning(f"Cannot connect Link to network interface {self} as it already has a connection")
return
if self._connected_link == link:
_LOGGER.error(f"Cannot connect Link to network interface {self} as it is already connected")
_LOGGER.warning(f"Cannot connect Link to network interface {self} as it is already connected")
return
self._connected_link = link
@@ -519,12 +525,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
"""
super().enable()
try:
pass
self._connected_node.default_gateway_hello()
return True
except AttributeError:
pass
return False
return True
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:
@@ -660,6 +664,10 @@ class Link(SimComponent):
def apply_timestep(self, timestep: int) -> None:
"""Apply a timestep to the simulation."""
super().apply_timestep(timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.current_load = 0.0
@@ -845,12 +853,62 @@ class Node(SimComponent):
)
rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on))
self._software_request_manager = RequestManager()
rm.add_request("software_manager", RequestType(func=self._software_request_manager, validator=_node_is_on))
self._application_manager = RequestManager()
self._software_request_manager.add_request(
name="application", request_type=RequestType(func=self._application_manager)
)
self._application_manager.add_request(
name="install",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.application_install_action(
application=self._read_application_type(request[0]), ip_address=request[1]
)
)
),
)
self._application_manager.add_request(
name="uninstall",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.application_uninstall_action(application=self._read_application_type(request[0]))
)
),
)
return rm
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
pass
def _read_application_type(self, application_class_str: str) -> Type[IOSoftwareClass]:
"""Wrapper that converts the string from the request manager into the appropriate class for the application."""
if application_class_str == "DoSBot":
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
return DoSBot
elif application_class_str == "DataManipulationBot":
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
DataManipulationBot,
)
return DataManipulationBot
elif application_class_str == "WebBrowser":
from primaite.simulator.system.applications.web_browser import WebBrowser
return WebBrowser
elif application_class_str == "RansomwareScript":
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
return RansomwareScript
else:
return 0
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -891,8 +949,9 @@ class Node(SimComponent):
table.align = "l"
table.title = f"{self.hostname} Open Ports"
for port in self.software_manager.get_open_ports():
table.add_row([port.value, port.name])
print(table)
if port.value > 0:
table.add_row([port.value, port.name])
print(table.get_string(sortby="Port"))
@property
def has_enabled_network_interface(self) -> bool:
@@ -917,12 +976,15 @@ class Node(SimComponent):
table.align = "l"
table.title = f"{self.hostname} Network Interface Cards"
for port, network_interface in self.network_interface.items():
ip_address = ""
if hasattr(network_interface, "ip_address"):
ip_address = f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}"
table.add_row(
[
port,
type(network_interface),
network_interface.__class__.__name__,
network_interface.mac_address,
f"{network_interface.ip_address}/{network_interface.ip_network.prefixlen}",
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
]
@@ -1023,6 +1085,23 @@ class Node(SimComponent):
self.file_system.apply_timestep(timestep=timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
for network_interface in self.network_interfaces.values():
network_interface.pre_timestep(timestep=timestep)
for process_id in self.processes:
self.processes[process_id].pre_timestep(timestep=timestep)
for service_id in self.services:
self.services[service_id].pre_timestep(timestep=timestep)
for application_id in self.applications:
self.applications[application_id].pre_timestep(timestep=timestep)
self.file_system.pre_timestep(timestep=timestep)
def scan(self) -> bool:
"""
Scan the node and all the items within it.
@@ -1259,6 +1338,77 @@ class Node(SimComponent):
_LOGGER.info(f"Removed application {application.name} from node {self.hostname}")
self._application_request_manager.remove_request(application.name)
def application_install_action(self, application: Application, ip_address: Optional[str] = None) -> bool:
"""
Install an application on this node and configure it.
This method is useful for allowing agents to take this action.
:param application: Application object that has not been installed on any node yet.
:type application: Application
:param ip_address: IP address used to configure the application
(target IP for the DoSBot or server IP for the DataManipulationBot)
:type ip_address: str
:return: True if the application is installed successfully, otherwise False.
"""
if application in self:
_LOGGER.warning(
f"Can't add application {application.__name__}" + f"to node {self.hostname}. It's already installed."
)
return True
self.software_manager.install(application)
application_instance = self.software_manager.software.get(str(application.__name__))
self.applications[application_instance.uuid] = application_instance
self.sys_log.info(f"Installed application {application_instance.name}")
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
self._application_request_manager.add_request(
application_instance.name, RequestType(func=application_instance._request_manager)
)
# Configure application if additional parameters are given
if ip_address:
if application_instance.name == "DoSBot":
application_instance.configure(target_ip_address=IPv4Address(ip_address))
elif application_instance.name == "DataManipulationBot":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
elif application_instance.name == "RansomwareScript":
application_instance.configure(server_ip_address=IPv4Address(ip_address))
else:
pass
if application_instance.name in self.software_manager.software:
return True
else:
return False
def application_uninstall_action(self, application: Application) -> bool:
"""
Uninstall and completely remove application from this node.
This method is useful for allowing agents to take this action.
:param application: Application object that is currently associated with this node.
:type application: Application
:return: True if the application is uninstalled successfully, otherwise False.
"""
if application.__name__ not in self.software_manager.software:
_LOGGER.warning(
f"Can't remove application {application.__name__}" + f"from node {self.hostname}. It's not installed."
)
return True
application_instance = self.software_manager.software.get(
str(application.__name__)
) # This works because we can't have two applications with the same name on the same node
# self.uninstall_application(application_instance)
self.software_manager.uninstall(application_instance.name)
if application_instance.name not in self.software_manager.software:
return True
else:
return False
def _shut_down_actions(self):
"""Actions to perform when the node is shut down."""
# Turn off all the services in the node
@@ -1290,4 +1440,6 @@ class Node(SimComponent):
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
return item.uuid in self.services
elif isinstance(item, Application):
return item.uuid in self.applications
return None

View File

@@ -316,6 +316,16 @@ class HostNode(Node):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask))
@property
def arp(self) -> Optional[ARP]:
"""
Return the ARP Cache of the HostNode.
:return: ARP Cache for given HostNode
:rtype: Optional[ARP]
"""
return self.software_manager.software.get("ARP")
def _install_system_software(self):
"""
Installs the system software and network services typically found on an operating system.

View File

@@ -28,3 +28,9 @@ class Server(HostNode):
* Applications:
* Web Browser
"""
class Printer(HostNode):
"""Printer? I don't even know her!."""
# TODO: Implement printer-specific behaviour

View File

@@ -1,9 +1,10 @@
from ipaddress import IPv4Address
from typing import Dict, Final, Optional, Union
from typing import Dict, Final, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import (
AccessControlList,
@@ -67,22 +68,34 @@ class Firewall(Router):
:ivar str hostname: The Firewall hostname.
"""
internal_inbound_acl: Optional[AccessControlList] = None
internal_inbound_acl: AccessControlList = Field(
default_factory=lambda: AccessControlList(name="Internal Inbound", implicit_action=ACLAction.DENY)
)
"""Access Control List for managing entering the internal network."""
internal_outbound_acl: Optional[AccessControlList] = None
internal_outbound_acl: AccessControlList = Field(
default_factory=lambda: AccessControlList(name="Internal Outbound", implicit_action=ACLAction.DENY)
)
"""Access Control List for managing traffic leaving the internal network."""
dmz_inbound_acl: Optional[AccessControlList] = None
dmz_inbound_acl: AccessControlList = Field(
default_factory=lambda: AccessControlList(name="DMZ Inbound", implicit_action=ACLAction.DENY)
)
"""Access Control List for managing traffic entering the DMZ."""
dmz_outbound_acl: Optional[AccessControlList] = None
dmz_outbound_acl: AccessControlList = Field(
default_factory=lambda: AccessControlList(name="DMZ Outbound", implicit_action=ACLAction.DENY)
)
"""Access Control List for managing traffic leaving the DMZ."""
external_inbound_acl: Optional[AccessControlList] = None
external_inbound_acl: AccessControlList = Field(
default_factory=lambda: AccessControlList(name="External Inbound", implicit_action=ACLAction.PERMIT)
)
"""Access Control List for managing traffic entering from an external network."""
external_outbound_acl: Optional[AccessControlList] = None
external_outbound_acl: AccessControlList = Field(
default_factory=lambda: AccessControlList(name="External Outbound", implicit_action=ACLAction.PERMIT)
)
"""Access Control List for managing traffic leaving towards an external network."""
def __init__(self, hostname: str, **kwargs):
@@ -100,29 +113,85 @@ class Firewall(Router):
self.connect_nic(
RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="dmz")
)
# Update ACL objects with firewall's hostname and syslog to allow accurate logging
self.internal_inbound_acl.sys_log = kwargs["sys_log"]
self.internal_inbound_acl.name = f"{hostname} - Internal Inbound"
# Initialise ACLs for internal and dmz interfaces with a default DENY policy
self.internal_inbound_acl = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - Internal Inbound"
self.internal_outbound_acl.sys_log = kwargs["sys_log"]
self.internal_outbound_acl.name = f"{hostname} - Internal Outbound"
self.dmz_inbound_acl.sys_log = kwargs["sys_log"]
self.dmz_inbound_acl.name = f"{hostname} - DMZ Inbound"
self.dmz_outbound_acl.sys_log = kwargs["sys_log"]
self.dmz_outbound_acl.name = f"{hostname} - DMZ Outbound"
self.external_inbound_acl.sys_log = kwargs["sys_log"]
self.external_inbound_acl.name = f"{hostname} - External Inbound"
self.external_outbound_acl.sys_log = kwargs["sys_log"]
self.external_outbound_acl.name = f"{hostname} - External Outbound"
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
self._internal_acl_request_manager = RequestManager()
rm.add_request("internal", RequestType(func=self._internal_acl_request_manager))
self._dmz_acl_request_manager = RequestManager()
rm.add_request("dmz", RequestType(func=self._dmz_acl_request_manager))
self._external_acl_request_manager = RequestManager()
rm.add_request("external", RequestType(func=self._external_acl_request_manager))
self._internal_inbound_acl_request_manager = RequestManager()
self._internal_outbound_acl_request_manager = RequestManager()
self._internal_acl_request_manager.add_request(
"inbound", RequestType(func=self._internal_inbound_acl_request_manager)
)
self.internal_outbound_acl = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - Internal Outbound"
)
self.dmz_inbound_acl = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - DMZ Inbound"
)
self.dmz_outbound_acl = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.DENY, name=f"{hostname} - DMZ Outbound"
self._internal_acl_request_manager.add_request(
"outbound", RequestType(func=self._internal_outbound_acl_request_manager)
)
# external ACLs should have a default PERMIT policy
self.external_inbound_acl = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Inbound"
self.dmz_inbound_acl_request_manager = RequestManager()
self.dmz_outbound_acl_request_manager = RequestManager()
self._dmz_acl_request_manager.add_request("inbound", RequestType(func=self.dmz_inbound_acl_request_manager))
self._dmz_acl_request_manager.add_request("outbound", RequestType(func=self.dmz_outbound_acl_request_manager))
self.external_inbound_acl_request_manager = RequestManager()
self.external_outbound_acl_request_manager = RequestManager()
self._external_acl_request_manager.add_request(
"inbound", RequestType(func=self.external_inbound_acl_request_manager)
)
self.external_outbound_acl = AccessControlList(
sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound"
self._external_acl_request_manager.add_request(
"outbound", RequestType(func=self.external_outbound_acl_request_manager)
)
self._internal_inbound_acl_request_manager.add_request(
"acl", RequestType(func=self.internal_inbound_acl._request_manager)
)
self._internal_outbound_acl_request_manager.add_request(
"acl", RequestType(func=self.internal_outbound_acl._request_manager)
)
self.dmz_inbound_acl_request_manager.add_request("acl", RequestType(func=self.dmz_inbound_acl._request_manager))
self.dmz_outbound_acl_request_manager.add_request(
"acl", RequestType(func=self.dmz_outbound_acl._request_manager)
)
self.external_inbound_acl_request_manager.add_request(
"acl", RequestType(func=self.external_inbound_acl._request_manager)
)
self.external_outbound_acl_request_manager.add_request(
"acl", RequestType(func=self.external_outbound_acl._request_manager)
)
return rm
def describe_state(self) -> Dict:
"""
Describes the current state of the Firewall.
@@ -530,7 +599,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -543,7 +614,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -556,7 +629,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -569,7 +644,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -582,7 +659,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
@@ -595,7 +674,9 @@ class Firewall(Router):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)

View File

@@ -1,7 +1,9 @@
from abc import abstractmethod
from typing import Optional
from primaite.simulator.network.hardware.base import NetworkInterface, Node
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.services.arp.arp import ARP
class NetworkNode(Node):
@@ -28,3 +30,13 @@ class NetworkNode(Node):
:type from_network_interface: NetworkInterface
"""
pass
@property
def arp(self) -> Optional[ARP]:
"""
Return the ARP Cache of the NetworkNode.
:return: ARP Cache for given NetworkNode
:rtype: Optional[ARP]
"""
return self.software_manager.software.get("ARP")

View File

@@ -18,6 +18,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import SessionManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
@@ -147,8 +148,10 @@ class ACLRule(SimComponent):
state["action"] = self.action.value
state["protocol"] = self.protocol.name if self.protocol else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None
state["src_port"] = self.src_port.name if self.src_port else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None
state["dst_port"] = self.dst_port.name if self.dst_port else None
state["match_count"] = self.match_count
return state
@@ -275,7 +278,7 @@ class AccessControlList(SimComponent):
:ivar int max_acl_rules: The maximum number of ACL rules that can be added to the list. Defaults to 25.
"""
sys_log: SysLog
sys_log: Optional[SysLog] = None
implicit_action: ACLAction
implicit_rule: ACLRule
max_acl_rules: int = 25
@@ -319,10 +322,12 @@ class AccessControlList(SimComponent):
action=ACLAction[request[0]],
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
src_port=None if request[3] == "ALL" else Port[request[3]],
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
dst_port=None if request[5] == "ALL" else Port[request[5]],
position=int(request[6]),
src_wildcard_mask=None if request[3] == "NONE" else IPv4Address(request[3]),
src_port=None if request[4] == "ALL" else Port[request[4]],
dst_ip_address=None if request[5] == "ALL" else IPv4Address(request[5]),
dst_wildcard_mask=None if request[6] == "NONE" else IPv4Address(request[6]),
dst_port=None if request[7] == "ALL" else Port[request[7]],
position=int(request[8]),
)
)
),
@@ -624,11 +629,12 @@ class RouteTable(SimComponent):
"""
pass
@validate_call()
def add_route(
self,
address: Union[IPv4Address, str],
subnet_mask: Union[IPv4Address, str],
next_hop_ip_address: Union[IPv4Address, str],
address: Union[IPV4Address, str],
subnet_mask: Union[IPV4Address, str],
next_hop_ip_address: Union[IPV4Address, str],
metric: float = 0.0,
):
"""
@@ -647,7 +653,8 @@ class RouteTable(SimComponent):
)
self.routes.append(route)
def set_default_route_next_hop_ip_address(self, ip_address: IPv4Address):
@validate_call()
def set_default_route_next_hop_ip_address(self, ip_address: IPV4Address):
"""
Sets the next-hop IP address for the default route in a routing table.
@@ -660,7 +667,7 @@ class RouteTable(SimComponent):
"""
if not self.default_route:
self.default_route = RouteEntry(
ip_address=IPv4Address("0.0.0.0"),
address=IPv4Address("0.0.0.0"),
subnet_mask=IPv4Address("0.0.0.0"),
next_hop_ip_address=ip_address,
)
@@ -767,6 +774,13 @@ class RouterARP(ARP):
is_reattempt=True,
is_default_route_attempt=is_default_route_attempt,
)
elif route and route == self.router.route_table.default_route:
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
return self._get_arp_cache_mac_address(
ip_address=self.router.route_table.default_route.next_hop_ip_address,
is_reattempt=True,
is_default_route_attempt=True,
)
else:
if self.router.route_table.default_route:
if not is_default_route_attempt:
@@ -817,6 +831,12 @@ class RouterARP(ARP):
return network_interface
if not is_reattempt:
if self.router.ip_is_in_router_interface_subnet(ip_address):
self.send_arp_request(ip_address)
return self._get_arp_cache_network_interface(
ip_address=ip_address, is_reattempt=True, is_default_route_attempt=is_default_route_attempt
)
route = self.router.route_table.find_best_route(ip_address)
if route and route != self.router.route_table.default_route:
self.send_arp_request(route.next_hop_ip_address)
@@ -825,6 +845,13 @@ class RouterARP(ARP):
is_reattempt=True,
is_default_route_attempt=is_default_route_attempt,
)
elif route and route == self.router.route_table.default_route:
self.send_arp_request(self.router.route_table.default_route.next_hop_ip_address)
return self._get_arp_cache_network_interface(
ip_address=self.router.route_table.default_route.next_hop_ip_address,
is_reattempt=True,
is_default_route_attempt=True,
)
else:
if self.router.route_table.default_route:
if not is_default_route_attempt:
@@ -1016,6 +1043,144 @@ class RouterInterface(IPWiredNetworkInterface):
return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}"
class RouterSessionManager(SessionManager):
"""
Manages network sessions, including session creation, lookup, and communication with other components.
The RouterSessionManager is a Router/Firewall specific implementation of SessionManager. It overrides the
resolve_outbound_network_interface and resolve_outbound_transmission_details functions, allowing them to leverage
the route table instead of the default gateway.
:param sys_log: A reference to the system log component.
"""
def resolve_outbound_network_interface(self, dst_ip_address: IPv4Address) -> Optional[RouterInterface]:
"""
Resolves the appropriate outbound network interface for a given destination IP address.
This method determines the most suitable network interface for sending a packet to the specified
destination IP address. It considers only enabled network interfaces and checks if the destination
IP address falls within the subnet of each interface. If no suitable local network interface is found,
the method defaults to performing a route table look-up to determine if there is a dedicated route or a default
route it can use.
The search process prioritises local network interfaces based on the IP network to which they belong.
If the destination IP address does not match any local subnet, the method assumes that the destination
is outside the local network and hence, routes the packet according to route table look-up.
:param dst_ip_address: The destination IP address for which the outbound interface is to be resolved.
:type dst_ip_address: IPv4Address
:return: The network interface through which the packet should be sent to reach the destination IP address,
or the default gateway's network interface if the destination is not within any local subnet.
:rtype: Optional[RouterInterface]
"""
network_interface = super().resolve_outbound_network_interface(dst_ip_address)
if not network_interface:
route = self.node.route_table.find_best_route(dst_ip_address)
if not route:
return None
network_interface = super().resolve_outbound_network_interface(route.next_hop_ip_address)
return network_interface
def resolve_outbound_transmission_details(
self,
dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None,
src_port: Optional[Port] = None,
dst_port: Optional[Port] = None,
protocol: Optional[IPProtocol] = None,
session_id: Optional[str] = None,
) -> Tuple[
Optional[RouterInterface],
Optional[str],
IPv4Address,
Optional[Port],
Optional[Port],
Optional[IPProtocol],
bool,
]:
"""
Resolves the necessary details for outbound transmission based on the provided parameters.
This method determines whether the payload should be broadcast or unicast based on the destination IP address
and resolves the outbound network interface and destination MAC address accordingly.
The method first checks if `session_id` is provided and uses the session details if available. For broadcast
transmissions, it finds a suitable network interface and uses a broadcast MAC address. For unicast
transmissions, it attempts to resolve the destination MAC address using ARP and finds the appropriate
outbound network interface. If the destination IP address is outside the local network and no specific MAC
address is resolved, it defaults to performing a route table look-up to determine if there is a dedicated route
or a default route it can use.
:param dst_ip_address: The destination IP address or network. If an IPv4Network is provided, the method
treats the transmission as a broadcast to that network. Optional.
:type dst_ip_address: Optional[Union[IPv4Address, IPv4Network]]
:param src_port: The source port number for the transmission. Optional.
:type src_port: Optional[Port]
:param dst_port: The destination port number for the transmission. Optional.
:type dst_port: Optional[Port]
:param protocol: The IP protocol to be used for the transmission. Optional.
:type protocol: Optional[IPProtocol]
:param session_id: The session ID associated with the transmission. If provided, the session details override
other parameters. Optional.
:type session_id: Optional[str]
:return: A tuple containing the resolved outbound network interface, destination MAC address, destination IP
address, source port, destination port, protocol, and a boolean indicating whether the transmission is a
broadcast.
:rtype: Tuple[Optional[RouterInterface], Optional[str], IPv4Address, Optional[Port], Optional[Port],
Optional[IPProtocol], bool]
"""
if dst_ip_address and not isinstance(dst_ip_address, (IPv4Address, IPv4Network)):
dst_ip_address = IPv4Address(dst_ip_address)
is_broadcast = False
outbound_network_interface = None
dst_mac_address = None
# Use session details if session_id is provided
if session_id:
session = self.sessions_by_uuid[session_id]
dst_ip_address = session.with_ip_address
protocol = session.protocol
src_port = session.src_port
dst_port = session.dst_port
# Determine if the payload is for broadcast or unicast
# Handle broadcast transmission
if isinstance(dst_ip_address, IPv4Network):
is_broadcast = True
dst_ip_address = dst_ip_address.broadcast_address
if dst_ip_address:
# Find a suitable NIC for the broadcast
for network_interface in self.node.network_interfaces.values():
if dst_ip_address in network_interface.ip_network and network_interface.enabled:
dst_mac_address = "ff:ff:ff:ff:ff:ff"
outbound_network_interface = network_interface
break
else:
# Resolve MAC address for unicast transmission
use_route_table = True
for network_interface in self.node.network_interfaces.values():
if dst_ip_address in network_interface.ip_network and network_interface.enabled:
dst_mac_address = self.software_manager.arp.get_arp_cache_mac_address(dst_ip_address)
break
if dst_mac_address:
use_route_table = False
outbound_network_interface = self.software_manager.arp.get_arp_cache_network_interface(dst_ip_address)
if use_route_table:
route = self.node.route_table.find_best_route(dst_ip_address)
if not route:
raise Exception("cannot use route to resolve outbound details")
dst_mac_address = self.software_manager.arp.get_arp_cache_mac_address(route.next_hop_ip_address)
outbound_network_interface = self.software_manager.arp.get_arp_cache_network_interface(
route.next_hop_ip_address
)
return outbound_network_interface, dst_mac_address, dst_ip_address, src_port, dst_port, protocol, is_broadcast
class Router(NetworkNode):
"""
Represents a network router, managing routing and forwarding of IP packets across network interfaces.
@@ -1049,6 +1214,10 @@ class Router(NetworkNode):
if not kwargs.get("route_table"):
kwargs["route_table"] = RouteTable(sys_log=kwargs["sys_log"])
super().__init__(hostname=hostname, num_ports=num_ports, **kwargs)
self.session_manager = RouterSessionManager(sys_log=self.sys_log)
self.session_manager.node = self
self.software_manager.session_manager = self.session_manager
self.session_manager.software_manager = self.software_manager
for i in range(1, self.num_ports + 1):
network_interface = RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")
self.connect_nic(network_interface)
@@ -1068,8 +1237,7 @@ class Router(NetworkNode):
icmp: RouterICMP = self.software_manager.icmp # noqa
icmp.router = self
self.software_manager.install(RouterARP)
arp: RouterARP = self.software_manager.arp # noqa
arp.router = self
self.arp.router = self
def _set_default_acl(self):
"""
@@ -1313,6 +1481,8 @@ class Router(NetworkNode):
frame.ethernet.src_mac_addr = network_interface.mac_address
frame.ethernet.dst_mac_addr = target_mac
network_interface.send_frame(frame)
else:
self.sys_log.error(f"Frame dropped as there is no route to {frame.ip.dst_ip_address}")
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
"""
@@ -1393,6 +1563,13 @@ class Router(NetworkNode):
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
- routes (list[dict]): List of route dicts with values:
- address (str): The destination address of the route.
- subnet_mask (str): The subnet mask of the route.
- next_hop_ip_address (str): The next hop IP for the route.
- metric (int): The metric of the route. Optional.
- default_route:
- next_hop_ip_address (str): The next hop IP for the route.
Example config:
```
@@ -1403,6 +1580,10 @@ class Router(NetworkNode):
1: {
'ip_address' : '192.168.1.1',
'subnet_mask' : '255.255.255.0',
},
2: {
'ip_address' : '192.168.0.1',
'subnet_mask' : '255.255.255.252',
}
},
'acl' : {
@@ -1410,6 +1591,10 @@ class Router(NetworkNode):
22: {'action': 'PERMIT', 'src_port': 'ARP', 'dst_port': 'ARP'},
23: {'action': 'PERMIT', 'protocol': 'ICMP'},
},
'routes' : [
{'address': '192.168.0.0', 'subnet_mask': '255.255.255.0', 'next_hop_ip_address': '192.168.1.2'}
],
'default_route': {'next_hop_ip_address': '192.168.0.2'}
}
```
@@ -1440,7 +1625,9 @@ class Router(NetworkNode):
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_ip_address=r_cfg.get("dst_ip"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
@@ -1451,4 +1638,8 @@ class Router(NetworkNode):
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
if "default_route" in cfg:
next_hop_ip_address = cfg["default_route"].get("next_hop_ip_address", None)
if next_hop_ip_address:
router.route_table.set_default_route_next_hop_ip_address(next_hop_ip_address)
return router

View File

@@ -100,13 +100,8 @@ class Switch(NetworkNode):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.network_interface:
self.network_interface = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.network_interface.items():
port._connected_node = self
port.port_num = port_num
port.parent = self
port.port_num = port_num
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def show(self, markdown: bool = False):
"""

View File

@@ -1,10 +1,14 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Union
from pydantic import validate_call
from primaite.simulator.network.airspace import AirSpaceFrequency, IPWirelessNetworkInterface
from primaite.simulator.network.hardware.nodes.network.router import Router, RouterInterface
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterInterface
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.utils.validators import IPV4Address
@@ -209,3 +213,68 @@ class WirelessRouter(Router):
raise NotImplementedError(
"Please use the 'configure_wireless_access_point' and 'configure_router_interface' functions."
)
@classmethod
def from_config(cls, cfg: Dict) -> "WirelessRouter":
"""Generate the wireless router from config.
Schema:
- hostname (str): unique name for this router.
- router_interface (dict): The values should be another dict specifying
- ip_address (str)
- subnet_mask (str)
- wireless_access_point (dict): Dict with
- ip address,
- subnet mask,
- frequency, (string: either WIFI_2_4 or WIFI_5)
- acl (dict): Dict with integers from 1 - max_acl_rules as keys. The key defines the position within the ACL
where the rule will be added (lower number is resolved first). The values should describe valid ACL
Rules as:
- action (str): either PERMIT or DENY
- src_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- dst_port (str, optional): the named port such as HTTP, HTTPS, or POSTGRES_SERVER
- protocol (str, optional): the named IP protocol such as ICMP, TCP, or UDP
- src_ip_address (str, optional): IP address octet written in base 10
- dst_ip_address (str, optional): IP address octet written in base 10
:param cfg: Config dictionary
:type cfg: Dict
:return: WirelessRouter instance.
:rtype: WirelessRouter
"""
operating_state = (
NodeOperatingState.ON if not (p := cfg.get("operating_state")) else NodeOperatingState[p.upper()]
)
router = cls(hostname=cfg["hostname"], operating_state=operating_state)
if "router_interface" in cfg:
ip_address = cfg["router_interface"]["ip_address"]
subnet_mask = cfg["router_interface"]["subnet_mask"]
router.configure_router_interface(ip_address=ip_address, subnet_mask=subnet_mask)
if "wireless_access_point" in cfg:
ip_address = cfg["wireless_access_point"]["ip_address"]
subnet_mask = cfg["wireless_access_point"]["subnet_mask"]
frequency = AirSpaceFrequency[cfg["wireless_access_point"]["frequency"]]
router.configure_wireless_access_point(ip_address=ip_address, subnet_mask=subnet_mask, frequency=frequency)
if "acl" in cfg:
for r_num, r_cfg in cfg["acl"].items():
router.acl.add_rule(
action=ACLAction[r_cfg["action"]],
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
src_wildcard_mask=r_cfg.get("src_wildcard_mask"),
dst_wildcard_mask=r_cfg.get("dst_wildcard_mask"),
position=r_num,
)
if "routes" in cfg:
for route in cfg.get("routes"):
router.route_table.add_route(
address=IPv4Address(route.get("address")),
subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")),
next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")),
metric=float(route.get("metric", 0)),
)
return router

View File

@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False

View File

@@ -8,7 +8,7 @@ from primaite.simulator.network.protocols.icmp import ICMPPacket
from primaite.simulator.network.protocols.packet import DataPacket
from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader, UDPHeader
from primaite.simulator.network.utils import convert_bytes_to_megabits
_LOGGER = getLogger(__name__)
@@ -141,3 +141,37 @@ class Frame(BaseModel):
def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed
"""The daa transfer size of the Frame in Mbits."""
return convert_bytes_to_megabits(self.size)
@property
def is_broadcast(self) -> bool:
"""
Determines if the Frame is a broadcast frame.
A Frame is considered a broadcast frame if the destination MAC address is set to the broadcast address
"ff:ff:ff:ff:ff:ff".
:return: True if the destination MAC address is a broadcast address, otherwise False.
"""
return self.ethernet.dst_mac_addr.lower() == "ff:ff:ff:ff:ff:ff"
@property
def is_arp(self) -> bool:
"""
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
:return: True if the Frame is an ARP packet, otherwise False.
"""
return self.udp.dst_port == Port.ARP
@property
def is_icmp(self) -> bool:
"""
Determines if the Frame is an ICMP (Internet Control Message Protocol) packet.
This check is performed by verifying if the 'icmp' attribute of the Frame instance is present (not None).
:return: True if the Frame is an ICMP packet (i.e., has an ICMP header), otherwise False.
"""
return self.icmp is not None

View File

@@ -11,6 +11,9 @@ class Port(Enum):
.. _List of Ports:
"""
UNUSED = -1
"An unused port stub."
NONE = 0
"Place holder for a non-port."
WOL = 9

View File

@@ -63,3 +63,8 @@ class Simulation(SimComponent):
"""Apply a timestep to the simulation."""
super().apply_timestep(timestep)
self.network.apply_timestep(timestep)
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.network.pre_timestep(timestep)

View File

@@ -3,6 +3,8 @@ from enum import Enum
from typing import Any, Dict, Set
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
_LOGGER = getLogger(__name__)
@@ -38,6 +40,17 @@ class Application(IOSoftware):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request("close", RequestType(func=lambda request, context: RequestResponse.from_bool(self.close())))
return rm
@abstractmethod
def describe_state(self) -> Dict:
"""
@@ -67,7 +80,10 @@ class Application(IOSoftware):
"""
super().apply_timestep(timestep=timestep)
self.num_executions = 0 # reset number of executions
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.num_executions = 0
def _can_perform_action(self) -> bool:
"""
@@ -83,7 +99,7 @@ class Application(IOSoftware):
if self.operating_state is not self.operating_state.RUNNING:
# service is not running
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
_LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}")
return False
return True
@@ -104,11 +120,12 @@ class Application(IOSoftware):
"""The main application loop."""
pass
def close(self) -> None:
def close(self) -> bool:
"""Close the Application."""
if self.operating_state == ApplicationOperatingState.RUNNING:
self.sys_log.info(f"Closed Application{self.name}")
self.operating_state = ApplicationOperatingState.CLOSED
return True
def install(self) -> None:
"""Install Application."""

View File

@@ -29,6 +29,9 @@ class DatabaseClient(Application):
_query_success_tracker: Dict[str, bool] = {}
_last_connection_successful: Optional[bool] = None
"""Keep track of connections that were established or verified during this step. Used for rewards."""
last_query_response: Optional[Dict] = None
"""Keep track of the latest query response. Used to determine rewards."""
_server_connection_id: Optional[str] = None
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseClient"
@@ -49,10 +52,9 @@ class DatabaseClient(Application):
def execute(self) -> bool:
"""Execution definition for db client: perform a select query."""
self.num_executions += 1 # trying to connect counts as an execution
if self.connections:
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
else:
can_connect = self.check_connection(connection_id=str(uuid4()))
if not self._server_connection_id:
self.connect()
can_connect = self.check_connection(connection_id=self._server_connection_id)
self._last_connection_successful = can_connect
return can_connect
@@ -78,17 +80,21 @@ class DatabaseClient(Application):
self.server_password = server_password
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self, connection_id: Optional[str] = None) -> bool:
def connect(self) -> bool:
"""Connect to a Database Service."""
if not self._can_perform_action():
return False
if not connection_id:
connection_id = str(uuid4())
if not self._server_connection_id:
self._server_connection_id = str(uuid4())
self.connected = self._connect(
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
server_ip_address=self.server_ip_address,
password=self.server_password,
connection_id=self._server_connection_id,
)
if not self.connected:
self._server_connection_id = None
return self.connected
def check_connection(self, connection_id: str) -> bool:
@@ -123,7 +129,7 @@ class DatabaseClient(Application):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
if self.connections.get(connection_id):
if self._server_connection_id:
self.sys_log.info(
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
)
@@ -147,31 +153,28 @@ class DatabaseClient(Application):
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
)
def disconnect(self, connection_id: Optional[str] = None) -> bool:
def disconnect(self) -> bool:
"""Disconnect from the Database Service."""
if not self._can_perform_action():
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
return False
# if there are no connections - nothing to disconnect
if not len(self.connections):
if not self._server_connection_id:
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
return False
# if no connection provided, disconnect the first connection
if not connection_id:
connection_id = list(self.connections.keys())[0]
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_id},
payload={"type": "disconnect", "connection_id": self._server_connection_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
self.remove_connection(connection_id=connection_id)
self.remove_connection(connection_id=self._server_connection_id)
self.sys_log.info(
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
f"{self.name}: DatabaseClient disconnected {self._server_connection_id} from {self.server_ip_address}"
)
self.connected = False
@@ -219,18 +222,23 @@ class DatabaseClient(Application):
if not self._can_perform_action():
return False
if connection_id is None:
if self.connections:
connection_id = list(self.connections.keys())[-1]
# TODO: if the most recent connection dies, it should be automatically cleared.
else:
connection_id = str(uuid4())
# reset last query response
self.last_query_response = None
if not self.connections.get(connection_id):
if not self.connect(connection_id=connection_id):
return False
connection_id: str
if not connection_id:
connection_id = self._server_connection_id
if not connection_id:
self.connect()
connection_id = self._server_connection_id
if not connection_id:
msg = "Cannot run sql query, could not establish connection with the server."
self.parent.sys_log.error(msg)
return False
# Initialise the tracker of this ID to False
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
@@ -252,6 +260,7 @@ class DatabaseClient(Application):
# add connection
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
elif payload["type"] == "sql":
self.last_query_response = payload
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200

View File

@@ -0,0 +1,316 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
class RansomwareAttackStage(IntEnum):
"""
Enumeration representing different attack stages of the ransomware script.
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle
in the simulation.
Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
DOWNLOAD = 1
"Installing the Encryption Script - Testing"
INSTALL = 2
"The stage where logon procedures are simulated."
ACTIVATE = 3
"Operating Status Changes"
PROPAGATE = 4
"Represents the stage of performing a horizontal port scan on the target."
COMMAND_AND_CONTROL = 5
"Represents the stage of setting up a rely C2 Beacon (Not Implemented)"
PAYLOAD = 6
"Stage of actively attacking the target."
SUCCEEDED = 7
"Indicates the attack has been successfully completed."
FAILED = 8
"Signifies that the attack has failed."
class RansomwareScript(Application):
"""Ransomware Kill Chain - Designed to be used by the TAP001 Agent on the example layout Network.
:ivar payload: The attack stage query payload. (Default Corrupt)
:ivar target_scan_p_of_success: The probability of success for the target scan stage.
:ivar c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:ivar ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:ivar repeat: Whether to repeat attacking once finished.
"""
server_ip_address: Optional[IPv4Address] = None
"""IP address of node which hosts the database."""
server_password: Optional[str] = None
"""Password required to access the database."""
payload: Optional[str] = "ENCRYPT"
"Payload String for the payload stage"
target_scan_p_of_success: float = 0.9
"Probability of the target scan succeeding: Default 0.9"
c2_beacon_p_of_success: float = 0.9
"Probability of the c2 beacon setup stage succeeding: Default 0.9"
ransomware_encrypt_p_of_success: float = 0.9
"Probability of the ransomware attack succeeding: Default 0.9"
repeat: bool = False
"If true, the Denial of Service bot will keep performing the attack."
attack_stage: RansomwareAttackStage = RansomwareAttackStage.NOT_STARTED
"The ransomware attack stage. See RansomwareAttackStage Class"
def __init__(self, **kwargs):
kwargs["name"] = "RansomwareScript"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
return state
@property
def _host_db_client(self) -> DatabaseClient:
"""Return the database client that is installed on the same machine as the Ransomware Script."""
db_client = self.software_manager.software.get("DatabaseClient")
if db_client is None:
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
return db_client
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(
name="execute",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
)
return rm
def _activate(self):
"""
Simulate the install process as the initial stage of the attack.
Advances the attack stage to 'ACTIVATE' attack state.
"""
if self.attack_stage == RansomwareAttackStage.INSTALL:
self.sys_log.info(f"{self.name}: Activated!")
self.attack_stage = RansomwareAttackStage.ACTIVATE
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, triggering the application loop.
:param timestep: The timestep value to update the bot's state.
"""
pass
def run(self) -> bool:
"""Calls the parent classes execute method before starting the application loop."""
super().run()
return True
def _application_loop(self) -> bool:
"""
The main application loop of the script, handling the attack process.
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if not self._can_perform_action():
return False
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Running")
self.attack_stage = RansomwareAttackStage.NOT_STARTED
self._local_download()
self._install()
self._activate()
self._perform_target_scan()
self._setup_beacon()
self._perform_ransomware_encrypt()
if self.repeat and self.attack_stage in (
RansomwareAttackStage.SUCCEEDED,
RansomwareAttackStage.FAILED,
):
self.attack_stage = RansomwareAttackStage.NOT_STARTED
return True
else:
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
return False
def configure(
self,
server_ip_address: IPv4Address,
server_password: Optional[str] = None,
payload: Optional[str] = None,
target_scan_p_of_success: Optional[float] = None,
c2_beacon_p_of_success: Optional[float] = None,
ransomware_encrypt_p_of_success: Optional[float] = None,
repeat: bool = True,
):
"""
Configure the Ransomware Script to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The attack stage query (Encrypt / Delete)
:param target_scan_p_of_success: The probability of success for the target scan stage.
:param c2_beacon_p_of_success: The probability of success for the c2_beacon stage
:param ransomware_encrypt_p_of_success: The probability of success for the ransomware 'attack' (encrypt) stage.
:param repeat: Whether to repeat attacking once finished.
"""
if server_ip_address:
self.server_ip_address = server_ip_address
if server_password:
self.server_password = server_password
if payload:
self.payload = payload
if target_scan_p_of_success:
self.target_scan_p_of_success = target_scan_p_of_success
if c2_beacon_p_of_success:
self.c2_beacon_p_of_success = c2_beacon_p_of_success
if ransomware_encrypt_p_of_success:
self.ransomware_encrypt_p_of_success = ransomware_encrypt_p_of_success
if repeat:
self.repeat = repeat
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
f"{repeat=}."
)
def _install(self):
"""
Simulate the install stage in the kill-chain.
Advances the attack stage to 'ACTIVATE' if successful.
From this attack stage onwards.
the ransomware application is now visible from this point onwardin the observation space.
"""
if self.attack_stage == RansomwareAttackStage.DOWNLOAD:
self.sys_log.info(f"{self.name}: Malware installed on the local file system")
downloads_folder = self.file_system.get_folder(folder_name="downloads")
ransomware_file = downloads_folder.get_file(file_name="ransom_script.pdf")
ransomware_file.num_access += 1
self.attack_stage = RansomwareAttackStage.INSTALL
def _setup_beacon(self):
"""
Simulates setting up a c2 beacon; currently a pseudo step for increasing red variance.
Advances the attack stage to 'COMMAND AND CONTROL` if successful.
:param p_of_sucess: Probability of a successful c2 setup (Advancing this step),
by default the success rate is 0.5
"""
if self.attack_stage == RansomwareAttackStage.PROPAGATE:
self.sys_log.info(f"{self.name} Attempting to set up C&C Beacon - Scan 1/2")
if simulate_trial(self.c2_beacon_p_of_success):
self.sys_log.info(f"{self.name} C&C Successful setup - Scan 2/2")
c2c_setup = True # TODO Implement the c2c step via an FTP Application/Service
if c2c_setup:
self.attack_stage = RansomwareAttackStage.COMMAND_AND_CONTROL
def _perform_target_scan(self):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PROPAGATE` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == RansomwareAttackStage.ACTIVATE:
# perform a port scan to identify that the SQL port is open on the server
self.sys_log.info(f"{self.name}: Scanning for vulnerable databases - Scan 0/2")
if simulate_trial(self.target_scan_p_of_success):
self.sys_log.info(f"{self.name}: Found a target database! Scan 1/2")
port_is_open = True # TODO Implement a NNME Triggering scan as a seperate Red Application
if port_is_open:
self.attack_stage = RansomwareAttackStage.PROPAGATE
def attack(self) -> bool:
"""Perform the attack steps after opening the application."""
if not self._can_perform_action():
_LOGGER.debug("Ransomware application is unable to perform it's actions.")
self.run()
self.num_executions += 1
return self._application_loop()
def _perform_ransomware_encrypt(self):
"""
Execute the Ransomware Encrypt payload on the target.
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
:param p_of_success: Probability of successfully performing ransomware encryption, by default 0.1.
"""
if self._host_db_client is None:
self.sys_log.info(f"{self.name}: Failed to connect to db_client - Ransomware Script")
self.attack_stage = RansomwareAttackStage.FAILED
return
self._host_db_client.server_ip_address = self.server_ip_address
self._host_db_client.server_password = self.server_password
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
if simulate_trial(self.ransomware_encrypt_p_of_success):
self.sys_log.info(f"{self.name}: Attempting to launch payload")
if not len(self._host_db_client.connections):
self._host_db_client.connect()
if len(self._host_db_client.connections):
self._host_db_client.query(self.payload)
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
attack_successful = True
if attack_successful:
self.sys_log.info(f"{self.name}: Payload Successful")
self.attack_stage = RansomwareAttackStage.SUCCEEDED
else:
self.sys_log.info(f"{self.name}: Payload failed")
self.attack_stage = RansomwareAttackStage.FAILED
else:
self.sys_log.error("Attack Attempted to launch too quickly")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download(self):
"""Downloads itself via the onto the local file_system."""
if self.attack_stage == RansomwareAttackStage.NOT_STARTED:
if self._local_download_verify():
self.attack_stage = RansomwareAttackStage.DOWNLOAD
else:
self.sys_log.info("Malware failed to create a installation location")
self.attack_stage = RansomwareAttackStage.FAILED
else:
self.sys_log.info("Malware failed to download")
self.attack_stage = RansomwareAttackStage.FAILED
def _local_download_verify(self) -> bool:
"""Verifies a download location - Creates one if needed."""
for folder in self.file_system.folders:
if self.file_system.folders[folder].name == "downloads":
self.file_system.num_file_creations += 1
return True
self.file_system.create_folder("downloads")
self.file_system.create_file(folder_name="downloads", file_name="ransom_script.pdf")
return True

View File

@@ -49,8 +49,9 @@ class PacketCapture:
self.current_episode: int = 1
self.setup_logger(outbound=False)
self.setup_logger(outbound=True)
if SIM_OUTPUT.save_pcap_logs:
self.setup_logger(outbound=False)
self.setup_logger(outbound=True)
def setup_logger(self, outbound: bool = False):
"""Set up the logger configuration."""
@@ -108,8 +109,9 @@ class PacketCapture:
:param frame: The PCAP frame to capture.
"""
msg = frame.model_dump_json()
self.inbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
if SIM_OUTPUT.save_pcap_logs:
msg = frame.model_dump_json()
self.inbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
def capture_outbound(self, frame): # noqa - I'll have a circular import and cant use if TYPE_CHECKING ;(
"""
@@ -117,5 +119,6 @@ class PacketCapture:
:param frame: The PCAP frame to capture.
"""
msg = frame.model_dump_json()
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
if SIM_OUTPUT.save_pcap_logs:
msg = frame.model_dump_json()
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL

View File

@@ -72,7 +72,6 @@ class SessionManager:
Manages network sessions, including session creation, lookup, and communication with other components.
:param sys_log: A reference to the system log component.
:param arp_cache: A reference to the ARP cache component.
"""
def __init__(self, sys_log: SysLog):

View File

@@ -88,6 +88,10 @@ class SysLog:
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
if to_terminal or SIM_OUTPUT.write_sys_log_to_terminal:
print(f"{self.hostname}: ({level}) {msg}")
def debug(self, msg: str, to_terminal: bool = False):
"""
Logs a message with the DEBUG level.
@@ -97,8 +101,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.debug(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "DEBUG", to_terminal)
def info(self, msg: str, to_terminal: bool = False):
"""
@@ -109,8 +112,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.info(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "INFO", to_terminal)
def warning(self, msg: str, to_terminal: bool = False):
"""
@@ -121,8 +123,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.warning(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "WARNING", to_terminal)
def error(self, msg: str, to_terminal: bool = False):
"""
@@ -133,8 +134,7 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.error(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "ERROR", to_terminal)
def critical(self, msg: str, to_terminal: bool = False):
"""
@@ -145,5 +145,4 @@ class SysLog:
"""
if SIM_OUTPUT.save_sys_logs:
self.logger.critical(msg)
if to_terminal:
print(msg)
self._write_to_terminal(msg, "CRITICAL", to_terminal)

View File

@@ -65,6 +65,10 @@ class ARP(Service):
"""Clears the arp cache."""
self.arp.clear()
def get_default_gateway_network_interface(self) -> Optional[NetworkInterface]:
"""Not used at the parent ARP level. Should return None when there is no override by child class."""
return None
def add_arp_cache_entry(
self, ip_address: IPV4Address, mac_address: str, network_interface: NetworkInterface, override: bool = False
):

View File

@@ -104,14 +104,30 @@ class DatabaseService(Service):
self.sys_log.error("Unable to restore database backup.")
return False
old_visible_state = SoftwareHealthState.GOOD
# get db file regardless of whether or not it was deleted
db_file = self.file_system.get_file(folder_name="database", file_name="database.db", include_deleted=True)
if db_file is None:
self.sys_log.error("Database file not initialised.")
return False
# if the file was deleted, get the old visible health state
if db_file.deleted:
old_visible_state = db_file.visible_health_status
else:
old_visible_state = self.db_file.visible_health_status
self.file_system.delete_file(folder_name="database", file_name="database.db")
# replace db file
self.file_system.delete_file(folder_name="database", file_name="database.db")
self.file_system.copy_file(src_folder_name="downloads", src_file_name="database.db", dst_folder_name="database")
if self.db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
self.db_file.visible_health_status = old_visible_state
self.set_health_state(SoftwareHealthState.GOOD)
return True
@@ -125,8 +141,7 @@ class DatabaseService(Service):
"""Returns the database file."""
return self.file_system.get_file(folder_name="database", file_name="database.db")
@property
def folder(self) -> Folder:
def _return_database_folder(self) -> Folder:
"""Returns the database folder."""
return self.file_system.get_folder_by_id(self.db_file.folder_id)
@@ -171,7 +186,10 @@ class DatabaseService(Service):
}
def _process_sql(
self, query: Literal["SELECT", "DELETE", "INSERT"], query_id: str, connection_id: Optional[str] = None
self,
query: Literal["SELECT", "DELETE", "INSERT", "ENCRYPT"],
query_id: str,
connection_id: Optional[str] = None,
) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
@@ -180,6 +198,7 @@ class DatabaseService(Service):
- SELECT : returns the data
- DELETE : deletes the data
- INSERT : inserts the data
- ENCRYPT : corrupts the data
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
@@ -188,10 +207,18 @@ class DatabaseService(Service):
if not self.db_file:
self.sys_log.info(f"{self.name}: Failed to run {query} because the database file is missing.")
return {"status_code": 404, "data": False}
return {"status_code": 404, "type": "sql", "data": False}
if query == "SELECT":
if self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
if self.db_file.health_status == FileSystemItemHealthStatus.CORRUPT:
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
elif self.db_file.health_status == FileSystemItemHealthStatus.GOOD:
return {
"status_code": 200,
"type": "sql",
@@ -200,7 +227,7 @@ class DatabaseService(Service):
"connection_id": connection_id,
}
else:
return {"status_code": 404, "data": False}
return {"status_code": 404, "type": "sql", "data": False}
elif query == "DELETE":
self.db_file.health_status = FileSystemItemHealthStatus.COMPROMISED
return {
@@ -210,6 +237,20 @@ class DatabaseService(Service):
"uuid": query_id,
"connection_id": connection_id,
}
elif query == "ENCRYPT":
self.file_system.num_file_creations += 1
self.db_file.health_status = FileSystemItemHealthStatus.CORRUPT
self.db_file.num_access += 1
database_folder = self._return_database_folder()
database_folder.health_status = FileSystemItemHealthStatus.CORRUPT
self.file_system.num_file_deletions += 1
return {
"status_code": 200,
"type": "sql",
"data": False,
"uuid": query_id,
"connection_id": connection_id,
}
elif query == "INSERT":
if self.health_state_actual == SoftwareHealthState.GOOD:
return {
@@ -220,7 +261,7 @@ class DatabaseService(Service):
"connection_id": connection_id,
}
else:
return {"status_code": 404, "data": False}
return {"status_code": 404, "type": "sql", "data": False}
elif query == "SELECT * FROM pg_stat_activity":
# Check if the connection is active.
if self.health_state_actual == SoftwareHealthState.GOOD:
@@ -304,8 +345,8 @@ class DatabaseService(Service):
self.backup_database()
return super().apply_timestep(timestep)
def _update_patch_status(self) -> None:
"""Perform a database restore when the patching countdown is finished."""
super()._update_patch_status()
if self._patching_countdown is None:
def _update_fix_status(self) -> None:
"""Perform a database restore when the FIXING countdown is finished."""
super()._update_fix_status()
if self._fixing_countdown is None:
self.restore_backup()

View File

@@ -87,13 +87,9 @@ class NTPClient(Service):
:return: True if successful, False otherwise.
"""
if not isinstance(payload, NTPPacket):
_LOGGER.debug(f"{payload} is not a NTPPacket")
_LOGGER.debug(f"{self.name}: Failed to parse NTP update")
return False
if payload.ntp_reply.ntp_datetime:
self.sys_log.info(
f"{self.name}: \
Received time update from NTP server{payload.ntp_reply.ntp_datetime}"
)
self.time = payload.ntp_reply.ntp_datetime
return True
@@ -124,5 +120,3 @@ class NTPClient(Service):
if self.operating_state == ServiceOperatingState.RUNNING:
# request time from server
self.request_time()
else:
self.sys_log.debug(f"{self.name} ntp client not running")

View File

@@ -59,7 +59,7 @@ class Service(IOSoftware):
if self.operating_state is not ServiceOperatingState.RUNNING:
# service is not running
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
_LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}")
return False
return True

View File

@@ -43,8 +43,8 @@ class SoftwareHealthState(Enum):
"Unused state."
GOOD = 1
"The software is in a good and healthy condition."
PATCHING = 2
"The software is undergoing patching or updates."
FIXING = 2
"The software is undergoing FIXING or updates."
COMPROMISED = 3
"The software's security has been compromised."
OVERWHELMED = 4
@@ -82,13 +82,13 @@ class Software(SimComponent):
"The health state of the software visible to the red agent."
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
"The criticality level of the software."
patching_count: int = 0
fixing_count: int = 0
"The count of patches applied to the software, defaults to 0."
scanning_count: int = 0
"The count of times the software has been scanned, defaults to 0."
revealed_to_red: bool = False
"Indicates if the software has been revealed to red agent, defaults is False."
software_manager: "SoftwareManager" = None
software_manager: Optional["SoftwareManager"] = None
"An instance of Software Manager that is used by the parent node."
sys_log: SysLog = None
"An instance of SysLog that is used by the parent node."
@@ -96,9 +96,9 @@ class Software(SimComponent):
"The FileSystem of the Node the Software is installed on."
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
patching_duration: int = 2
fixing_duration: int = 2
"The number of ticks it takes to patch the software."
_patching_countdown: Optional[int] = None
_fixing_countdown: Optional[int] = None
"Current number of ticks left to patch the software."
def _init_request_manager(self) -> RequestManager:
@@ -117,9 +117,9 @@ class Software(SimComponent):
),
)
rm.add_request(
"patch",
"fix",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.patch()),
func=lambda request, context: RequestResponse.from_bool(self.fix()),
),
)
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
@@ -149,7 +149,7 @@ class Software(SimComponent):
"health_state_actual": self.health_state_actual.value,
"health_state_visible": self.health_state_visible.value,
"criticality": self.criticality.value,
"patching_count": self.patching_count,
"fixing_count": self.fixing_count,
"scanning_count": self.scanning_count,
"revealed_to_red": self.revealed_to_red,
}
@@ -194,21 +194,21 @@ class Software(SimComponent):
self.health_state_visible = self.health_state_actual
return True
def patch(self) -> bool:
"""Perform a patch on the software."""
def fix(self) -> bool:
"""Perform a fix on the software."""
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
self._patching_countdown = self.patching_duration
self.set_health_state(SoftwareHealthState.PATCHING)
self._fixing_countdown = self.fixing_duration
self.set_health_state(SoftwareHealthState.FIXING)
return True
return False
def _update_patch_status(self) -> None:
"""Update the patch status of the software."""
self._patching_countdown -= 1
if self._patching_countdown <= 0:
def _update_fix_status(self) -> None:
"""Update the fix status of the software."""
self._fixing_countdown -= 1
if self._fixing_countdown <= 0:
self.set_health_state(SoftwareHealthState.GOOD)
self._patching_countdown = None
self.patching_count += 1
self._fixing_countdown = None
self.fixing_count += 1
def reveal_to_red(self) -> None:
"""Reveals the software to the red agent."""
@@ -221,8 +221,12 @@ class Software(SimComponent):
:param timestep: The current timestep of the simulation.
"""
super().apply_timestep(timestep)
if self.health_state_actual == SoftwareHealthState.PATCHING:
self._update_patch_status()
if self.health_state_actual == SoftwareHealthState.FIXING:
self._update_fix_status()
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
class IOSoftware(Software):

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
from pathlib import Path

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Tuple, Union

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import csv
from logging import Logger

View File

@@ -1,12 +1,3 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
se3ed: 333 # Purposeful typo to check that error is raised with bad configuration.
n_learn_steps: 2560
n_eval_episodes: 5
game:
ports:
- ARP
@@ -22,8 +13,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -50,10 +40,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -86,63 +73,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
@@ -155,7 +152,7 @@ agents:
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_PATCH
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
@@ -169,14 +166,10 @@ agents:
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_hostname: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_hostname: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
@@ -250,7 +243,7 @@ agents:
folder_id: 1
file_id: 0
13:
action: "NODE_SERVICE_PATCH"
action: "NODE_SERVICE_FIX"
options:
node_id: 2
service_id: 0
@@ -291,8 +284,9 @@ agents:
options:
node_id: 5
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
@@ -300,9 +294,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
@@ -310,9 +307,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
@@ -320,9 +320,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
@@ -330,9 +333,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
@@ -340,9 +346,12 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "NETWORK_ACL_ADDRULE"
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
@@ -350,123 +359,135 @@ agents:
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 9
38:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
39:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
40:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
41:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
42:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
43:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
44:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
45:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
46:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
47:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
48:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
49:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
50:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
51:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
52:
action: "NETWORK_NIC_DISABLE"
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
53:
action: "NETWORK_NIC_ENABLE"
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
@@ -487,23 +508,15 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_name: domain_controller
nic_num: 1
- node_name: web_server
nic_num: 1
- node_name: database_server
nic_num: 1
- node_name: backup_server
nic_num: 1
- node_name: security_suite
nic_num: 1
- node_name: client_1
nic_num: 1
- node_name: client_2
nic_num: 1
- node_name: security_suite
nic_num: 2
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:
@@ -533,8 +546,7 @@ simulation:
network:
nodes:
- ref: router_1
type: router
- type: router
hostname: router_1
num_ports: 5
ports:
@@ -561,70 +573,58 @@ simulation:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
- type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
- type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
- type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
- type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_web_service
type: WebServer
- type: WebServer
applications:
- ref: web_server_database_client
type: DatabaseClient
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: database_server
type: server
- type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- type: DatabaseService
- ref: backup_server
type: server
- type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: FTPServer
- type: FTPServer
- ref: security_suite
type: server
- type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
@@ -635,87 +635,71 @@ simulation:
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
- type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
- type: DNSClient
- ref: client_2
type: computer
- type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
- type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
- type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_hostname: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_hostname: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_hostname: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_hostname: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_hostname: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_hostname: security_suite
endpoint_b_port: 2

View File

@@ -5,21 +5,7 @@
# -------------- -------------- --------------
#
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
@@ -41,12 +27,20 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
options:
nodes:
- node_name: client_2
@@ -71,8 +65,7 @@ simulation:
network:
nodes:
- ref: firewall
type: firewall
- type: firewall
hostname: firewall
start_up_duration: 0
shut_down_duration: 0
@@ -125,25 +118,21 @@ simulation:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
- type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
- type: switch
hostname: switch_2
num_ports: 8
- ref: client_1
type: computer
- type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
# pre installed services and applications
- ref: client_2
type: computer
- type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
@@ -152,23 +141,19 @@ simulation:
# pre installed services and applications
links:
- ref: switch_1___client_1
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- ref: switch_1___firewall
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_ref: firewall
endpoint_b_hostname: firewall
endpoint_b_port: 1
- ref: switch_2___firewall
endpoint_a_ref: switch_2
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_ref: firewall
endpoint_b_hostname: firewall
endpoint_b_port: 2

View File

@@ -4,22 +4,7 @@
# | client_1 |------| switch_1 |------| client_2 |
# -------------- -------------- --------------
#
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
@@ -41,12 +26,20 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
options:
nodes:
- node_name: client_2
@@ -71,79 +64,64 @@ simulation:
network:
nodes:
- ref: switch_1
type: switch
- type: switch
hostname: switch_1
num_ports: 8
- ref: client_1
- hostname: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_1_web_browser
type: WebBrowser
- type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: client_1_database_client
type: DatabaseClient
- type: DatabaseClient
options:
db_server_ip: 192.168.1.10
server_password: arcd
- ref: data_manipulation_bot
type: DataManipulationBot
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.21
server_password: arcd
- ref: dos_bot
type: DoSBot
- type: DoSBot
options:
target_ip_address: 192.168.10.21
payload: SPOOF DATA
port_scan_p_of_success: 0.8
services:
- ref: client_1_dns_client
type: DNSClient
- type: DNSClient
options:
dns_server: 192.168.1.10
- ref: client_1_dns_server
type: DNSServer
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.10
- ref: client_1_database_service
type: DatabaseService
- type: DatabaseService
options:
backup_server_ip: 192.168.1.10
- ref: client_1_web_service
type: WebServer
- ref: client_1_ftp_server
type: FTPServer
- type: WebServer
- type: FTPServer
options:
server_password: arcd
- ref: client_1_ntp_client
type: NTPClient
- type: NTPClient
options:
ntp_server_ip: 192.168.1.10
- ref: client_1_ntp_server
type: NTPServer
- ref: client_2
- type: NTPServer
- hostname: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
# pre installed services and applications
- ref: client_3
- hostname: client_3
type: computer
hostname: client_3
ip_address: 192.168.10.23
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
@@ -154,13 +132,11 @@ simulation:
# pre installed services and applications
links:
- ref: switch_1___client_1
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- ref: switch_1___client_2
endpoint_a_ref: switch_1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_hostname: client_2
endpoint_b_port: 1

Some files were not shown because too many files have changed in this diff Show More