Merged PR 545: Merge release/3.3 into main

Related work items: #2681, #2686, #2689, #2718, #2720, #2721, #2736, #2748, #2768, #2769, #2772, #2779, #2781, #2799, #2826, #2837, #2844
This commit is contained in:
Marek Wolan
2024-09-18 08:21:58 +00:00
143 changed files with 24135 additions and 585 deletions

View File

@@ -102,9 +102,7 @@ stages:
version: '2.1.x'
- script: |
coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml --cov-fail-under=80
coverage xml -o coverage.xml -i
coverage html -d htmlcov -i
python run_test_and_coverage.py
displayName: 'Run tests and code coverage'
# Run the notebooks

View File

@@ -5,6 +5,37 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [3.3.0] - 2024-09-04
### Added
- Random Number Generator Seeding by specifying a random number seed in the config file.
- Implemented Terminal service class, providing a generic terminal simulation.
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.
- Added actions to establish SSH connections, send commands remotely and terminate SSH connections.
- Added actions to change users' passwords.
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
main port they're assigned.
- Added two new red applications: ``C2Beacon`` and ``C2Server`` which aim to simulate malicious network infrastructure.
Refer to the ``Command and Control Application Suite E2E Demonstration`` notebook for more information.
- Added reward calculation details to AgentHistoryItem.
- Added a new Privilege-Escalation-and Data-Loss-Example.ipynb notebook with a realistic cyber scenario focusing on
internal privilege escalation and data loss through the manipulation of SSH access and Access Control Lists (ACLs).
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
- It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty`
- Node observations can now be configured to show the number of active local and remote logins.
### Fixed
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)
- Updated `SoftwareManager` `install` and `uninstall` to handle all functionality that was being done at the `install`
and `uninstall` methods in the `Node` class.
- Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the
payload to any software listening on the destination port of the `Frame`.
### Removed
- Removed the `install` and `uninstall` methods in the `Node` class.
## [3.2.0] - 2024-07-18
@@ -17,7 +48,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML
- Agent logging for agents' internal decision logic
- Action masking in all PrimAITE environments
### Changed
- Application registry was moved to the `Application` class and now updates automatically when Application is subclassed
- Databases can no longer respond to request while performing a backup
@@ -27,6 +57,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Frame `size` attribute now includes both core size and payload size in bytes
- The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`
- Tidied up CHANGELOG
- Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity.
- Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits.
### Fixed
- Links and airspaces can no longer transmit data if this would exceed their bandwidth

View File

@@ -24,6 +24,8 @@ PrimAITE presents the following features:
- Support for multiple agents, each having their own customisable observation space, action space, and reward function definition, and either deterministic or RL-directed behaviour
Whilst PrimAITE ships with a number of example modelled scenarios (a.k.a. Use Cases), it has not been developed to mandate the solving of a single cyber challenge, and instead provides a highly flexible environment application that can be extended and reconfigured by the user to suit their specific cyber defence training and evaluation needs. PrimAITE provides default networks, red agent and green agent behaviour, reward functions, and action / observation space configuration, all of which can be utilised out of the box, but which ultimately can (and in some instances should) be built upon and / or reconfigured to meet the needs of different defensive agent developers. The PrimAITE user guide provides comprehensive instruction on all PrimAITE features, functionality and components, and can be consulted in order to help guide users in any reconfiguration or enhancements they wish to undertake; a library of example Jupyter notebooks are also provided to support such work.
## Getting Started with PrimAITE
### 💫 Installation

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
from report import build_benchmark_md_report
from report import build_benchmark_md_report, md2pdf
from stable_baselines3 import PPO
import primaite
@@ -159,6 +159,13 @@ def run(
learning_rate: float = 3e-4,
) -> None:
"""Run the PrimAITE benchmark."""
# generate report folder
v_str = f"v{primaite.__version__}"
version_result_dir = _RESULTS_ROOT / v_str
version_result_dir.mkdir(exist_ok=True, parents=True)
output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md"
benchmark_start_time = datetime.now()
session_metadata_dict = {}
@@ -193,6 +200,12 @@ def run(
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),
results_root_path=_RESULTS_ROOT,
output_path=output_path,
)
md2pdf(
md_path=output_path,
pdf_path=str(output_path).replace(".md", ".pdf"),
css_path="static/styles.css",
)

View File

@@ -2,6 +2,7 @@
import json
import sys
from datetime import datetime
from os import PathLike
from pathlib import Path
from typing import Dict, Optional
@@ -14,7 +15,7 @@ from utils import _get_system_info
import primaite
PLOT_CONFIG = {
"size": {"auto_size": False, "width": 1500, "height": 900},
"size": {"auto_size": False, "width": 800, "height": 640},
"template": "plotly_white",
"range_slider": False,
}
@@ -144,6 +145,20 @@ def _plot_benchmark_metadata(
yaxis={"title": "Total Reward"},
title=title,
)
fig.update_layout(
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255,255,255,0.3)",
)
)
for trace in fig["data"]:
if trace["name"].startswith("Session"):
trace["showlegend"] = False
fig["data"][0]["name"] = "Individual Sessions"
fig["data"][0]["showlegend"] = True
return fig
@@ -194,6 +209,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
title=title,
)
fig["data"][0]["showlegend"] = True
fig.update_layout(legend=dict(yanchor="top", y=-0.2, xanchor="left", x=0.01, orientation="h"))
return fig
@@ -234,10 +250,7 @@ def _plot_av_s_per_100_steps_10_nodes(
"""
major_v = primaite.__version__.split(".")[0]
title = f"Performance of Minor and Bugfix Releases for Major Version {major_v}"
subtitle = (
f"Average Training Time per 100 Steps on 10 Nodes "
f"(target: <= {PLOT_CONFIG['av_s_per_100_steps_10_nodes_benchmark_threshold']} seconds)"
)
subtitle = "Average Training Time per 100 Steps on 10 Nodes "
title = f"{title} <br><sub>{subtitle}</sub>"
layout = go.Layout(
@@ -250,24 +263,12 @@ def _plot_av_s_per_100_steps_10_nodes(
versions = sorted(list(version_times_dict.keys()))
times = [version_times_dict[version] for version in versions]
av_s_per_100_steps_10_nodes_benchmark_threshold = PLOT_CONFIG["av_s_per_100_steps_10_nodes_benchmark_threshold"]
# Calculate the appropriate maximum y-axis value
max_y_axis_value = max(max(times), av_s_per_100_steps_10_nodes_benchmark_threshold) + 1
fig.add_trace(
go.Bar(
x=versions,
y=times,
text=times,
textposition="auto",
)
)
fig.add_trace(go.Bar(x=versions, y=times, text=times, textposition="auto", texttemplate="%{y:.3f}"))
fig.update_layout(
xaxis_title="PrimAITE Version",
yaxis_title="Avg Time per 100 Steps on 10 Nodes (seconds)",
yaxis=dict(range=[0, max_y_axis_value]),
title=title,
)
@@ -275,7 +276,11 @@ def _plot_av_s_per_100_steps_10_nodes(
def build_benchmark_md_report(
benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path
benchmark_start_time: datetime,
session_metadata: Dict,
config_path: Path,
results_root_path: Path,
output_path: PathLike,
) -> None:
"""
Generates a Markdown report for a benchmarking session, documenting performance metrics and graphs.
@@ -327,7 +332,7 @@ def build_benchmark_md_report(
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
with open(version_result_dir / f"PrimAITE v{primaite_version} Benchmark Report.md", "w") as file:
with open(output_path, "w") as file:
# Title
file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n")
file.write("## PrimAITE Dev Team\n")
@@ -401,3 +406,15 @@ def build_benchmark_md_report(
f"![Performance of Minor and Bugfix Releases for Major Version {major_v}]"
f"({performance_benchmark_plot_path.name})\n"
)
def md2pdf(md_path: PathLike, pdf_path: PathLike, css_path: PathLike) -> None:
"""Generate PDF version of Markdown report."""
from md2pdf.core import md2pdf
md2pdf(
pdf_file_path=pdf_path,
md_file_path=md_path,
base_url=Path(md_path).parent,
css_file_path=css_path,
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 316 KiB

After

Width:  |  Height:  |  Size: 295 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

View File

@@ -1,10 +1,10 @@
# PrimAITE v3.0.0 Learning Benchmark
# PrimAITE v3.3.0 Learning Benchmark
## PrimAITE Dev Team
### 2024-07-20
### 2024-09-02
---
## 1 Introduction
PrimAITE v3.0.0 was benchmarked automatically upon release. Learning rate metrics were captured to be referenced during system-level testing and user acceptance testing (UAT).
PrimAITE v3.3.0 was benchmarked automatically upon release. Learning rate metrics were captured to be referenced during system-level testing and user acceptance testing (UAT).
The benchmarking process consists of running 5 training session using the same config file. Each session trains an agent for 1000 episodes, with each episode consisting of 128 steps.
The total reward per episode from each session is captured. This is then used to calculate an caverage total reward per episode from the 5 individual sessions for smoothing. Finally, a 25-widow rolling average of the average total reward per session is calculated for further smoothing.
## 2 System Information
@@ -26,12 +26,12 @@ The total reward per episode from each session is captured. This is then used to
- **Total Sessions:** 5
- **Total Episodes:** 5005
- **Total Steps:** 640000
- **Av Session Duration (s):** 1452.5910
- **Av Step Duration (s):** 0.0454
- **Av Duration per 100 Steps per 10 Nodes (s):** 4.5393
- **Av Session Duration (s):** 1458.2831
- **Av Step Duration (s):** 0.0456
- **Av Duration per 100 Steps per 10 Nodes (s):** 4.5571
## 4 Graphs
### 4.1 v3.0.0 Learning Benchmark Plot
![PrimAITE 3.0.0 Learning Benchmark Plot](PrimAITE v3.0.0 Learning Benchmark.png)
### 4.1 v3.3.0 Learning Benchmark Plot
![PrimAITE 3.3.0 Learning Benchmark Plot](PrimAITE v3.3.0 Learning Benchmark.png)
### 4.2 Learning Benchmark of Minor and Bugfix Releases for Major Version 3
![Learning Benchmark of Minor and Bugfix Releases for Major Version 3](PrimAITE Learning Benchmark of Minor and Bugfix Releases for Major Version 3.png)
### 4.3 Performance of Minor and Bugfix Releases for Major Version 3

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
body {
font-family: 'Arial', sans-serif;
line-height: 1.6;
/* margin: 1cm; */
}
h1, h2, h3, h4, h5, h6 {
font-weight: bold;
/* margin: 1em 0; */
}
p {
/* margin: 0.5em 0; */
}
ul, ol {
margin: 1em 0;
padding-left: 1.5em;
}
pre {
background: #f4f4f4;
padding: 0.5em;
overflow-x: auto;
}
img {
max-width: 100%;
height: auto;
}
table {
width: 100%;
border-collapse: collapse;
margin: 1em 0;
}
th, td {
padding: 0.5em;
border: 1px solid #ddd;
}

BIN
docs/_static/c2_sequence.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

View File

@@ -25,6 +25,7 @@ What is PrimAITE?
source/game_layer
source/simulation
source/config
source/rewards
source/customising_scenarios
source/varying_config_files
source/environment
@@ -59,6 +60,8 @@ The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effect
- Modelling background (green) pattern-of-life;
- Operates at machine-speed to enable fast training cycles via Reinforcement Learning (RL).
PrimAITE has been designed as an extensible environment and toolkit to support the development, test, training and evaluation of AI-based cyber defensive agents. Whilst PrimAITE ships with a number of example modelled scenarios (a.k.a. Use Cases), it has not been developed to mandate the solving of a single cyber challenge, and instead provides a highly flexible environment application that can be extended and reconfigured by the user to suit their specific cyber defence training and evaluation needs. PrimAITE provides default networks, red agent and green agent behaviour, reward functions, and action / observation space configuration, all of which can be utilised out of the box, but which ultimately can (and in some instances should) be built upon and / or reconfigured to meet the needs of different defensive agent developers. The PrimAITE user guide provides comprehensive instruction on all PrimAITE features, functionality and components, and can be consulted in order to help guide users in any reconfiguration or enhancements they wish to undertake; a library of example Jupyter notebooks are also provided to support such work.
Features
^^^^^^^^

View File

@@ -9,6 +9,8 @@ about which actions are invalid based on the current environment state. For inst
software on a node that is turned off. Therefore, if an agent has a NODE_SOFTWARE_INSTALL in it's action map for that node,
the action mask will show `0` in the corresponding entry.
*Note: just because an action is available in the action mask does not mean it will be successful when executed. It just means it's possible to try to execute the action at this time.*
Configuration
=============
Action masking is supported for agents that use the `ProxyAgent` class (the class used for connecting to RL algorithms).
@@ -23,95 +25,121 @@ The following logic is applied:
+==========================================+=====================================================================+
| **DONOTHING** | Always Possible. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_SCAN** | Node is on. Service is running. |
| **NODE_SERVICE_SCAN** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_STOP** | Node is on. Service is running. |
| **NODE_SERVICE_STOP** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_START** | Node is on. Service is stopped. |
| **NODE_SERVICE_START** | Node is on. Service is stopped. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_PAUSE** | Node is on. Service is running. |
| **NODE_SERVICE_PAUSE** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_RESUME** | Node is on. Service is paused. |
| **NODE_SERVICE_RESUME** | Node is on. Service is paused. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_RESTART** | Node is on. Service is running. |
| **NODE_SERVICE_RESTART** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_DISABLE** | Node is on. |
| **NODE_SERVICE_DISABLE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_ENABLE** | Node is on. Service is disabled. |
| **NODE_SERVICE_ENABLE** | Node is on. Service is disabled. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SERVICE_FIX** | Node is on. Service is running. |
| **NODE_SERVICE_FIX** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_APPLICATION_EXECUTE** | Node is on. |
| **NODE_APPLICATION_EXECUTE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_APPLICATION_SCAN** | Node is on. Application is running. |
| **NODE_APPLICATION_SCAN** | Node is on. Application is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_APPLICATION_CLOSE** | Node is on. Application is running. |
| **NODE_APPLICATION_CLOSE** | Node is on. Application is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_APPLICATION_FIX** | Node is on. Application is running. |
| **NODE_APPLICATION_FIX** | Node is on. Application is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_APPLICATION_INSTALL** | Node is on. |
| **NODE_APPLICATION_INSTALL** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_APPLICATION_REMOVE** | Node is on. |
| **NODE_APPLICATION_REMOVE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_SCAN** | Node is on. File exists. File not deleted. |
| **NODE_FILE_SCAN** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_CREATE** | Node is on. |
| **NODE_FILE_CREATE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_CHECKHASH** | Node is on. File exists. File not deleted. |
| **NODE_FILE_CHECKHASH** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_DELETE** | Node is on. File exists. |
| **NODE_FILE_DELETE** | Node is on. File exists. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_REPAIR** | Node is on. File exists. File not deleted. |
| **NODE_FILE_REPAIR** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_RESTORE** | Node is on. File exists. File is deleted. |
| **NODE_FILE_RESTORE** | Node is on. File exists. File is deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_CORRUPT** | Node is on. File exists. File not deleted. |
| **NODE_FILE_CORRUPT** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FILE_ACCESS** | Node is on. File exists. File not deleted. |
| **NODE_FILE_ACCESS** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FOLDER_CREATE** | Node is on. |
| **NODE_FOLDER_CREATE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FOLDER_SCAN** | Node is on. Folder exists. Folder not deleted. |
| **NODE_FOLDER_SCAN** | Node is on. Folder exists. Folder not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FOLDER_CHECKHASH** | Node is on. Folder exists. Folder not deleted. |
| **NODE_FOLDER_CHECKHASH** | Node is on. Folder exists. Folder not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FOLDER_REPAIR** | Node is on. Folder exists. Folder not deleted. |
| **NODE_FOLDER_REPAIR** | Node is on. Folder exists. Folder not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_FOLDER_RESTORE** | Node is on. Folder exists. Folder is deleted. |
| **NODE_FOLDER_RESTORE** | Node is on. Folder exists. Folder is deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_OS_SCAN** | Node is on. |
| **NODE_OS_SCAN** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_NIC_ENABLE** | NIC is disabled. Node is on. |
| **HOST_NIC_ENABLE** | NIC is disabled. Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_NIC_DISABLE** | NIC is enabled. Node is on. |
| **HOST_NIC_DISABLE** | NIC is enabled. Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_SHUTDOWN** | Node is on. |
| **NODE_SHUTDOWN** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_STARTUP** | Node is off. |
| **NODE_STARTUP** | Node is off. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_RESET** | Node is on. |
| **NODE_RESET** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_NMAP_PING_SCAN** | Node is on. |
| **NODE_NMAP_PING_SCAN** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_NMAP_PORT_SCAN** | Node is on. |
| **NODE_NMAP_PORT_SCAN** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_HOST_NMAP_NETWORK_SERVICE_RECON** | Node is on. |
| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_ROUTER_PORT_ENABLE** | Router is on. |
| **NETWORK_PORT_ENABLE** | Node is on. Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_ROUTER_PORT_DISABLE** | Router is on. |
| **NETWORK_PORT_DISABLE** | Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_ROUTER_ACL_ADDRULE** | Router is on. |
| **ROUTER_ACL_ADDRULE** | Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_ROUTER_ACL_REMOVERULE** | Router is on. |
| **ROUTER_ACL_REMOVERULE** | Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FIREWALL_PORT_ENABLE** | Firewall is on. |
| **FIREWALL_ACL_ADDRULE** | Firewall is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FIREWALL_PORT_DISABLE** | Firewall is on. |
| **FIREWALL_ACL_REMOVERULE** | Firewall is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FIREWALL_ACL_ADDRULE** | Firewall is on. |
| **NODE_NMAP_PING_SCAN** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FIREWALL_ACL_REMOVERULE** | Firewall is on. |
| **NODE_NMAP_PORT_SCAN** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_DATABASE_CLIENT** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_RANSOMWARE_SCRIPT** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_DOSBOT** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_C2_BEACON** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_RANSOMWARE_LAUNCH** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_RANSOMWARE_CONFIGURE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_TERMINAL_COMMAND** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_DATA_EXFILTRATE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_ACCOUNTS_CHANGE_PASSWORD** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **SSH_TO_REMOTE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **SESSIONS_REMOTE_LOGOFF** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SEND_REMOTE_COMMAND** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+

View File

@@ -172,3 +172,8 @@ The amount of timesteps that the frequency can randomly change.
---------------
If ``True``, gymnasium flattening will be performed on the observation space before sending to the agent. Set this to ``True`` if your agent does not support nested observation spaces.
``Agent History``
-----------------
Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation.

View File

@@ -28,6 +28,7 @@ This section defines high-level settings that apply across the game, currently i
high: 10
medium: 5
low: 0
seed: 1
``max_episode_length``
----------------------
@@ -54,3 +55,8 @@ See :ref:`List of IPProtocols <List of IPProtocols>` for a list of protocols.
--------------
These are used to determine the thresholds of high, medium and low categories for counted observation occurrences.
``seed``
--------
Used to configure the random seeds used within PrimAITE, ensuring determinism within episode/session runs. If empty or set to -1, no seed is set.

View File

@@ -53,3 +53,27 @@ The number of time steps required to occur in order for the node to cycle from `
Optional. Default value is ``3``.
The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``.
``users``
---------
The list of pre-existing users that are additional to the default admin user (``username=admin``, ``password=admin``).
Additional users are configured as an array and must contain a ``username``, ``password``, and can contain an optional
boolean ``is_admin``.
Example of adding two additional users to a node:
.. code-block:: yaml
simulation:
network:
nodes:
- hostname: [hostname]
type: [Node Type]
users:
- username: jane.doe
password: '1234'
is_admin: true
- username: john.doe
password: password_1
is_admin: false

View File

@@ -7,7 +7,7 @@
+===================+=========+====================================+=======================================================================================================+====================================================================+
| gymnasium | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| ipywidgets | 8.1.3 | BSD License | Jupyter interactive widgets | http://jupyter.org |
| ipywidgets | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| jupyterlab | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
@@ -23,7 +23,7 @@
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| plotly | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| polars | 0.18.4 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
| polars | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| prettytable | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
@@ -31,7 +31,7 @@
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| PyYAML | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| ray | 2.23.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
| ray | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| stable-baselines3 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
@@ -39,7 +39,7 @@
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| typer | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| Deepdiff | 7.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff |
| Deepdiff | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| sb3_contrib | 2.3.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking)| https://github.com/Stable-Baselines-Team/stable-baselines3-contrib |
| sb3_contrib | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking)| https://github.com/Stable-Baselines-Team/stable-baselines3-contrib |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+

126
docs/source/rewards.rst Normal file
View File

@@ -0,0 +1,126 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
Rewards
#######
Rewards in PrimAITE are based on a system of individual components that react to events in the simulation. An agent's reward function is calculated as the weighted sum of several reward components.
Some rewards, such as the ``GreenAdminDatabaseUnreachablePenalty``, can be marked as 'sticky' in their configuration. Setting this to ``True`` will mean that they continue to output the same value after an event until another event of that type.
In the instance of the ``GreenAdminDatabaseUnreachablePenalty``, the database admin reward will stay negative until the next successful database request is made, even if the database admin agents do nothing and the database returns a good state.
Components
**********
The following API pages describe the use of each reward component and the possible configuration options. An example of configuring each via yaml is also provided.
:py:class:`primaite.game.agent.rewards.DummyReward`
.. code-block:: yaml
agents:
- ref: agent_name
# ...
reward_function:
reward_components:
- type: DUMMY
weight: 1.0
:py:class:`primaite.game.agent.rewards.DatabaseFileIntegrity`
.. code-block:: yaml
agents:
- ref: agent_name
# ...
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 1.0
options:
node_hostname: server_1
folder_name: database
file_name: database.db
:py:class:`primaite.game.agent.rewards.WebServer404Penalty`
.. code-block:: yaml
agents:
- ref: agent_name
# ...
reward_function:
reward_components:
- type: WEB_SERVER_404_PENALTY
node_hostname: web_server
weight: 1.0
options:
service_name: WebService
sticky: false
:py:class:`primaite.game.agent.rewards.WebpageUnavailablePenalty`
.. code-block:: yaml
agents:
- ref: agent_name
# ...
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
node_hostname: computer_1
weight: 1.0
options:
sticky: false
:py:class:`primaite.game.agent.rewards.GreenAdminDatabaseUnreachablePenalty`
.. code-block:: yaml
agents:
- ref: agent_name
# ...
reward_function:
reward_components:
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 1.0
options:
node_hostname: admin_pc_1
sticky: false
:py:class:`primaite.game.agent.rewards.SharedReward`
.. code-block:: yaml
agents:
- ref: scripted_agent
# ...
- ref: agent_name
# ...
reward_function:
reward_components:
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: scripted_agent
:py:class:`primaite.game.agent.rewards.ActionPenalty`
.. code-block:: yaml
agents:
- ref: agent_name
# ...
reward_function:
reward_components:
- type: ACTION_PENALTY
weight: 1.0
options:
action_penalty: -0.3
do_nothing_penalty: 0.0

View File

@@ -97,8 +97,8 @@ Node Behaviours/Functions
- **receive_frame()**: Handles the processing of incoming network frames.
- **apply_timestep()**: Advances the state of the node according to the simulation timestep.
- **power_on()**: Initiates the node, enabling all connected Network Interfaces and starting all Services and
Applications, taking into account the `start_up_duration`.
- **power_off()**: Stops the node's operations, adhering to the `shut_down_duration`.
Applications, taking into account the ``start_up_duration``.
- **power_off()**: Stops the node's operations, adhering to the ``shut_down_duration``.
- **ping()**: Sends ICMP echo requests to a specified IP address to test connectivity.
- **has_enabled_network_interface()**: Checks if the node has any network interfaces enabled, facilitating network
communication.
@@ -109,3 +109,205 @@ Node Behaviours/Functions
The Node class handles installation of system software, network connectivity, frame processing, system logging, and
power states. It establishes baseline functionality while allowing subclassing to model specific node types like hosts,
routers, firewalls etc. The flexible architecture enables composing complex network topologies.
User, UserManager, and UserSessionManager
=========================================
The ``base.py`` module also includes essential classes for managing users and their sessions within the PrimAITE
simulation. These are the ``User``, ``UserManager``, and ``UserSessionManager`` classes. The base ``Node`` class comes
with ``UserManager``, and ``UserSessionManager`` classes pre-installed.
User Class
----------
The ``User`` class represents a user in the system. It includes attributes such as ``username``, ``password``,
``disabled``, and ``is_admin`` to define the user's credentials and status.
Example Usage
^^^^^^^^^^^^^
Creating a user:
.. code-block:: python
user = User(username="john_doe", password="12345")
UserManager Class
-----------------
The ``UserManager`` class handles user management tasks such as creating users, authenticating them, changing passwords,
and enabling or disabling user accounts. It maintains a dictionary of users and provides methods to manage them
effectively.
Example Usage
^^^^^^^^^^^^^
Creating a ``UserManager`` instance and adding a user:
.. code-block:: python
user_manager = UserManager()
user_manager.add_user(username="john_doe", password="12345")
Authenticating a user:
.. code-block:: python
user = user_manager.authenticate_user(username="john_doe", password="12345")
UserSessionManager Class
------------------------
The ``UserSessionManager`` class manages user sessions, including local and remote sessions. It handles session creation,
timeouts, and provides methods for logging users in and out.
Example Usage
^^^^^^^^^^^^^
Creating a ``UserSessionManager`` instance and logging a user in locally:
.. code-block:: python
session_manager = UserSessionManager()
session_id = session_manager.local_login(username="john_doe", password="12345")
Logging a user out:
.. code-block:: python
session_manager.local_logout()
Practical Examples
------------------
Below are unit tests which act as practical examples illustrating how to use the ``User``, ``UserManager``, and
``UserSessionManager`` classes within the context of a client-server network simulation.
Setting up a Client-Server Network
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
from typing import Tuple
from uuid import uuid4
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
@pytest.fixture(scope="function")
def client_server_network() -> Tuple[Computer, Server, Network]:
network = Network()
client = Computer(
hostname="client",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client.power_on()
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
network.connect(client.network_interface[1], server.network_interface[1])
return client, server, network
Local Login Success
^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
Local Login Failure
^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_local_login_failure(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert not client.user_session_manager.local_user_logged_in
Adding a New User and Successful Local Login
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_new_user_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_manager.add_user(username="jane.doe", password="12345")
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
Clearing Previous Login on New Local Login
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_new_local_login_clears_previous_login(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
client.user_manager.add_user(username="jane.doe", password="12345")
new_session_id = client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "jane.doe"
assert new_session_id != current_session_id
Persistent Login for the Same User
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_new_local_login_attempt_same_uses_persists(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
new_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
assert new_session_id == current_session_id

View File

@@ -49,3 +49,5 @@ fundamental network operations:
5. **NTP (Network Time Protocol) Client:** Synchronises the host's clock with network time servers.
6. **Web Browser:** A simulated application that allows the host to request and display web content.
7. **Terminal:** A simulated service that allows the host to connect to remote hosts and execute commands.

View File

@@ -3,7 +3,7 @@
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
######
Router
Wireless Router
######
The ``WirelessRouter`` class extends the functionality of the standard ``Router`` class within PrimAITE,

View File

@@ -0,0 +1,319 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _C2_Suite:
Command and Control Application Suite
#####################################
Comprising of two applications, the Command and Control (C2) suite intends to introduce
malicious network architecture and further the realism of red agents within PrimAITE.
Overview:
=========
These two new classes give red agents a cyber realistic way of leveraging the capabilities of the ``Terminal`` application whilst introducing more opportunities for the blue agent(s) to notice and subvert a red agent during an episode.
For a more in-depth look at the command and control applications then please refer to the ``C2-E2E-Notebook``.
``C2 Server``
"""""""""""""
The C2 Server application is intended to represent the malicious infrastructure already under the control of an adversary.
The C2 Server is configured to listen and await ``keep alive`` traffic from a C2 beacon. Once received the C2 Server is able to send and receive C2 commands.
Currently, the C2 Server offers four commands:
+---------------------+---------------------------------------------------------------------------+
|C2 Command | Meaning |
+=====================+===========================================================================+
|RANSOMWARE_CONFIGURE | Configures an installed ransomware script based on the passed parameters. |
+---------------------+---------------------------------------------------------------------------+
|RANSOMWARE_LAUNCH | Launches the installed ransomware script. |
+---------------------+---------------------------------------------------------------------------+
|DATA_EXFILTRATION | Copies a target file from a remote node to the C2 Beacon & Server via FTP |
+---------------------+---------------------------------------------------------------------------+
|TERMINAL | Executes a command via the terminal installed on the C2 Beacons Host. |
+---------------------+---------------------------------------------------------------------------+
It's important to note that in order to keep PrimAITE realistic from a cyber perspective,
the C2 Server application should never be visible or actionable upon directly by the blue agent.
This is because in the real world, C2 servers are hosted on ephemeral public domains that would not be accessible by private network blue agent.
Therefore granting blue agent(s) the ability to perform counter measures directly against the application would be unrealistic.
It is more accurate to see the host that the C2 Beacon is installed on as being able to route to the C2 Server (Internet Access).
``C2 Beacon``
"""""""""""""
The C2 Beacon application is intended to represent malware that is used to establish and maintain contact to a C2 Server within a compromised network.
A C2 Beacon will need to be first configured with the C2 Server IP Address which can be done via the ``configure`` method.
Once installed and configured; the C2 beacon can establish connection with the C2 Server via executing the application.
This will send an initial ``keep alive`` to the given C2 Server (The C2 Server IPv4Address must be given upon C2 Beacon configuration).
Which is then resolved and responded by another ``Keep Alive`` by the C2 server back to the C2 beacon to confirm connection.
The C2 Beacon will send out periodic keep alive based on its configuration parameters to configure it's active connection with the C2 server.
It's recommended that a C2 Beacon is installed and configured mid episode by a Red Agent for a more cyber realistic simulation.
Usage
=====
As mentioned, the C2 Suite is intended to grant Red Agents further flexibility whilst also expanding a blue agent's observation space.
Adding to this, the following behaviour of the C2 beacon can be configured by users for increased domain randomisation:
+---------------------+---------------------------------------------------------------------------+
|Configuration Option | Option Meaning |
+=====================+===========================================================================+
|c2_server_ip_address | The IP Address of the C2 Server. (The C2 Server must be running) |
+---------------------+---------------------------------------------------------------------------+
|keep_alive_frequency | How often should the C2 Beacon confirm it's connection in timesteps. |
+---------------------+---------------------------------------------------------------------------+
|masquerade_protocol | What protocol should the C2 traffic masquerade as? (HTTP, FTP or DNS) |
+---------------------+---------------------------------------------------------------------------+
|masquerade_port | What port should the C2 traffic use? (TCP or UDP) |
+---------------------+---------------------------------------------------------------------------+
Implementation
==============
Both applications inherit from an abstract C2 which handles the keep alive functionality and main logic.
However, each host implements it's own receive methods.
- The ``C2 Beacon`` is responsible for the following logic:
- Establishes and confirms connection to the C2 Server via sending ``C2Payload.KEEP_ALIVE``.
- Receives and executes C2 Commands given by the C2 Server via ``C2Payload.INPUT``.
- Returns the RequestResponse of the C2 Commands executed back the C2 Server via ``C2Payload.OUTPUT``.
- The ``C2 Server`` is responsible for the following logic:
- Listens and resolves connection to a C2 Beacon via responding to ``C2Payload.KEEP_ALIVE``.
- Sends C2 Commands to the C2 Beacon via ``C2Payload.INPUT``.
- Receives the RequestResponse of the C2 Commands executed by C2 Beacon via ``C2Payload.OUTPUT``.
The sequence diagram below clarifies the functionality of both applications:
.. image:: ../../../../_static/c2_sequence.png
:width: 1000
:align: center
For further details and more in-depth examples please refer to the ``Command-&-Control notebook``
Examples
========
Python
""""""
.. code-block:: python
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Command, C2Server
from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon
# Network Setup
network = Network()
switch = Switch(hostname="switch", start_up_duration=0, num_ports=4)
switch.power_on()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
network.connect(node_a.network_interface[1], switch.network_interface[1])
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_b.network_interface[1], switch.network_interface[2])
node_c = Computer(hostname="node_c", ip_address="192.168.0.12", subnet_mask="255.255.255.0", start_up_duration=0)
node_c.power_on()
network.connect(node_c.network_interface[1], switch.network_interface[3])
node_c.software_manager.install(software_class=DatabaseService)
node_b.software_manager.install(software_class=DatabaseClient)
node_b.software_manager.install(software_class=RansomwareScript)
node_a.software_manager.install(software_class=C2Server)
# C2 Application objects
c2_server_host: Computer = network.get_node_by_hostname("node_a")
c2_beacon_host: Computer = network.get_node_by_hostname("node_b")
c2_server: C2Server = c2_server_host.software_manager.software["C2Server"]
c2_beacon: C2Beacon = c2_beacon_host.software_manager.software["C2Beacon"]
# Configuring the C2 Beacon
c2_beacon.configure(c2_server_ip_address="192.168.0.10", keep_alive_frequency=5)
# Launching the C2 Server (Needs to be running in order to listen for connections)
c2_server.run()
# Establishing connection
c2_beacon.establish()
# Example command: Creating a file
file_create_command = {
"commands": [
["file_system", "create", "folder", "test_folder"],
["file_system", "create", "file", "test_folder", "example_file", "True"],
],
"username": "admin",
"password": "admin",
"ip_address": None,
}
c2_server.send_command(C2Command.TERMINAL, command_options=file_create_command)
# Example command: Installing and configuring Ransomware:
ransomware_installation_command = { "commands": [
["software_manager","application","install","RansomwareScript"],
],
"username": "admin",
"password": "admin",
"ip_address": None,
}
c2_server.send_command(given_command=C2Command.TERMINAL, command_options=ransomware_installation_command)
ransomware_config = {"server_ip_address": "192.168.0.12"}
c2_server.send_command(given_command=C2Command.RANSOMWARE_CONFIGURE, command_options=ransomware_config)
c2_beacon_host.software_manager.show()
# Example command: File Exfiltration
data_exfil_options = {
"username": "admin",
"password": "admin",
"ip_address": None,
"target_ip_address": "192.168.0.12",
"target_file_name": "database.db",
"target_folder_name": "database",
}
c2_server.send_command(given_command=C2Command.DATA_EXFILTRATION, command_options=data_exfil_options)
# Example command: Launching Ransomware
c2_server.send_command(given_command=C2Command.RANSOMWARE_LAUNCH, command_options={})
Via Configuration
"""""""""""""""""
.. code-block:: yaml
simulation:
network:
nodes:
- ref: example_computer_1
hostname: computer_a
type: computer
...
applications:
type: C2Server
...
hostname: computer_b
type: computer
...
# A C2 Beacon will not automatically connection to a C2 Server.
# Either an agent must use application_execute.
# Or a if using the simulation layer - .establish().
applications:
type: C2Beacon
options:
c2_server_ip_address: ...
keep_alive_frequency: 5
masquerade_protocol: tcp
masquerade_port: http
listen_on_ports:
- 80
- 53
- 21
C2 Beacon Configuration
=======================
``c2_server_ip_address``
""""""""""""""""""""""""
IP address of the ``C2Server`` that the C2 Beacon will use to establish connection.
This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``.
``Keep Alive Frequency``
""""""""""""""""""""""""
How often should the C2 Beacon confirm it's connection in timesteps.
For example, if the keep alive Frequency is set to one then every single timestep
the C2 connection will be confirmed.
It's worth noting that this may be a useful option when investigating
network blue agent observation space.
This must be a valid integer i.e ``10``. Defaults to ``5``.
``Masquerade Protocol``
"""""""""""""""""""""""
The protocol that the C2 Beacon will use to communicate to the C2 Server with.
Currently only ``TCP`` and ``UDP`` are valid masquerade protocol options.
It's worth noting that this may be a useful option to bypass ACL rules.
This must be a string i.e *UDP*. Defaults to ``TCP``.
*Please refer to the ``IPProtocol`` class for further reference.*
``Masquerade Port``
"""""""""""""""""""
What port that the C2 Beacon will use to communicate to the C2 Server with.
Currently only ``FTP``, ``HTTP`` and ``DNS`` are valid masquerade port options.
It's worth noting that this may be a useful option to bypass ACL rules.
This must be a string i.e ``DNS``. Defaults to ``HTTP``.
*Please refer to the ``IPProtocol`` class for further reference.*
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`
C2 Server Configuration
=======================
*The C2 Server does not currently offer any unique configuration options and will configure itself to match the C2 beacon's network behaviour.*
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -158,10 +158,6 @@ If not using the data manipulation bot manually, it needs to be used with a data
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: DataManipulationBot
.. |SOFTWARE_NAME_BACKTICK| replace:: ``DataManipulationBot``
``server_ip``
"""""""""""""
@@ -203,3 +199,8 @@ Optional. Default value is ``0.1``.
The chance of the ``DataManipulationBot`` to succeed with a data manipulation attack.
This must be a float value between ``0`` and ``1``.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -90,11 +90,6 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: DatabaseClient
.. |SOFTWARE_NAME_BACKTICK| replace:: ``DatabaseClient``
``db_server_ip``
""""""""""""""""
@@ -109,3 +104,8 @@ This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.25
Optional. Default value is ``None``.
The password that the ``DatabaseClient`` will use to access the :ref:`DatabaseService`.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -98,11 +98,6 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: DoSBot
.. |SOFTWARE_NAME_BACKTICK| replace:: ``DoSBot``
``target_ip_address``
"""""""""""""""""""""
@@ -161,3 +156,8 @@ Optional. Default value is ``1000``.
The maximum number of sessions the ``DoSBot`` is able to make.
This must be an integer value equal to or greater than ``0``.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -346,10 +346,8 @@ Perform a full box scan on all ports, over both TCP and UDP, on a whole subnet:
| 192.168.1.13 | 219 | ARP | UDP |
+--------------+------+-----------------+----------+
Configuration
=============
.. include:: ../common/common_configuration.rst
``Common Attributes``
"""""""""""""""""""""
.. |SOFTWARE_NAME| replace:: NMAP
.. |SOFTWARE_NAME_BACKTICK| replace:: ``NMAP``
See :ref:`Common Configuration`

View File

@@ -72,10 +72,6 @@ Configuration
The RansomwareScript inherits configuration options such as ``fix_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: RansomwareScript
.. |SOFTWARE_NAME_BACKTICK| replace:: ``RansomwareScript``
``server_ip``
"""""""""""""
@@ -83,3 +79,8 @@ The RansomwareScript inherits configuration options such as ``fix_duration`` fro
IP address of the :ref:`DatabaseService` which the ``RansomwareScript`` will encrypt.
This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -92,10 +92,6 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: WebBrowser
.. |SOFTWARE_NAME_BACKTICK| replace:: ``WebBrowser``
``target_url``
""""""""""""""
@@ -109,3 +105,9 @@ The domain ``arcd.com`` can be matched by
- http://arcd.com/
- http://arcd.com/users/
- arcd.com
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -2,26 +2,56 @@
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
``ref``
=======
.. _Common Configuration:
Human readable name used as reference for the |SOFTWARE_NAME_BACKTICK|. Not used in code.
Common Configuration
""""""""""""""""""""
``type``
========
ref
"""
The type of software that should be added. To add |SOFTWARE_NAME| this must be |SOFTWARE_NAME_BACKTICK|.
Human readable name used as reference for the software class. Not used in code.
``options``
===========
type
""""
The configuration options are the attributes that fall under the options for an application.
The type of software that should be added. To add the required software, this must be it's name.
options
"""""""
The configuration options are the attributes that fall under the options for an application or service.
``fix_duration``
""""""""""""""""
fix_duration
""""""""""""
Optional. Default value is ``2``.
The number of timesteps the |SOFTWARE_NAME| will remain in a ``FIXING`` state before going into a ``GOOD`` state.
The number of timesteps the software will remain in a ``FIXING`` state before going into a ``GOOD`` state.
listen_on_ports
^^^^^^^^^^^^^^^
Optional. The set of ports to listen on. This is in addition to the main port the software is designated. This can either be
the string name of ports or the port integers
Example:
.. code-block:: yaml
simulation:
network:
nodes:
- hostname: [hostname]
type: [Node Type]
services:
- type: [Service Type]
options:
listen_on_ports:
- 631
applications:
- type: [Application Type]
options:
listen_on_ports:
- SMB

View File

@@ -94,11 +94,6 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: DatabaseService
.. |SOFTWARE_NAME_BACKTICK| replace:: ``DatabaseService``
``backup_server_ip``
""""""""""""""""""""
@@ -114,3 +109,8 @@ This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.25
Optional. Default value is ``None``.
The password that needs to be provided by connecting clients in order to create a successful connection.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -84,10 +84,6 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: DNSClient
.. |SOFTWARE_NAME_BACKTICK| replace:: ``DNSClient``
``dns_server``
""""""""""""""
@@ -97,3 +93,8 @@ Optional. Default value is ``None``.
The IP Address of the :ref:`DNSServer`.
This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -83,16 +83,17 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: DNSServer
.. |SOFTWARE_NAME_BACKTICK| replace:: ``DNSServer``
domain_mapping
""""""""""""""
``domain_mapping``
""""""""""""""""""
Domain mapping takes the domain and IP Addresses as a key-value pairs i.e.
If the domain is "arcd.com" and the IP Address attributed to the domain is 192.168.0.10, then the value should be ``arcd.com: 192.168.0.10``
The key must be a string and the IP Address must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -83,7 +83,7 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
.. |SOFTWARE_NAME| replace:: FTPClient
.. |SOFTWARE_NAME_BACKTICK| replace:: ``FTPClient``
See :ref:`Common Configuration`

View File

@@ -81,14 +81,14 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: FTPServer
.. |SOFTWARE_NAME_BACKTICK| replace:: ``FTPServer``
``server_password``
"""""""""""""""""""
Optional. Default value is ``None``.
The password that needs to be provided by a connecting :ref:`FTPClient` in order to create a successful connection.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -80,11 +80,6 @@ Via Configuration
Configuration
=============
.. include:: ../common/common_configuration.rst
.. |SOFTWARE_NAME| replace:: NTPClient
.. |SOFTWARE_NAME_BACKTICK| replace:: ``NTPClient``
``ntp_server_ip``
"""""""""""""""""
@@ -93,3 +88,8 @@ Optional. Default value is ``None``.
The IP address of an NTP Server which provides a time that the ``NTPClient`` can synchronise to.
This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``.
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -75,10 +75,8 @@ Via Configuration
- ref: ntp_server
type: NTPServer
Configuration
=============
.. include:: ../common/common_configuration.rst
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
.. |SOFTWARE_NAME| replace:: NTPServer
.. |SOFTWARE_NAME_BACKTICK| replace:: ``NTPServer``
See :ref:`Common Configuration`

View File

@@ -0,0 +1,181 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _Terminal:
Terminal
########
The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment.
Overview
========
The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically
installed on Nodes when they are instantiated.
Key capabilities
""""""""""""""""
- Ensures packets are matched to an existing session
- Simulates common Terminal processes/commands.
- Leverages the Service base class for install/uninstall, status tracking etc.
Implementation
""""""""""""""
- Manages remote connections in a dictionary by session ID.
- Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate.
- Extends Service class.
- A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook.
Usage
"""""
- Pre-Installs on all ``Nodes`` (with the exception of ``Switches``).
- Terminal Clients connect, execute commands and disconnect from remote nodes.
- Ensures that users are logged in to the component before executing any commands.
- Service runs on SSH port 22 by default.
Usage
=====
The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node.
Python
""""""
.. code-block:: python
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
client = Computer(
hostname="client",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
operating_state=NodeOperatingState.ON,
)
terminal: Terminal = client.software_manager.software.get("Terminal")
Creating Remote Terminal Connection
"""""""""""""""""""""""""""""""""""
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
Executing a basic application install command
"""""""""""""""""""""""""""""""""""""""""""""
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"])
Creating a folder on a remote node
""""""""""""""""""""""""""""""""""
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
term_a_term_b_remote_connection.execute(["file_system", "create", "folder", "downloads"])
Disconnect from Remote Node
"""""""""""""""""""""""""""
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
term_a_term_b_remote_connection.disconnect()
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
See :ref:`Common Configuration`

View File

@@ -75,10 +75,8 @@ Via Configuration
- ref: web_server
type: WebServer
Configuration
=============
.. include:: ../common/common_configuration.rst
``Common Attributes``
^^^^^^^^^^^^^^^^^^^^^
.. |SOFTWARE_NAME| replace:: WebServer
.. |SOFTWARE_NAME_BACKTICK| replace:: ``WebServer``
See :ref:`Common Configuration`

View File

@@ -2,6 +2,8 @@
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _software:
Software
========
@@ -63,3 +65,10 @@ Processes
#########
`To be implemented`
Common Software Configuration
#############################
Below is a list of the common configuration items within Software components of PrimAITE:
.. include:: common/common_configuration.rst

View File

@@ -52,7 +52,7 @@ license-files = ["LICENSE"]
[project.optional-dependencies]
rl = [
"ray[rllib] >= 2.20.0, < 2.33",
"ray[rllib] >= 2.20.0, <2.33",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
"sb3-contrib==2.1.0",
@@ -75,7 +75,8 @@ dev = [
"wheel==0.38.4",
"nbsphinx==0.9.4",
"nbmake==1.5.4",
"pytest-xdist==3.3.1"
"pytest-xdist==3.3.1",
"md2pdf",
]
[project.scripts]

22
run_test_and_coverage.py Normal file
View File

@@ -0,0 +1,22 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import subprocess
import sys
from typing import Any
def run_command(command: Any):
"""Runs a command and returns the exit code."""
result = subprocess.run(command, shell=True)
if result.returncode != 0:
sys.exit(result.returncode)
# Run pytest with coverage
run_command(
"coverage run -m --source=primaite pytest -v -o junit_family=xunit2 "
"--junitxml=junit/test-results.xml --cov-fail-under=80"
)
# Generate coverage reports if tests passed
run_command("coverage xml -o coverage.xml -i")
run_command("coverage html -d htmlcov -i")

View File

@@ -1 +1 @@
3.2.0
3.3.0

View File

@@ -25,7 +25,7 @@ simulation:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
target_url: http://sometech.ai/users/
- hostname: pc_2
type: computer
@@ -39,7 +39,7 @@ simulation:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
target_url: http://sometech.ai/users/
- hostname: server_1
type: server
@@ -221,7 +221,7 @@ simulation:
subnet_mask: 255.255.255.0
acl:
2: # Allow the some_tech_web_srv to connect to the Database Service on some_tech_db_srv
11: # Allow the some_tech_web_srv to connect to the Database Service on some_tech_db_srv
action: PERMIT
src_ip: 94.10.180.6
src_wildcard_mask: 0.0.0.0
@@ -229,7 +229,7 @@ simulation:
dst_ip: 10.10.1.11
dst_wildcard_mask: 0.0.0.0
dst_port: POSTGRES_SERVER
3: # Allow the Database Service on some_tech_db_srv to respond to some_tech_web_srv
12: # Allow the Database Service on some_tech_db_srv to respond to some_tech_web_srv
action: PERMIT
src_ip: 10.10.1.11
src_wildcard_mask: 0.0.0.0
@@ -237,7 +237,7 @@ simulation:
dst_ip: 94.10.180.6
dst_wildcard_mask: 0.0.0.0
dst_port: POSTGRES_SERVER
4: # Prevent the Junior engineer from downloading files from the some_tech_storage_srv over FTP
13: # Prevent the Junior engineer from downloading files from the some_tech_storage_srv over FTP
action: DENY
src_ip: 10.10.2.12
src_wildcard_mask: 0.0.0.0
@@ -245,33 +245,41 @@ simulation:
dst_ip: 10.10.1.12
dst_wildcard_mask: 0.0.0.0
dst_port: FTP
5: # Allow communication between Engineering and the DB & Storage subnet
14: # Prevent the Junior engineer from connecting to some_tech_storage_srv over SSH
action: DENY
src_ip: 10.10.2.12
src_wildcard_mask: 0.0.0.0
src_port: SSH
dst_ip: 10.10.1.12
dst_wildcard_mask: 0.0.0.0
dst_port: SSH
15: # Allow communication between Engineering and the DB & Storage subnet
action: PERMIT
src_ip: 10.10.2.0
src_wildcard_mask: 0.0.0.255
dst_ip: 10.10.1.0
dst_wildcard_mask: 0.0.0.255
6: # Allow communication between the DB & Storage subnet and Engineering
16: # Allow communication between the DB & Storage subnet and Engineering
action: PERMIT
src_ip: 10.10.1.0
src_wildcard_mask: 0.0.0.255
dst_ip: 10.10.2.0
dst_wildcard_mask: 0.0.0.255
7: # Allow the SomeTech network to use HTTP
17: # Allow the SomeTech network to use HTTP
action: PERMIT
src_port: HTTP
dst_port: HTTP
8: # Allow the SomeTech internal network to use ARP
18: # Allow the SomeTech internal network to use ARP
action: PERMIT
src_ip: 10.10.0.0
src_wildcard_mask: 0.0.255.255
src_port: ARP
9: # Allow the SomeTech internal network to use ICMP
19: # Allow the SomeTech internal network to use ICMP
action: PERMIT
src_ip: 10.10.0.0
src_wildcard_mask: 0.0.255.255
protocol: ICMP
10:
21:
action: PERMIT
src_ip: 94.10.180.6
src_wildcard_mask: 0.0.0.0
@@ -279,10 +287,14 @@ simulation:
dst_ip: 10.10.0.0
dst_wildcard_mask: 0.0.255.255
dst_port: HTTP
11: # Permit SomeTech to use DNS
22: # Permit SomeTech to use DNS
action: PERMIT
src_port: DNS
dst_port: DNS
23: # Permit SomeTech to use SSH
action: PERMIT
src_port: SSH
dst_port: SSH
default_route: # Default route to all external networks
next_hop_ip_address: 10.10.4.2 # NI int on some_tech_fw
@@ -332,7 +344,7 @@ simulation:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
target_url: http://sometech.ai/users/
- hostname: some_tech_snr_dev_pc
type: computer
@@ -346,7 +358,7 @@ simulation:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
target_url: http://sometech.ai/users/
- hostname: some_tech_jnr_dev_pc
type: computer
@@ -360,7 +372,7 @@ simulation:
db_server_ip: 10.10.1.11
- type: WebBrowser
options:
target_url: http://sometech.ai
target_url: http://sometech.ai/users/
links:
# Home/Office Lan Links

View File

@@ -129,6 +129,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: client
type: computer

View File

@@ -1071,6 +1071,247 @@ class NodeNetworkServiceReconAction(AbstractAction):
]
class ConfigureC2BeaconAction(AbstractAction):
"""Action which configures a C2 Beacon based on the parameters given."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this action."""
c2_server_ip_address: str
keep_alive_frequency: int = Field(default=5, ge=1)
masquerade_protocol: str = Field(default="TCP")
masquerade_port: str = Field(default="HTTP")
@field_validator(
"c2_server_ip_address",
"keep_alive_frequency",
"masquerade_protocol",
"masquerade_port",
mode="before",
)
@classmethod
def not_none(cls, v: str, info: ValidationInfo) -> int:
"""If None is passed, use the default value instead."""
if v is None:
return cls.model_fields[info.field_name].default
return v
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
config = ConfigureC2BeaconAction._Opts(
c2_server_ip_address=config["c2_server_ip_address"],
keep_alive_frequency=config["keep_alive_frequency"],
masquerade_port=config["masquerade_port"],
masquerade_protocol=config["masquerade_protocol"],
)
ConfigureC2BeaconAction._Opts.model_validate(config) # check that options adhere to schema
return ["network", "node", node_name, "application", "C2Beacon", "configure", config.__dict__]
class NodeAccountsChangePasswordAction(AbstractAction):
"""Action which changes the password for a user."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"UserManager",
"change_password",
username,
current_password,
new_password,
]
class NodeSessionsRemoteLoginAction(AbstractAction):
"""Action which performs a remote session login."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"Terminal",
"ssh_to_remote",
username,
password,
remote_ip,
]
class NodeSessionsRemoteLogoutAction(AbstractAction):
"""Action which performs a remote session logout."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: str, remote_ip: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip]
class RansomwareConfigureC2ServerAction(AbstractAction):
"""Action which sends a command from the C2 Server to the C2 Beacon which configures a local RansomwareScript."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, config: Dict) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
# Using the ransomware scripts model to validate.
ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema
return ["network", "node", node_name, "application", "C2Server", "ransomware_configure", config]
class RansomwareLaunchC2ServerAction(AbstractAction):
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
# This action currently doesn't require any further configuration options.
return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"]
class ExfiltrationC2ServerAction(AbstractAction):
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this action."""
username: Optional[str]
password: Optional[str]
target_ip_address: str
target_file_name: str
target_folder_name: str
exfiltration_folder_name: Optional[str]
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(
self,
node_id: int,
account: dict,
target_ip_address: str,
target_file_name: str,
target_folder_name: str,
exfiltration_folder_name: Optional[str],
) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
command_model = {
"target_file_name": target_file_name,
"target_folder_name": target_folder_name,
"exfiltration_folder_name": exfiltration_folder_name,
"target_ip_address": target_ip_address,
"username": account["username"],
"password": account["password"],
}
ExfiltrationC2ServerAction._Opts.model_validate(command_model)
return ["network", "node", node_name, "application", "C2Server", "exfiltrate", command_model]
class NodeSendRemoteCommandAction(AbstractAction):
"""Action which sends a terminal command to a remote node via SSH."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return [
"network",
"node",
node_name,
"service",
"Terminal",
"send_remote_command",
remote_ip,
{"command": command},
]
class TerminalC2ServerAction(AbstractAction):
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this action."""
commands: Union[List[RequestFormat], RequestFormat]
ip_address: Optional[str]
username: Optional[str]
password: Optional[str]
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int, commands: List, ip_address: Optional[str], account: dict) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
command_model = {
"commands": commands,
"ip_address": ip_address,
"username": account["username"],
"password": account["password"],
}
TerminalC2ServerAction._Opts.model_validate(command_model)
return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model]
class RansomwareLaunchC2ServerAction(AbstractAction):
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
def __init__(self, manager: "ActionManager", **kwargs) -> None:
super().__init__(manager=manager)
def form_request(self, node_id: int) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
return ["do_nothing"]
# This action currently doesn't require any further configuration options.
return ["network", "node", node_name, "application", "C2Server", "ransomware_launch"]
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -1122,6 +1363,15 @@ class ActionManager:
"CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction,
"CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction,
"CONFIGURE_DOSBOT": ConfigureDoSBotAction,
"CONFIGURE_C2_BEACON": ConfigureC2BeaconAction,
"C2_SERVER_RANSOMWARE_LAUNCH": RansomwareLaunchC2ServerAction,
"C2_SERVER_RANSOMWARE_CONFIGURE": RansomwareConfigureC2ServerAction,
"C2_SERVER_TERMINAL_COMMAND": TerminalC2ServerAction,
"C2_SERVER_DATA_EXFILTRATE": ExfiltrationC2ServerAction,
"NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction,
"SSH_TO_REMOTE": NodeSessionsRemoteLoginAction,
"SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction,
"NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""

View File

@@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel):
reward: Optional[float] = None
reward_info: Dict[str, Any] = {}
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""

View File

@@ -23,8 +23,10 @@ class FileObservation(AbstractObservation, identifier="FILE"):
"""Name of the file, used for querying simulation state dictionary."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to the file in the observation."""
file_system_requires_scan: Optional[bool] = None
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(self, where: WhereType, include_num_access: bool) -> None:
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
"""
Initialise a file observation instance.
@@ -34,9 +36,13 @@ class FileObservation(AbstractObservation, identifier="FILE"):
:type where: WhereType
:param include_num_access: Whether to include the number of accesses to the file in the observation.
:type include_num_access: bool
:param file_system_requires_scan: If True, the file must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
self.where: WhereType = where
self.include_num_access: bool = include_num_access
self.file_system_requires_scan: bool = file_system_requires_scan
self.default_observation: ObsType = {"health_status": 0}
if self.include_num_access:
@@ -74,7 +80,11 @@ class FileObservation(AbstractObservation, identifier="FILE"):
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {"health_status": file_state["visible_status"]}
if self.file_system_requires_scan:
health_status = file_state["visible_status"]
else:
health_status = file_state["health_status"]
obs = {"health_status": health_status}
if self.include_num_access:
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
return obs
@@ -104,8 +114,15 @@ class FileObservation(AbstractObservation, identifier="FILE"):
:type parent_where: WhereType, optional
:return: Constructed file observation instance.
:rtype: FileObservation
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
return cls(
where=parent_where + ["files", config.file_name],
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
)
class FolderObservation(AbstractObservation, identifier="FOLDER"):
@@ -122,9 +139,16 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
"""Number of spaces for file observations in this folder."""
include_num_access: Optional[bool] = None
"""Whether files in this folder should include the number of accesses in their observation."""
file_system_requires_scan: Optional[bool] = None
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
def __init__(
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
self,
where: WhereType,
files: Iterable[FileObservation],
num_files: int,
include_num_access: bool,
file_system_requires_scan: bool,
) -> None:
"""
Initialise a folder observation instance.
@@ -138,12 +162,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
:type num_files: int
:param include_num_access: Whether to include the number of accesses to files in the observation.
:type include_num_access: bool
:param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False,
the true state is always shown.
:type file_system_requires_scan: bool
"""
self.where: WhereType = where
self.file_system_requires_scan: bool = file_system_requires_scan
self.files: List[FileObservation] = files
while len(self.files) < num_files:
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
self.files.append(
FileObservation(
where=None,
include_num_access=include_num_access,
file_system_requires_scan=self.file_system_requires_scan,
)
)
while len(self.files) > num_files:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
@@ -168,7 +203,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
if self.file_system_requires_scan:
health_status = folder_state["visible_status"]
else:
health_status = folder_state["health_status"]
obs = {}
@@ -209,6 +247,13 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
# pass down shared/common config items
for file_config in config.files:
file_config.include_num_access = config.include_num_access
file_config.file_system_requires_scan = config.file_system_requires_scan
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
return cls(
where=where,
files=files,
num_files=config.num_files,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
)

View File

@@ -10,6 +10,7 @@ from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
@@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list: List[int],
protocol_list: List[str],
num_rules: int,
include_users: bool,
) -> None:
"""
Initialise a firewall observation instance.
@@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.include_users: bool = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
]
@@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
:rtype: ObsType
"""
firewall_state = access_from_nested_dict(state, self.where)
if firewall_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
@@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
},
},
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
port_list=config.port_list,
protocol_list=config.protocol_list,
num_rules=config.num_rules,
include_users=config.include_users,
)

View File

@@ -48,6 +48,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"""A dict containing which traffic types are to be included in the observation."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to files observations on this host."""
file_system_requires_scan: Optional[bool] = None
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -64,6 +70,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_nmne: bool,
monitored_traffic: Optional[Dict],
include_num_access: bool,
file_system_requires_scan: bool,
include_users: bool,
) -> None:
"""
Initialise a host observation instance.
@@ -95,10 +103,18 @@ class HostObservation(AbstractObservation, identifier="HOST"):
:type monitored_traffic: Dict
:param include_num_access: Flag to include the number of accesses to files.
:type include_num_access: bool
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
If False, the true state is always shown.
:type file_system_requires_scan: bool
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.include_num_access = include_num_access
self.include_users = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
@@ -120,7 +136,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders:
self.folders.append(
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
FolderObservation(
where=None,
files=[],
num_files=num_files,
include_num_access=include_num_access,
file_system_requires_scan=file_system_requires_scan,
)
)
while len(self.folders) > num_folders:
truncated_folder = self.folders.pop()
@@ -151,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
if self.include_users:
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
def observe(self, state: Dict) -> ObsType:
"""
@@ -178,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_users:
sess = node_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -202,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)
if self.include_users:
shape["users"] = spaces.Dict(
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
)
return spaces.Dict(shape)
@classmethod
@@ -226,6 +260,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
folder_config.file_system_requires_scan = config.file_system_requires_scan
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
@@ -257,4 +292,6 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_nmne=config.include_nmne,
monitored_traffic=config.monitored_traffic,
include_num_access=config.include_num_access,
file_system_requires_scan=config.file_system_requires_scan,
include_users=config.include_users,
)

View File

@@ -44,6 +44,10 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""A dict containing which traffic types are to be included in the observation."""
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
file_system_requires_scan: bool = True
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
include_users: Optional[bool] = True
"""If True, report user session information."""
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: Optional[List[str]] = None
@@ -187,6 +191,10 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
host_config.monitored_traffic = config.monitored_traffic
if host_config.include_num_access is None:
host_config.include_num_access = config.include_num_access
if host_config.file_system_requires_scan is None:
host_config.file_system_requires_scan = config.file_system_requires_scan
if host_config.include_users is None:
host_config.include_users = config.include_users
for router_config in config.routers:
if router_config.num_ports is None:
@@ -201,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
router_config.protocol_list = config.protocol_list
if router_config.num_rules is None:
router_config.num_rules = config.num_rules
if router_config.include_users is None:
router_config.include_users = config.include_users
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
@@ -213,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
firewall_config.protocol_list = config.protocol_list
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
if firewall_config.include_users is None:
firewall_config.include_users = config.include_users
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]

View File

@@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
"""If True, report user session information."""
def __init__(
self,
@@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
ports: List[PortObservation],
num_ports: int,
acl: ACLObservation,
include_users: bool,
) -> None:
"""
Initialise a router observation instance.
@@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
:type num_ports: int
:param acl: ACL observation representing the access control list of the router.
:type acl: ACLObservation
:param include_users: If True, report user session information.
:type include_users: bool
"""
self.where: WhereType = where
self.ports: List[PortObservation] = ports
self.acl: ACLObservation = acl
self.num_ports: int = num_ports
self.include_users: bool = include_users
self.max_users: int = 3
"""Maximum number of remote sessions observable, excess sessions are truncated."""
while len(self.ports) < num_ports:
self.ports.append(PortObservation(where=None))
while len(self.ports) > num_ports:
@@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
if self.include_users:
sess = router_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users)

View File

@@ -47,7 +47,15 @@ class AbstractReward:
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -67,7 +75,15 @@ class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward):
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
@@ -151,33 +171,52 @@ class DatabaseFileIntegrity(AbstractReward):
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_hostname: str, service_name: str) -> None:
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
"""Initialise the reward component.
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
web_service_state = access_from_nested_dict(state, self.location_in_state)
# if webserver is no longer installed on the node, return 0
if web_service_state is NOT_PRESENT_IN_STATE:
return 0.0
most_recent_return_code = web_service_state["last_response_status_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
if most_recent_return_code == 200:
return 1.0
elif most_recent_return_code == 404:
return -1.0
else:
return 0.0
codes = web_service_state.get("response_codes_this_timestep")
if codes:
def status2rew(status: int) -> int:
"""Map status codes to reward values."""
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0
self.reward = 0.0
else: # skip calculating if sticky and no new codes. instead, reuse last step's value
pass
return self.reward
@classmethod
def from_config(cls, config: Dict) -> "WebServer404Penalty":
@@ -197,23 +236,29 @@ class WebServer404Penalty(AbstractReward):
)
_LOGGER.warning(msg)
raise ValueError(msg)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, service_name=service_name)
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
class WebpageUnavailablePenalty(AbstractReward):
"""Penalises the agent when the web browser fails to fetch a webpage."""
def __init__(self, node_hostname: str) -> None:
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -222,32 +267,50 @@ class WebpageUnavailablePenalty(AbstractReward):
When the green agent requests to execute the browser application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request did actually go through, then check if the webpage also loaded
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
if web_browser_state is NOT_PRESENT_IN_STATE:
self.reward = 0.0
# check if the most recent action was to request the webpage
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
"application",
"WebBrowser",
"execute",
]
# skip calculating if sticky and no new codes, reusing last step value
if not request_attempted and self.sticky:
return self.reward
if last_action_response.response.status != "success":
self.reward = -1.0
elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]:
_LOGGER.debug(
"Web browser reward could not be calculated because the web browser history on node",
f"{self._node} was not reported in the simulation state. Returning 0.0",
)
return 0.0 # 0 if the web browser cannot be found
if not web_browser_state["history"]:
return 0.0 # 0 if no requests have been attempted yet
outcome = web_browser_state["history"][-1]["outcome"]
if outcome == "PENDING":
return 0.0 # 0 if a request was attempted but not yet resolved
elif outcome == 200:
return 1.0 # 1 for successful request
else: # includes failure codes and SERVER_UNREACHABLE
return -1.0 # -1 for failure
self.reward = 0.0
else:
outcome = web_browser_state["history"][-1]["outcome"]
if outcome == "PENDING":
self.reward = 0.0 # 0 if a request was attempted but not yet resolved
elif outcome == 200:
self.reward = 1.0 # 1 for successful request
else: # includes failure codes and SERVER_UNREACHABLE
self.reward = -1.0 # -1 for failure
return self.reward
@classmethod
def from_config(cls, config: dict) -> AbstractReward:
@@ -258,22 +321,28 @@ class WebpageUnavailablePenalty(AbstractReward):
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str) -> None:
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -283,26 +352,33 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
"application",
"DatabaseClient",
"execute",
]
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
if request_attempted: # if agent makes request, always recalculate fresh value
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
elif not self.sticky: # if no new request and not sticky, set reward to 0
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
self.reward = 0.0
else: # if no new request and sticky, reuse reward value from last step
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
pass
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
elif last_connection_successful is True:
return 1.0
return 0.0
return self.reward
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
@@ -313,7 +389,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class SharedReward(AbstractReward):
@@ -346,7 +423,15 @@ class SharedReward(AbstractReward):
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
"""Simply access the other agent's reward and return it.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return self.callback(self.agent_name)
@classmethod
@@ -379,7 +464,15 @@ class ActionPenalty(AbstractReward):
self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied."""
"""Calculate the penalty to be applied.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.action == "DONOTHING":
return self.do_nothing_penalty
else:
@@ -436,6 +529,7 @@ class RewardFunction:
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
self.current_reward = total
return self.current_reward
@classmethod

View File

@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
"""Strict validation."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
@@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent):
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
# The random number seed for np.random is dependent on whether a random number seed is set
# in the config file. If there is one it is processed by set_random_seed() in environment.py
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
self.settings = ProbabilisticAgent.Settings(**settings)
self.rng = np.random.default_rng(self.settings.random_seed)
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)
# convert probabilities from
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
super().__init__(agent_name, action_space, observation_space, reward_function)
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import numpy as np
from pydantic import BaseModel, ConfigDict
@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -26,11 +26,14 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient # noqa: F401
from primaite.simulator.system.applications.red_applications.c2.c2_beacon import C2Beacon # noqa: F401
from primaite.simulator.system.applications.red_applications.c2.c2_server import C2Server # noqa: F401
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import ( # noqa: F401
DataManipulationBot,
)
@@ -44,7 +47,10 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import Software
_LOGGER = getLogger(__name__)
@@ -57,6 +63,7 @@ SERVICE_TYPES_MAPPING = {
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
"Terminal": Terminal,
}
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
@@ -70,6 +77,8 @@ class PrimaiteGameOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
seed: int = None
"""Random number seed for RNGs."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]
@@ -264,9 +273,12 @@ class PrimaiteGame:
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
new_node = None
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
@@ -316,6 +328,25 @@ class PrimaiteGame:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
raise ValueError(msg)
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
def _set_software_listen_on_ports(software: Union[Software, Service], software_cfg: Dict):
"""Set listener ports on software."""
listen_on_ports = []
for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])):
port = None
if isinstance(port_id, int):
port = Port(port_id)
elif isinstance(port_id, str):
port = Port[port_id]
if port:
listen_on_ports.append(port)
software.listen_on_ports = set(listen_on_ports)
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
@@ -329,6 +360,7 @@ class PrimaiteGame:
if "fix_duration" in service_cfg.get("options", {}):
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
_set_software_listen_on_ports(new_service, service_cfg)
# start the service
new_service.start()
else:
@@ -378,6 +410,8 @@ class PrimaiteGame:
_LOGGER.error(msg)
raise ValueError(msg)
_set_software_listen_on_ports(new_application, application_cfg)
# run the application
new_application.run()
@@ -422,6 +456,15 @@ class PrimaiteGame:
dos_intensity=float(opt.get("dos_intensity", "1.0")),
max_sessions=int(opt.get("max_sessions", "1000")),
)
elif application_type == "C2Beacon":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")),
keep_alive_frequency=(opt.get("keep_alive_frequency", 5)),
masquerade_protocol=IPProtocol[(opt.get("masquerade_protocol", IPProtocol.TCP))],
masquerade_port=Port[(opt.get("masquerade_port", Port.HTTP))],
)
if "network_interfaces" in node_cfg:
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
@@ -533,10 +576,7 @@ class PrimaiteGame:
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
return game
def setup_reward_sharing(self):

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,607 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simulating Privilege Escalation and Data Loss Using SSH and ACLs Manipulation\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
"\n",
"## Overview\n",
"\n",
"This Jupyter notebook demonstrates a cyber scenario focusing on internal privilege escalation and data loss through the manipulation of SSH access and Access Control Lists (ACLs). The scenario is designed to model and visualise how a disgruntled junior engineer might exploit internal network vulnerabilities and social engineering of account credentials to escalate privileges and cause significant data loss and disruption to services.\n",
"\n",
"## Scenario Description\n",
"\n",
"This simulation utilises the PrimAITE demo network, focussing specifically on five nodes:\n",
"\n",
"<a href=\"_package_data/primaite_demo_network.png\" target=\"_blank\">\n",
" <img src=\"_package_data/primaite_demo_network.png\" alt=\"Description of Image\" style=\"width:100%; max-width:450px;\">\n",
"</a>\n",
"\n",
"\n",
"- **SomeTech Developer PC (`some_tech_jnr_dev_pc`)**: The workstation used by the junior engineer.\n",
"- **SomeTech Core Router (`some_tech_rt`)**: A critical network device that controls access between nodes.\n",
"- **SomeTech PostgreSQL Database Server (`some_tech_db_srv`)**: Hosts the companys critical database.\n",
"- **SomeTech Storage Server (`some_tech_storage_srv`)**: Stores important files and database backups.\n",
"- **SomeTech Web Server (`some_tech_web_srv`)**: Serves the companys website.\n",
"\n",
"By default, the junior developer PC is restricted from connecting to the storage server via FTP or SSH due to ACL rules that permit only senior members of the engineering team to access these services.\n",
"\n",
"The goal of the scenario is to simulate how the junior engineer, after gaining unauthorised access to the core router, manipulates ACL rules to escalate privileges and delete critical data.\n",
"\n",
"### Key Actions Simulated\n",
"\n",
"1. **Privilege Escalation**: The junior engineer uses social engineering to obtain login credentials for the core router, SSHs into the router, and modifies the ACL rules to allow SSH access from their PC to the storage server.\n",
"2. **Remote Access**: The junior engineer then uses the newly gained SSH access to connect to the storage server from their PC. This step is crucial for executing further actions, such as deleting files.\n",
"3. **File Deletion**: With SSH access to the storage server, the engineer deletes the backup file from the storage server and subsequently removes critical data from the PostgreSQL database, bringing down the sometech.ai website.\n",
"4. **Website Impact Verification:** After the deletion of the database backup, the scenario checks the sometech.ai website's status to confirm it has been brought down due to the data loss.\n",
"5. **Database Restore Failure:** An attempt is made to restore the deleted backup, demonstrating that the restoration fails and highlighting the severity of the data loss.\n",
"\n",
"### Notes:\n",
"- The demo will utilise CAOS (Common Action and Observation Space) actions wherever they are available. For actions where a CAOS action does not yet exist, the action will be performed manually on the node/service.\n",
"- This notebook will be updated to incorporate new CAOS actions as they become supported."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Scenario"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"from primaite import PRIMAITE_PATHS\n",
"from primaite.game.game import PrimaiteGame\n",
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
"from primaite.simulator.network.hardware.nodes.network.router import Router\n",
"from primaite.simulator.network.hardware.nodes.host.server import Server\n",
"from primaite.simulator.system.applications.database_client import DatabaseClient\n",
"from primaite.simulator.system.applications.web_browser import WebBrowser\n",
"from primaite.simulator.system.services.database.database_service import DatabaseService"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the network configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"path = PRIMAITE_PATHS.user_config_path / \"example_config\" / \"multi_lan_internet_network_example.yaml\"\n",
"\n",
"with open(path, \"r\") as file:\n",
" cfg = yaml.safe_load(file)\n",
"\n",
" game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Capture some of the nodes from the network to observe actions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_jnr_dev_pc: Computer = game.simulation.network.get_node_by_hostname(\"some_tech_jnr_dev_pc\")\n",
"some_tech_jnr_dev_db_client: DatabaseClient = some_tech_jnr_dev_pc.software_manager.software[\"DatabaseClient\"]\n",
"some_tech_jnr_dev_web_browser: WebBrowser = some_tech_jnr_dev_pc.software_manager.software[\"WebBrowser\"]\n",
"some_tech_rt: Router = game.simulation.network.get_node_by_hostname(\"some_tech_rt\")\n",
"some_tech_db_srv: Server = game.simulation.network.get_node_by_hostname(\"some_tech_db_srv\")\n",
"some_tech_db_service: DatabaseService = some_tech_db_srv.software_manager.software[\"DatabaseService\"]\n",
"some_tech_storage_srv: Server = game.simulation.network.get_node_by_hostname(\"some_tech_storage_srv\")\n",
"some_tech_web_srv: Server = game.simulation.network.get_node_by_hostname(\"some_tech_web_srv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform a Database Backup and Inspect the Storage Server\n",
"\n",
"At this stage, a backup of the PostgreSQL database is created and the storage servers file system is inspected. This step ensures that a backup file is present and correctly stored in the storage server before any further actions are taken. The inspection of the file system allows verification of the backups existence and health, establishing a baseline that will later be used to confirm the success of the subsequent deletion actions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_storage_srv.file_system.show(full=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_db_service.backup_database()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_storage_srv.file_system.show(full=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Extract the folder name containing the database backup file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"db_backup_folder = [folder.name for folder in some_tech_storage_srv.file_system.folders.values() if folder.name != \"root\"][0]\n",
"db_backup_folder"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Check That the Junior Engineer Cannot SSH into the Storage Server\n",
"\n",
"This step verifies that the junior engineer is currently restricted from SSH access to the storage server. By attempting to establish an SSH connection from the junior engineers workstation to the storage server, this action confirms that the current ACL rules on the core router correctly prevents unauthorised access. It sets up the necessary conditions to later validate the effectiveness of the privilege escalation by demonstrating the initial access restrictions.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
"]\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Confirm That the Website Is Up by Executing the Web Browser on the Junior Engineer's Machine\n",
"\n",
"In this step, we verify that the sometech.ai website is operational before any malicious activities begin. By executing the web browser application on the junior engineers machine, we ensure that the website is accessible and functioning correctly. This establishes a baseline for the websites status, allowing us to later assess the impact of the subsequent actions, such as database deletion, on the website's availability.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"WebBrowser\", \"execute\"]\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Exploit Core Router to Add ACL for SSH Access\n",
"\n",
"At this point, the junior engineer exploits a vulnerability in the core router by obtaining the login credentials through social engineering. With SSH access to the core router, the engineer modifies the ACL rules to permit SSH connections from their machine to the storage server. This action is crucial as it will enable the engineer to remotely access the storage server and execute further malicious activities.\n",
"\n",
"Interestingly, if we inspect the `active_remote_sessions` on the SomeTech core routers `UserSessionManager`, we'll see an active session appear. This active session would pop up in the observation space."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"UserSessionManager\"][\"active_remote_sessions\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n",
"]\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"UserSessionManager\"][\"active_remote_sessions\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inspect the ACL Table Before Adding the New Rule\n",
"\n",
"Before making any changes, we first examine the current Access Control List (ACL) table on the core router. This inspection provides a snapshot of the existing rules that govern network traffic, including permissions and restrictions related to SSH access. Understanding this baseline is crucial for verifying the effect of new rules, ensuring that changes can be accurately assessed for their impact on network security and access controls.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_rt.acl.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"send_remote_command\", str(some_tech_rt.network_interface[4].ip_address),\n",
" {\n",
" \"command\": [\n",
" \"acl\", \"add_rule\", \"PERMIT\", \"TCP\",\n",
" str(some_tech_jnr_dev_pc.network_interface[1].ip_address), \"0.0.0.0\", \"SSH\",\n",
" str(some_tech_storage_srv.network_interface[1].ip_address), \"0.0.0.0\", \"SSH\",\n",
" 1\n",
" ]\n",
" }\n",
"]\n",
"\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify That the New ACL Rule Has Been Added\n",
"\n",
"After updating the ACL rules on the core router, we need to confirm that the new rule has been successfully applied. This verification involves inspecting the ACL table again to ensure that the new rule allowing SSH access from the junior engineers PC to the storage server is present. This step is critical to ensure that the modification was executed correctly and that the junior engineer now has the intended access."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_rt.acl.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Terminate Remote Session on Core Router\n",
"\n",
"After successfully adding the ACL rule to allow SSH access to the storage server, the junior engineer terminates the remote session on the core router. The termination of the session is a strategic move to avoid leaving an active remote login open while maintaining the newly granted access privileges for future use."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"remote_logoff\", str(some_tech_rt.network_interface[4].ip_address)\n",
"]\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Confirm the termination of the remote session"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"game.get_sim_state()[\"network\"][\"nodes\"][\"some_tech_rt\"][\"services\"][\"UserSessionManager\"][\"active_remote_sessions\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SSH into Storage Server and Delete Database Backup\n",
"\n",
"With the newly added ACL rule, the junior engineer can now SSH into the storage server from their machine. The engineer proceeds to delete the critical database backup file stored on the server. This action is pivotal in the attack, as it directly impacts the availability of essential data and sets the stage for subsequent data loss and disruption of services.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
"]\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"send_remote_command\", str(some_tech_storage_srv.network_interface[1].ip_address),\n",
" {\n",
" \"command\": [\n",
" \"file_system\", \"delete\", \"file\", db_backup_folder, \"database.db\"\n",
" ]\n",
" }\n",
"]\n",
"\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify that the database backup file has been deleted"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_storage_srv.file_system.show(full=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Delete Critical Data from the PostgreSQL Database\n",
"\n",
"In this part of the scenario, the junior engineer manually interacts with the PostgreSQL database to delete critical data. The deletion of critical data from the database has significant implications, leading to the loss of essential information and affecting the availability of the sometech.ai website.\n",
"\n",
"* Since the CAOS framework does not support ad-hoc or dynamic SQL queries for database services, this action must be performed manually."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Again, confirm that the sometech.ai website is up by executing the web browser on the junior engineer's machine"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"WebBrowser\", \"execute\"]\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Set the server IP address and open a new DB connection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_jnr_dev_db_client.server_ip_address = some_tech_db_srv.network_interface[1].ip_address\n",
"some_tech_jnr_dev_db_connection = some_tech_jnr_dev_db_client.get_new_connection()\n",
"some_tech_jnr_dev_db_connection"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"##### Send the DELETE query"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_jnr_dev_db_connection.query(\"DELETE\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Confirm that the actions have brought the sometech.ai website down"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"caos_action = [\"network\", \"node\", \"some_tech_jnr_dev_pc\", \"application\", \"WebBrowser\", \"execute\"]\n",
"game.simulation.apply_request(caos_action)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Attempt to Restore Database Backup\n",
"\n",
"In this final section, an attempt is made to restore the database backup that was deleted earlier. The action is performed using the `some_tech_db_service.restore_backup()` method. This will demonstrate the impact of the data loss and confirm that the backup restoration fails, highlighting the severity of the disruption caused."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"some_tech_db_service.restore_backup()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## End of Scenario Summary\n",
"\n",
"In this simulation, we modelled a cyber attack scenario where a disgruntled junior engineer exploits internal network vulnerabilities to escalate privileges, causing significant data loss and disruption of services. The following key actions were performed:\n",
"\n",
"1. **Privilege Escalation:** The junior engineer used social engineering to obtain the login credentials for the core router. They remotely accessed the router via SSH and modified the ACL rules to grant SSH access from their machine to the storage server.\n",
"\n",
"2. **Remote Access:** With the modified ACLs in place, the engineer was able to SSH into the storage server from their machine. This access enabled them to interact with the storage server and perform further actions.\n",
"\n",
"3. **File & Data Deletion:** The engineer used SSH remote access to delete a critical database backup file from the storage server. Subsequently, they executed a SQL command to delete critical data from the PostgreSQL database, which resulted in the disruption of the sometech.ai website.\n",
"\n",
"4. **Website Status Verification:** After the deletion of the database backup, the website's status was checked to confirm that it had been brought down due to the data loss.\n",
"\n",
"5. **Database Restore Failure:** An attempt to restore the deleted backup was made to demonstrate that the restoration process failed, highlighting the severity of the data loss.\n",
"\n",
"**Verification and Outcomes:**\n",
"\n",
"- **Initial State Verification:** The backup file was confirmed to be present on the storage server before any actions were taken. The junior engineer's inability to SSH into the storage server initially confirmed that ACL restrictions were in effect.\n",
"\n",
"- **Privilege Escalation Confirmation:** The successful modification of the ACL rules was verified by checking the router's ACL table.\n",
"\n",
"- **Remote Access Verification:** After the ACL modification, the engineer successfully SSH'd into the storage server from their PC. The file system inspection confirmed that the backup file was accessible and could be deleted.\n",
"\n",
"- **File Deletion Confirmation:** The deletion of the backup file was confirmed by inspecting the storage server's file system after the operation. The backup file was marked as deleted, validating that the deletion command was executed.\n",
"\n",
"- **Database and Website Impact:** The deletion of the database backup was followed by a DELETE query executed on the PostgreSQL database. The website's functionality was subsequently checked using a web browser, confirming that the sometech.ai website was down due to the data loss.\n",
"\n",
"- **Restore Attempt Verification:** An attempt to restore the deleted database backup was made, and it was confirmed that the restoration failed, highlighting the impact of the data deletion."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,224 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Terminal Processing\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook serves as a guide on the functionality and use of the new Terminal simulation component.\n",
"\n",
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.system.services.terminal.terminal import Terminal\n",
"from primaite.simulator.network.container import Network\n",
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
"from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n",
"from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n",
"\n",
"def basic_network() -> Network:\n",
" \"\"\"Utility function for creating a default network to demonstrate Terminal functionality\"\"\"\n",
" network = Network()\n",
" node_a = Computer(hostname=\"node_a\", ip_address=\"192.168.0.10\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
" node_a.power_on()\n",
" node_b = Computer(hostname=\"node_b\", ip_address=\"192.168.0.11\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
" node_b.power_on()\n",
" network.connect(node_a.network_interface[1], node_b.network_interface[1])\n",
" return network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The terminal can be accessed from a `Node` via the `software_manager` as demonstrated below. \n",
"\n",
"In the example, we have a basic network consisting of two computers, connected to form a basic network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"network: Network = basic_network()\n",
"computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n",
"terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n",
"computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n",
"terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To be able to send commands from `node_a` to `node_b`, you will need to `login` to `node_b` first, using valid user credentials. In the example below, we are remotely logging in to the 'admin' account on `node_b`, from `node_a`. \n",
"If you are not logged in, any commands sent will be rejected by the remote.\n",
"\n",
"Remote Logins return a RemoteTerminalConnection object, which can be used for sending commands to the remote node. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Login to the remote (node_b) from local (node_a)\n",
"term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"admin\", ip_address=\"192.168.0.11\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can view all active connections to a terminal through use of the `show()` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminal_b.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The new connection object allows us to forward commands to be executed on the target node. The example below demonstrates how you can remotely install an application on the target node."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.software_manager.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to create a downloads folder. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display the current state of the file system on computer_b\n",
"computer_b.file_system.show()\n",
"\n",
"# Send command\n",
"term_a_term_b_remote_connection.execute([\"file_system\", \"create\", \"folder\", \"downloads\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The resultant call to `computer_b.file_system.show()` shows that the new folder has been created."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.file_system.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When finished, the connection can be closed by calling the `disconnect` function of the Remote Client object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display active connection\n",
"terminal_a.show()\n",
"terminal_b.show()\n",
"\n",
"term_a_term_b_remote_connection.disconnect()\n",
"\n",
"terminal_a.show()\n",
"terminal_b.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Disconnected Terminal sessions will no longer show in the node's Terminal connection list, but will be under the historic sessions in the `user_session_manager`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.user_session_manager.show(include_historic=True, include_session_id=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -62,6 +62,7 @@
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
" .evaluation(evaluation_duration=1)\n",
" )\n"
]
},
@@ -82,6 +83,22 @@
"algo = config.build()\n",
"results = algo.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate the results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval = algo.evaluate()"
]
}
],
"metadata": {

View File

@@ -55,6 +55,7 @@
" .environment(env=PrimaiteRayEnv, env_config=env_config)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
" .evaluation(evaluation_duration=1)\n",
")\n"
]
},
@@ -74,6 +75,22 @@
"algo = config.build()\n",
"results = algo.train()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate the results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval = algo.evaluate()"
]
}
],
"metadata": {

View File

@@ -199,7 +199,7 @@
"metadata": {},
"source": [
"### Episode 0\n",
"Let' run the episodes to verify that the agents are changing as expected. In episode 0, there should be no green or red agents, just the defender blue agent."
"Let's run the episodes to verify that the agents are changing as expected. In episode 0, there should be no green or red agents, just the defender blue agent."
]
},
{

Binary file not shown.

After

Width:  |  Height:  |  Size: 334 KiB

View File

@@ -1,5 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
import random
import sys
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
@@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
# Check torch is installed
try:
import torch as th
except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
"""
Set random number generators.
:param seed: int
"""
if seed is None or seed == -1:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# Seed the RNG for all devices (both CPU and CUDA)
# if torch not installed don't set random seed.
if sys.modules["torch"]:
th.manual_seed(seed)
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
return seed
class PrimaiteGymEnv(gymnasium.Env):
"""
@@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env):
super().__init__()
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env):
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}

View File

@@ -60,13 +60,13 @@ class AirSpaceFrequency(Enum):
@property
def maximum_data_rate_bps(self) -> float:
"""
Retrieves the maximum data transmission rate in bits per second (bps) for the frequency.
Retrieves the maximum data transmission rate in bits per second (bps).
The maximum rates are predefined for known frequencies:
- For WIFI_2_4, it returns 100,000,000 bps (100 Mbps).
- For WIFI_5, it returns 500,000,000 bps (500 Mbps).
The maximum rates are predefined for frequencies.:
- WIFI 2.4 supports 100,000,000 bps
- WIFI 5 supports 500,000,000 bps
:return: The maximum data rate in bits per second. If the frequency is not recognized, returns 0.0.
:return: The maximum data rate in bits per second.
"""
if self == AirSpaceFrequency.WIFI_2_4:
return 100_000_000.0 # 100 Megabits per second

View File

@@ -6,12 +6,11 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Optional, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validate_call
import primaite.simulator.network.nmne
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
@@ -20,17 +19,10 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.packet_capture import PacketCapture
from primaite.simulator.system.core.session_manager import SessionManager
@@ -38,7 +30,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import IOSoftware
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.software import IOSoftware, Software
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
from primaite.utils.validators import IPV4Address
@@ -108,8 +101,11 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
"A dataclass defining malicious network events to be captured."
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
"A dict containing details of the number of malicious events captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
@@ -167,8 +163,8 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
if CAPTURE_NMNE:
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
if self.nmne_config and self.nmne_config.capture_nmne:
state.update({"nmne": self.nmne})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@@ -201,7 +197,7 @@ class NetworkInterface(SimComponent, ABC):
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
if not (self.nmne_config and self.nmne_config.capture_nmne):
return
# Initialise basic frame data variables
@@ -222,27 +218,27 @@ class NetworkInterface(SimComponent, ABC):
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
# Start with the root of the NMNE capture structure
current_level = self.nmne
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
if self.nmne_config.capture_by_direction:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
if self.nmne_config.capture_by_ip_address:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
if self.nmne_config.capture_by_protocol:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
if self.nmne_config.capture_by_port:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
@@ -251,8 +247,8 @@ class NetworkInterface(SimComponent, ABC):
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if self.nmne_config.capture_by_keyword:
for keyword in self.nmne_config.nmne_capture_keywords:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
@@ -794,6 +790,685 @@ class Link(SimComponent):
self.current_load = 0.0
class User(SimComponent):
"""
Represents a user in the PrimAITE system.
:ivar username: The username of the user
:ivar password: The password of the user
:ivar disabled: Boolean flag indicating whether the user is disabled
:ivar is_admin: Boolean flag indicating whether the user has admin privileges
"""
username: str
"""The username of the user"""
password: str
"""The password of the user"""
disabled: bool = False
"""Boolean flag indicating whether the user is disabled"""
is_admin: bool = False
"""Boolean flag indicating whether the user has admin privileges"""
num_of_logins: int = 0
"""Counts the number of the User has logged in"""
def describe_state(self) -> Dict:
"""
Returns a dictionary representing the current state of the user.
:return: A dict containing the state of the user
"""
return self.model_dump()
class UserManager(Service):
"""
Manages users within the PrimAITE system, handling creation, authentication, and administration.
:param users: A dictionary of all users by their usernames
:param admins: A dictionary of admin users by their usernames
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
"""
users: Dict[str, User] = {}
def __init__(self, **kwargs):
"""
Initializes a UserManager instanc.
:param username: The username for the default admin user
:param password: The password for the default admin user
"""
kwargs["name"] = "UserManager"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self.start()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
rm.add_request(
"change_password",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.change_user_password(username=request[0], current_password=request[1], new_password=request[2])
)
),
)
return rm
def describe_state(self) -> Dict:
"""
Returns the state of the UserManager along with the number of users and admins.
:return: A dict containing detailed state information
"""
state = super().describe_state()
state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)})
state["users"] = {k: v.describe_state() for k, v in self.users.items()}
return state
def show(self, markdown: bool = False):
"""
Display the Users.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
table = PrettyTable(["Username", "Admin", "Disabled"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} User Manager"
for user in self.users.values():
table.add_row([user.username, user.is_admin, user.disabled])
print(table.get_string(sortby="Username"))
@property
def non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and not v.disabled}
@property
def disabled_non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and v.disabled}
@property
def admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and not v.disabled}
@property
def disabled_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and v.disabled}
def install(self) -> None:
"""Setup default user during first-time installation."""
self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True)
def _is_last_admin(self, username: str) -> bool:
return username in self.admins and len(self.admins) == 1
def add_user(
self, username: str, password: str, is_admin: bool = False, bypass_can_perform_action: bool = False
) -> bool:
"""
Adds a new user to the system.
:param username: The username for the new user
:param password: The password for the new user
:param is_admin: Flag indicating if the new user is an admin
:return: True if user was successfully added, False otherwise
"""
if not bypass_can_perform_action and not self._can_perform_action():
return False
if username in self.users:
self.sys_log.info(f"{self.name}: Failed to create new user {username} as this user name already exists")
return False
user = User(username=username, password=password, is_admin=is_admin)
self.users[username] = user
self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}")
return True
def authenticate_user(self, username: str, password: str) -> Optional[User]:
"""
Authenticates a user's login attempt.
:param username: The username of the user trying to log in
:param password: The password provided by the user
:return: The User object if authentication is successful, None otherwise
"""
if not self._can_perform_action():
return None
user = self.users.get(username)
if user and not user.disabled and user.password == password:
self.sys_log.info(f"{self.name}: User authenticated: {username}")
return user
self.sys_log.info(f"{self.name}: Authentication failed for: {username}")
return None
def change_user_password(self, username: str, current_password: str, new_password: str) -> bool:
"""
Changes a user's password.
:param username: The username of the user changing their password
:param current_password: The current password of the user
:param new_password: The new password for the user
:return: True if the password was changed successfully, False otherwise
"""
if not self._can_perform_action():
return False
user = self.users.get(username)
if user and user.password == current_password:
user.password = new_password
self.sys_log.info(f"{self.name}: Password changed for {username}")
self._user_session_manager._logout_user(user=user)
return True
self.sys_log.info(f"{self.name}: Password change failed for {username}")
return False
def disable_user(self, username: str) -> bool:
"""
Disables a user account, preventing them from logging in.
:param username: The username of the user to disable
:return: True if the user was disabled successfully, False otherwise
"""
if not self._can_perform_action():
return False
if username in self.users and not self.users[username].disabled:
if self._is_last_admin(username):
self.sys_log.info(f"{self.name}: Cannot disable User {username} as they are the only enabled admin")
return False
self.users[username].disabled = True
self.sys_log.info(f"{self.name}: User disabled: {username}")
return True
self.sys_log.info(f"{self.name}: Failed to disable user: {username}")
return False
def enable_user(self, username: str) -> bool:
"""
Enables a previously disabled user account.
:param username: The username of the user to enable
:return: True if the user was enabled successfully, False otherwise
"""
if username in self.users and self.users[username].disabled:
self.users[username].disabled = False
self.sys_log.info(f"{self.name}: User enabled: {username}")
return True
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
return False
@property
def _user_session_manager(self) -> "UserSessionManager":
return self.software_manager.software["UserSessionManager"] # noqa
class UserSession(SimComponent):
"""
Represents a user session on the Node.
This class manages the state of a user session, including the user, session start, last active step,
and end step. It also indicates whether the session is local.
:ivar user: The user associated with this session.
:ivar start_step: The timestep when the session was started.
:ivar last_active_step: The last timestep when the session was active.
:ivar end_step: The timestep when the session ended, if applicable.
:ivar local: Indicates if the session is local. Defaults to True.
"""
user: User
"""The user associated with this session."""
start_step: int
"""The timestep when the session was started."""
last_active_step: int
"""The last timestep when the session was active."""
end_step: Optional[int] = None
"""The timestep when the session ended, if applicable."""
local: bool = True
"""Indicates if the session is a local session or a remote session. Defaults to True as a local session."""
@classmethod
def create(cls, user: User, timestep: int) -> UserSession:
"""
Creates a new instance of UserSession.
This class method initialises a user session with the given user and timestep.
:param user: The user associated with this session.
:param timestep: The timestep when the session is created.
:return: An instance of UserSession.
"""
user.num_of_logins += 1
return UserSession(user=user, start_step=timestep, last_active_step=timestep)
def describe_state(self) -> Dict:
"""
Describes the current state of the user session.
:return: A dictionary representing the state of the user session.
"""
return self.model_dump()
class RemoteUserSession(UserSession):
"""
Represents a remote user session on the Node.
This class extends the UserSession class to include additional attributes and methods specific to remote sessions.
:ivar remote_ip_address: The IP address of the remote user.
:ivar local: Indicates that this is not a local session. Always set to False.
"""
remote_ip_address: IPV4Address
"""The IP address of the remote user."""
local: bool = False
"""Indicates that this is not a local session. Always set to False."""
@classmethod
def create(cls, user: User, timestep: int, remote_ip_address: IPV4Address) -> RemoteUserSession: # noqa
"""
Creates a new instance of RemoteUserSession.
This class method initialises a remote user session with the given user, timestep, and remote IP address.
:param user: The user associated with this session.
:param timestep: The timestep when the session is created.
:param remote_ip_address: The IP address of the remote user.
:return: An instance of RemoteUserSession.
"""
return RemoteUserSession(
user=user, start_step=timestep, last_active_step=timestep, remote_ip_address=remote_ip_address
)
def describe_state(self) -> Dict:
"""
Describes the current state of the remote user session.
This method extends the base describe_state method to include the remote IP address.
:return: A dictionary representing the state of the remote user session.
"""
state = super().describe_state()
state["remote_ip_address"] = str(self.remote_ip_address)
return state
class UserSessionManager(Service):
"""
Manages user sessions on a Node, including local and remote sessions.
This class handles authentication, session management, and session timeouts for users interacting with the Node.
"""
local_session: Optional[UserSession] = None
"""The current local user session, if any."""
remote_sessions: Dict[str, RemoteUserSession] = {}
"""A dictionary of active remote user sessions."""
historic_sessions: List[UserSession] = Field(default_factory=list)
"""A list of historic user sessions."""
local_session_timeout_steps: int = 30
"""The number of steps before a local session times out due to inactivity."""
remote_session_timeout_steps: int = 30
"""The number of steps before a remote session times out due to inactivity."""
max_remote_sessions: int = 3
"""The maximum number of concurrent remote sessions allowed."""
current_timestep: int = 0
"""The current timestep in the simulation."""
def __init__(self, **kwargs):
"""
Initializes a UserSessionManager instance.
:param username: The username for the default admin user
:param password: The password for the default admin user
"""
kwargs["name"] = "UserSessionManager"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self.start()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
"""Request should take the form [username, password, remote_ip_address]."""
username, password, remote_ip_address = request
response = RequestResponse.from_bool(self.remote_login(username, password, remote_ip_address))
response.data = {"remote_hostname": self.parent.hostname, "username": username}
return response
rm.add_request("remote_login", RequestType(func=_remote_login))
rm.add_request(
"remote_logout",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.remote_logout(remote_session_id=request[0])
)
),
)
return rm
def show(self, markdown: bool = False, include_session_id: bool = False, include_historic: bool = False):
"""
Displays a table of the user sessions on the Node.
:param markdown: Whether to display the table in markdown format.
:param include_session_id: Whether to include session IDs in the table.
:param include_historic: Whether to include historic sessions in the table.
"""
headers = ["Session ID", "Username", "Type", "Remote IP", "Start Step", "Step Last Active", "End Step"]
if not include_session_id:
headers = headers[1:]
table = PrettyTable(headers)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.parent.hostname} User Sessions"
def _add_session_to_table(user_session: UserSession):
"""
Adds a user session to the table for display.
This helper function determines whether the session is local or remote and formats the session data
accordingly. It then adds the session data to the table.
:param user_session: The user session to add to the table.
"""
session_type = "local"
remote_ip = ""
if isinstance(user_session, RemoteUserSession):
session_type = "remote"
remote_ip = str(user_session.remote_ip_address)
data = [
user_session.uuid,
user_session.user.username,
session_type,
remote_ip,
user_session.start_step,
user_session.last_active_step,
user_session.end_step if user_session.end_step else "",
]
if not include_session_id:
data = data[1:]
table.add_row(data)
if self.local_session is not None:
_add_session_to_table(self.local_session)
for user_session in self.remote_sessions.values():
_add_session_to_table(user_session)
if include_historic:
for user_session in self.historic_sessions:
_add_session_to_table(user_session)
print(table.get_string(sortby="Step Last Active", reversesort=True))
def describe_state(self) -> Dict:
"""
Describes the current state of the UserSessionManager.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["current_local_user"] = None if not self.local_session else self.local_session.user.username
state["active_remote_sessions"] = list(self.remote_sessions.keys())
return state
@property
def _user_manager(self) -> UserManager:
"""
Returns the UserManager instance.
:return: The UserManager instance.
"""
return self.software_manager.software["UserManager"] # noqa
def pre_timestep(self, timestep: int) -> None:
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
self.current_timestep = timestep
inactive_sessions: list = []
if self.local_session:
if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep:
inactive_sessions.append(self.local_session)
for session in self.remote_sessions:
remote_session = self.remote_sessions[session]
if remote_session.last_active_step + self.remote_session_timeout_steps <= timestep:
inactive_sessions.append(remote_session)
for sessions in inactive_sessions:
self._timeout_session(sessions)
def _timeout_session(self, session: UserSession) -> None:
"""
Handles session timeout logic.
:param session: The session to be timed out.
"""
session.end_step = self.current_timestep
session_identity = session.user.username
if session.local:
self.local_session = None
session_type = "Local"
else:
self.remote_sessions.pop(session.uuid)
session_type = "Remote"
session_identity = f"{session_identity} {session.remote_ip_address}"
self.parent.terminal._connections.pop(session.uuid)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "user_timeout", "connection_id": session.uuid},
dest_port=Port.SSH,
dest_ip_address=session.remote_ip_address,
)
self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity")
@property
def remote_session_limit_reached(self) -> bool:
"""
Checks if the maximum number of remote sessions has been reached.
:return: True if the limit is reached, otherwise False.
"""
return len(self.remote_sessions) >= self.max_remote_sessions
def validate_remote_session_uuid(self, remote_session_id: str) -> bool:
"""
Validates if a given remote session ID exists.
:param remote_session_id: The remote session ID to validate.
:return: True if the session ID exists, otherwise False.
"""
return remote_session_id in self.remote_sessions
def _login(
self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None
) -> Optional[str]:
"""
Logs a user in either locally or remotely.
:param username: The username of the account.
:param password: The password of the account.
:param local: Whether the login is local or remote.
:param remote_ip_address: The remote IP address for remote login.
:return: The session ID if login is successful, otherwise None.
"""
if not self._can_perform_action():
return None
user = self._user_manager.authenticate_user(username=username, password=password)
if not user:
self.sys_log.info(f"{self.name}: Incorrect username or password")
return None
session_id = None
if local:
create_new_session = True
if self.local_session:
if self.local_session.user != user:
# logout the current user
self.local_logout()
else:
# not required as existing logged-in user attempting to re-login
create_new_session = False
if create_new_session:
self.local_session = UserSession.create(user=user, timestep=self.current_timestep)
session_id = self.local_session.uuid
else:
if not self.remote_session_limit_reached:
remote_session = RemoteUserSession.create(
user=user, timestep=self.current_timestep, remote_ip_address=remote_ip_address
)
session_id = remote_session.uuid
self.remote_sessions[session_id] = remote_session
self.sys_log.info(f"{self.name}: User {user.username} logged in")
return session_id
def local_login(self, username: str, password: str) -> Optional[str]:
"""
Logs a user in locally.
:param username: The username of the account.
:param password: The password of the account.
:return: The session ID if login is successful, otherwise None.
"""
return self._login(username=username, password=password, local=True)
@validate_call()
def remote_login(self, username: str, password: str, remote_ip_address: IPV4Address) -> Optional[str]:
"""
Logs a user in remotely.
:param username: The username of the account.
:param password: The password of the account.
:param remote_ip_address: The remote IP address for the remote login.
:return: The session ID if login is successful, otherwise None.
"""
return self._login(username=username, password=password, local=False, remote_ip_address=remote_ip_address)
def _logout(self, local: bool = True, remote_session_id: Optional[str] = None) -> bool:
"""
Logs a user out either locally or remotely.
:param local: Whether the logout is local or remote.
:param remote_session_id: The remote session ID for remote logout.
:return: True if logout successful, otherwise False.
"""
if not self._can_perform_action():
return False
session = None
if local and self.local_session:
session = self.local_session
session.end_step = self.current_timestep
self.local_session = None
if not local and remote_session_id:
self.parent.terminal._disconnect(remote_session_id)
session = self.remote_sessions.pop(remote_session_id)
if session:
self.historic_sessions.append(session)
self.sys_log.info(f"{self.name}: User {session.user.username} logged out")
return True
return False
def local_logout(self) -> bool:
"""
Logs out the current local user.
:return: True if logout successful, otherwise False.
"""
return self._logout(local=True)
def remote_logout(self, remote_session_id: str) -> bool:
"""
Logs out a remote user by session ID.
:param remote_session_id: The remote session ID.
:return: True if logout successful, otherwise False.
"""
return self._logout(local=False, remote_session_id=remote_session_id)
def _logout_user(self, user: Union[str, User]) -> bool:
"""End a user session by username or user object."""
if isinstance(user, str):
user = self._user_manager.users[user] # grab user object from username
for sess_id, session in self.remote_sessions.items():
if session.user is user:
self._logout(local=False, remote_session_id=sess_id)
return True
if self.local_user_logged_in and self.local_session.user is user:
self.local_logout()
return True
return False
@property
def local_user_logged_in(self) -> bool:
"""
Checks if a local user is currently logged in.
:return: True if a local user is logged in, otherwise False.
"""
return self.local_session is not None
class Node(SimComponent):
"""
A basic Node class that represents a node on the network.
@@ -861,11 +1536,14 @@ class Node(SimComponent):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
def __init__(self, **kwargs):
"""
Initialize the Node with various components and managers.
This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not
This method initialises the ARP cache, ICMP handler, session manager, and software manager if they are not
provided.
"""
if not kwargs.get("sys_log"):
@@ -885,9 +1563,45 @@ class Node(SimComponent):
dns_server=kwargs.get("dns_server"),
)
super().__init__(**kwargs)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self._install_system_software()
@property
def user_manager(self) -> Optional[UserManager]:
"""The Nodes User Manager."""
return self.software_manager.software.get("UserManager") # noqa
@property
def user_session_manager(self) -> Optional[UserSessionManager]:
"""The Nodes User Session Manager."""
return self.software_manager.software.get("UserSessionManager") # noqa
@property
def terminal(self) -> Optional[Terminal]:
"""The Nodes Terminal."""
return self.software_manager.software.get("Terminal")
def local_login(self, username: str, password: str) -> Optional[str]:
"""
Attempt to log in to the node uas a local user.
This method attempts to authenticate a local user with the given username and password. If successful, it
returns a session token. If authentication fails, it returns None.
:param username: The username of the account attempting to log in.
:param password: The password of the account attempting to log in.
:return: A session token if the login is successful, otherwise None.
"""
return self.user_session_manager.local_login(username, password)
def local_logout(self) -> None:
"""
Log out the current local user from the node.
This method ends the current local user's session and invalidates the session token.
"""
return self.user_session_manager.local_logout()
def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
"""
@@ -942,7 +1656,7 @@ class Node(SimComponent):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on."
return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
@@ -1091,10 +1805,6 @@ class Node(SimComponent):
return rm
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
pass
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -1173,7 +1883,7 @@ class Node(SimComponent):
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
]
)
print(table)
@@ -1455,74 +2165,6 @@ class Node(SimComponent):
else:
return
def install_service(self, service: Service) -> None:
"""
Install a service on this node.
:param service: Service instance that has not been installed on any node yet.
:type service: Service
"""
if service in self:
_LOGGER.warning(f"Can't add service {service.name} to node {self.hostname}. It's already installed.")
return
self.services[service.uuid] = service
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
def uninstall_service(self, service: Service) -> None:
"""
Uninstall and completely remove service from this node.
:param service: Service object that is currently associated with this node.
:type service: Service
"""
if service not in self:
_LOGGER.warning(f"Can't remove service {service.name} from node {self.hostname}. It's not installed.")
return
service.uninstall() # Perform additional teardown, such as removing files or restarting the machine.
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
self._service_request_manager.remove_request(service.name)
def install_application(self, application: Application) -> None:
"""
Install an application on this node.
:param application: Application instance that has not been installed on any node yet.
:type application: Application
"""
if application in self:
_LOGGER.warning(
f"Can't add application {application.name} to node {self.hostname}. It's already installed."
)
return
self.applications[application.uuid] = application
application.parent = self
self.sys_log.info(f"Installed application {application.name}")
_LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
def uninstall_application(self, application: Application) -> None:
"""
Uninstall and completely remove application from this node.
:param application: Application object that is currently associated with this node.
:type application: Application
"""
if application not in self:
_LOGGER.warning(
f"Can't remove application {application.name} from node {self.hostname}. It's not installed."
)
return
self.applications.pop(application.uuid)
application.parent = None
self.sys_log.info(f"Uninstalled application {application.name}")
self._application_request_manager.remove_request(application.name)
def _shut_down_actions(self):
"""Actions to perform when the node is shut down."""
# Turn off all the services in the node
@@ -1551,6 +2193,11 @@ class Node(SimComponent):
# for process_id in self.processes:
# self.processes[process_id]
def _install_system_software(self) -> None:
"""Preinstall required software."""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
return item.uuid in self.services

View File

@@ -5,7 +5,13 @@ from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from primaite import getLogger
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
Link,
Node,
UserManager,
UserSessionManager,
)
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.application import ApplicationOperatingState
@@ -15,6 +21,7 @@ from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.icmp.icmp import ICMP
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.utils.validators import IPV4Address
_LOGGER = getLogger(__name__)
@@ -292,6 +299,7 @@ class HostNode(Node):
* DNS (Domain Name System) Client: Resolves domain names to IP addresses.
* FTP (File Transfer Protocol) Client: Enables file transfers between the host and FTP servers.
* NTP (Network Time Protocol) Client: Synchronizes the system clock with NTP servers.
* Terminal Client: Handles SSH requests between HostNode and external components.
Applications:
------------
@@ -306,6 +314,9 @@ class HostNode(Node):
"NTPClient": NTPClient,
"WebBrowser": WebBrowser,
"NMAP": NMAP,
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
"""List of system software that is automatically installed on nodes."""
@@ -338,18 +349,6 @@ class HostNode(Node):
"""
return self.software_manager.software.get("ARP")
def _install_system_software(self):
"""
Installs the system software and network services typically found on an operating system.
This method equips the host with essential network services and applications, preparing it for various
network-related tasks and operations.
"""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
super()._install_system_software()
def default_gateway_hello(self):
"""
Sends a hello message to the default gateway to establish connectivity and resolve the gateway's MAC address.

View File

@@ -4,14 +4,14 @@ from __future__ import annotations
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, UserManager, UserSessionManager
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.protocols.arp import ARPPacket
@@ -24,6 +24,7 @@ from primaite.simulator.system.core.session_manager import SessionManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.utils.validators import IPV4Address
@@ -1200,6 +1201,12 @@ class Router(NetworkNode):
RouteTable, RouterARP, and RouterICMP services.
"""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
@@ -1235,6 +1242,7 @@ class Router(NetworkNode):
resolution within the network. These services are crucial for the router's operation, enabling it to manage
network traffic efficiently.
"""
super()._install_system_software()
self.software_manager.install(RouterICMP)
icmp: RouterICMP = self.software_manager.icmp # noqa
icmp.router = self

View File

@@ -108,6 +108,9 @@ class Switch(NetworkNode):
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
pass
def show(self, markdown: bool = False):
"""
Prints a table of the SwitchPorts on the Switch.

View File

@@ -1,48 +1,25 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, Final, List
from typing import List
CAPTURE_NMNE: bool = True
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination IP address."""
CAPTURE_BY_PROTOCOL: Final[bool] = False
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
CAPTURE_BY_PORT: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination port."""
CAPTURE_BY_KEYWORD: Final[bool] = False
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
from pydantic import BaseModel, ConfigDict
def set_nmne_config(nmne_config: Dict):
"""
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
class NMNEConfig(BaseModel):
"""Store all the information to perform NMNE operations."""
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
keywords to use for identifying NMNEs.
model_config = ConfigDict(extra="forbid")
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
and maintains type integrity by checking the types of the provided values.
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
to specify keywords for NMNE identification.
"""
global NMNE_CAPTURE_KEYWORDS
global CAPTURE_NMNE
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
if not isinstance(CAPTURE_NMNE, bool):
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
# Update the NMNE capture keywords, appending new keywords if provided
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
capture_nmne: bool = False
"""Indicates whether Malicious Network Events (MNEs) should be captured."""
nmne_capture_keywords: List[str] = []
"""List of keywords to identify malicious network events."""
capture_by_direction: bool = True
"""Captures should be organized by traffic direction (inbound/outbound)."""
capture_by_ip_address: bool = False
"""Captures should be organized by source or destination IP address."""
capture_by_protocol: bool = False
"""Captures should be organized by network protocol (e.g., TCP, UDP)."""
capture_by_port: bool = False
"""Captures should be organized by source or destination port."""
capture_by_keyword: bool = False
"""Captures should be filtered and categorised based on specific keywords."""

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Union
from pydantic import BaseModel, field_validator, validate_call
from pydantic_core.core_schema import FieldValidationInfo
from pydantic_core.core_schema import ValidationInfo
from primaite import getLogger
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
@field_validator("icmp_code") # noqa
@classmethod
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
"""Validates the icmp_type and icmp_code."""
icmp_type = info.data["icmp_type"]
if get_icmp_type_code_description(icmp_type, v):

View File

@@ -0,0 +1,23 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import Enum
from typing import Optional
from primaite.simulator.network.protocols.packet import DataPacket
class MasqueradePacket(DataPacket):
"""Represents an generic malicious packet that is masquerading as another protocol."""
masquerade_protocol: Enum # The 'Masquerade' protocol that is currently in use
masquerade_port: Enum # The 'Masquerade' port that is currently in use
class C2Packet(MasqueradePacket):
"""Represents C2 suite communications packets."""
payload_type: Enum # The type of C2 traffic (e.g keep alive, command or command out)
command: Optional[Enum] = None # Used to pass the actual C2 Command in C2 INPUT
keep_alive_frequency: int

View File

@@ -0,0 +1,89 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import IntEnum
from typing import Optional
from primaite.interface.request import RequestResponse
from primaite.simulator.network.protocols.packet import DataPacket
class SSHTransportMessage(IntEnum):
"""
Enum list of Transport layer messages that can be handled by the simulation.
Each msg value is equivalent to the real-world.
"""
SSH_MSG_USERAUTH_REQUEST = 50
"""Requests User Authentication."""
SSH_MSG_USERAUTH_FAILURE = 51
"""Indicates User Authentication failed."""
SSH_MSG_USERAUTH_SUCCESS = 52
"""Indicates User Authentication was successful."""
SSH_MSG_SERVICE_REQUEST = 24
"""Requests a service - such as executing a command."""
# These two msgs are invented for primAITE however are modelled on reality
SSH_MSG_SERVICE_FAILED = 25
"""Indicates that the requested service failed."""
SSH_MSG_SERVICE_SUCCESS = 26
"""Indicates that the requested service was successful."""
class SSHConnectionMessage(IntEnum):
"""Int Enum list of all SSH's connection protocol messages that can be handled by the simulation."""
SSH_MSG_CHANNEL_OPEN = 80
"""Requests an open channel - Used in combination with SSH_MSG_USERAUTH_REQUEST."""
SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 81
"""Confirms an open channel."""
SSH_MSG_CHANNEL_OPEN_FAILED = 82
"""Indicates that channel opening failure."""
SSH_MSG_CHANNEL_DATA = 84
"""Indicates that data is being sent through the channel."""
SSH_MSG_CHANNEL_CLOSE = 87
"""Closes the channel."""
class SSHUserCredentials(DataPacket):
"""Hold Username and Password in SSH Packets."""
username: str
"""Username for login"""
password: str
"""Password for login"""
class SSHPacket(DataPacket):
"""Represents an SSHPacket."""
transport_message: SSHTransportMessage
"""Message Transport Type"""
connection_message: SSHConnectionMessage
"""Message Connection Status"""
user_account: Optional[SSHUserCredentials] = None
"""User Account Credentials if passed"""
connection_request_uuid: Optional[str] = None
"""Connection Request UUID used when establishing a remote connection"""
connection_uuid: Optional[str] = None
"""Connection UUID used when validating a remote connection"""
ssh_output: Optional[RequestResponse] = None
"""RequestResponse from Request Manager"""
ssh_command: Optional[list] = None
"""Request String"""

View File

@@ -214,3 +214,21 @@ class Application(IOSoftware):
f"Cannot perform request on application '{self.application.name}' because it is not in the "
f"{self.state.name} state."
)
def _can_perform_network_action(self) -> bool:
"""
Checks if the application can perform outbound network actions.
First confirms application suitability via the can_perform_action method.
Then confirms that the host has an enabled NIC that can be used for outbound traffic.
:return: True if outbound network actions can be performed, otherwise False.
:rtype bool:
"""
if not super()._can_perform_action():
return False
for nic in self.software_manager.node.network_interface.values():
if nic.enabled:
return True
return False

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
@@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel):
if self.client and self.is_active:
self.client._disconnect(self.connection_id) # noqa
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
def __repr__(self) -> str:
return str(self)
class DatabaseClient(Application, identifier="DatabaseClient"):
"""
@@ -67,7 +73,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
_last_connection_successful: Optional[bool] = None
_query_success_tracker: Dict[str, bool] = {}
"""Keep track of connections that were established or verified during this step. Used for rewards."""
last_query_response: Optional[Dict] = None
@@ -76,7 +81,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""Connection ID to the Database Server."""
client_connections: Dict[str, DatabaseClientConnection] = {}
"""Keep track of active connections to Database Server."""
_client_connection_requests: Dict[str, Optional[str]] = {}
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
"""Dictionary of connection requests to Database Server."""
connected: bool = False
"""Boolean Value for whether connected to DB Server."""
@@ -129,8 +134,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
# list of connections that were established or verified during this step.
state["last_connection_successful"] = self._last_connection_successful
return state
def show(self, markdown: bool = False):
@@ -187,7 +190,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
return False
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
def _check_client_connection(self, connection_id: str) -> bool:
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
@@ -211,23 +214,28 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
valid_connection = self._check_client_connection(connection_id=connection_request_id)
if valid_connection:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
database_client_connection = self._client_connection_requests.pop(connection_request_id)
self.sys_log.info(
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
f"Connection Request ID was {connection_request_id}."
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
if isinstance(database_client_connection, DatabaseClientConnection):
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
f"Using connection id {database_client_connection}"
)
self.connected = True
return database_client_connection
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
)
return None
else:
self.sys_log.warning(
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
f"Connection Request ID was {connection_request_id}."
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
f"due to unknown client-side connection request id"
)
self._last_connection_successful = False
return None
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
@@ -300,9 +308,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""
if not self._can_perform_action():
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
self.sys_log.info(
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
)
return self._connect(
server_ip_address=self.server_ip_address,
password=self.server_password,
@@ -339,10 +352,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
success = self._query_success_tracker.get(query_id)
if success:
self.sys_log.info(f"{self.name}: Query successful {sql}")
self._last_connection_successful = True
return True
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
self._last_connection_successful = False
return False
else:
software_manager: SoftwareManager = self.software_manager

View File

@@ -0,0 +1,64 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Optional, Union
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from primaite.interface.request import RequestFormat
class CommandOpts(BaseModel):
"""A C2 Pydantic Schema acting as a base class for all C2 Commands."""
@field_validator("payload", "exfiltration_folder_name", "ip_address", mode="before", check_fields=False)
@classmethod
def not_none(cls, v: str, info: ValidationInfo) -> int:
"""If None is passed, use the default value instead."""
if v is None:
return cls.model_fields[info.field_name].default
return v
class RansomwareOpts(CommandOpts):
"""A Pydantic Schema for the Ransomware Configuration command options."""
server_ip_address: str
"""The IP Address of the target database that the RansomwareScript will attack."""
payload: str = Field(default="ENCRYPT")
"""The malicious payload to be used to attack the target database."""
class RemoteOpts(CommandOpts):
"""A base C2 Pydantic Schema for all C2 Commands that require a terminal connection."""
ip_address: Optional[str] = Field(default=None)
"""The IP address of a remote host. If this field defaults to None then a local session is used."""
username: str
"""A Username of a valid user account. Used to login into both remote and local hosts."""
password: str
"""A Password of a valid user account. Used to login into both remote and local hosts."""
class ExfilOpts(RemoteOpts):
"""A Pydantic Schema for the C2 Data Exfiltration command options."""
target_ip_address: str
"""The IP address of the target host that will be the target of the exfiltration."""
target_file_name: str
"""The name of the file that is attempting to be exfiltrated."""
target_folder_name: str
"""The name of the remote folder which contains the target file."""
exfiltration_folder_name: str = Field(default="exfiltration_folder")
"""The name of C2 Suite folder used to store the target file. Defaults to ``exfiltration_folder``"""
class TerminalOpts(RemoteOpts):
"""A Pydantic Schema for the C2 Terminal command options."""
commands: Union[list[RequestFormat], RequestFormat]
"""A list or individual Terminal Command. Please refer to the RequestResponse system for further info."""

View File

@@ -0,0 +1,488 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.protocols.masquerade import C2Packet
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
class C2Command(Enum):
"""Enumerations representing the different commands the C2 suite currently supports."""
RANSOMWARE_CONFIGURE = "Ransomware Configure"
"Instructs the c2 beacon to configure the ransomware with the provided options."
RANSOMWARE_LAUNCH = "Ransomware Launch"
"Instructs the c2 beacon to execute the installed ransomware."
DATA_EXFILTRATION = "Data Exfiltration"
"Instructs the c2 beacon to attempt to return a file to the C2 Server."
TERMINAL = "Terminal"
"Instructs the c2 beacon to execute the provided terminal command."
class C2Payload(Enum):
"""Represents the different types of command and control payloads."""
KEEP_ALIVE = "keep_alive"
"""C2 Keep Alive payload. Used by the C2 beacon and C2 Server to confirm their connection."""
INPUT = "input_command"
"""C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon."""
OUTPUT = "output_command"
"""C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server."""
class AbstractC2(Application, identifier="AbstractC2"):
"""
An abstract command and control (c2) application.
Extends the application class to provide base functionality for c2 suite applications
such as c2 beacons and c2 servers.
Provides the base methods for handling ``Keep Alive`` connections, configuring masquerade ports and protocols
as well as providing the abstract methods for sending, receiving and parsing commands.
Defaults to masquerading as HTTP (Port 80) via TCP.
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
"""
c2_connection_active: bool = False
"""Indicates if the c2 server and c2 beacon are currently connected."""
c2_remote_connection: IPv4Address = None
"""The IPv4 Address of the remote c2 connection. (Either the IP of the beacon or the server)."""
c2_session: Session = None
"""The currently active session that the C2 Traffic is using. Set after establishing connection."""
keep_alive_inactivity: int = 0
"""Indicates how many timesteps since the last time the c2 application received a keep alive."""
class _C2Opts(BaseModel):
"""A Pydantic Schema for the different C2 configuration options."""
keep_alive_frequency: int = Field(default=5, ge=1)
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
masquerade_protocol: IPProtocol = Field(default=IPProtocol.TCP)
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
masquerade_port: Port = Field(default=Port.HTTP)
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
c2_config: _C2Opts = _C2Opts()
"""
Holds the current configuration settings of the C2 Suite.
The C2 beacon initialise this class through it's internal configure method.
The C2 Server when receiving a keep alive will initialise it's own configuration
to match that of the configuration settings passed in the keep alive through _resolve keep alive.
If the C2 Beacon is reconfigured then a new keep alive is set which causes the
C2 beacon to reconfigure it's configuration settings.
"""
def _craft_packet(
self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {}
) -> C2Packet:
"""
Creates and returns a Masquerade Packet using the parameters given.
The packet uses the current c2 configuration and parameters given
to construct the base networking information such as the masquerade
protocol/port. Additionally all C2 Traffic packets pass the currently
in use C2 configuration. This ensures that the all C2 applications
can keep their configuration in sync.
:param c2_payload: The type of C2 Traffic ot be sent
:type c2_payload: C2Payload
:param c2_command: The C2 command to be sent to the C2 Beacon.
:type c2_command: C2Command.
:param command_options: The relevant C2 Beacon parameters.F
:type command_options: Dict
:return: Returns the construct C2Packet
:rtype: C2Packet
"""
constructed_packet = C2Packet(
masquerade_protocol=self.c2_config.masquerade_protocol,
masquerade_port=self.c2_config.masquerade_port,
keep_alive_frequency=self.c2_config.keep_alive_frequency,
payload_type=c2_payload,
command=c2_command,
payload=command_options,
)
return constructed_packet
def describe_state(self) -> Dict:
"""
Describe the state of the C2 application.
:return: A dictionary representation of the C2 application's state.
:rtype: Dict
"""
return super().describe_state()
def __init__(self, **kwargs):
"""Initialise the C2 applications to by default listen for HTTP traffic."""
kwargs["listen_on_ports"] = {Port.HTTP, Port.FTP, Port.DNS}
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@property
def _host_ftp_client(self) -> Optional[FTPClient]:
"""Return the FTPClient that is installed C2 Application's host.
This method confirms that the FTP Client is functional via the ._can_perform_action
method. If the FTP Client service is not in a suitable state (e.g disabled/pause)
then this method will return None.
(The FTP Client service is installed by default)
:return: An FTPClient object is successful, else None
:rtype: union[FTPClient, None]
"""
ftp_client: Union[FTPClient, None] = self.software_manager.software.get("FTPClient")
if ftp_client is None:
self.sys_log.warning(f"{self.__class__.__name__}: No FTPClient. Attempting to install.")
self.software_manager.install(FTPClient)
ftp_client = self.software_manager.software.get("FTPClient")
# Force start if the service is stopped.
if ftp_client.operating_state == ServiceOperatingState.STOPPED:
if not ftp_client.start():
self.sys_log.warning(f"{self.__class__.__name__}: cannot start the FTP Client.")
if not ftp_client._can_perform_action():
self.sys_log.error(f"{self.__class__.__name__}: is unable to use the FTP service on its host.")
return
return ftp_client
@property
def _host_ftp_server(self) -> Optional[FTPServer]:
"""
Returns the FTP Server that is installed C2 Application's host.
If a FTPServer is not installed then this method will attempt to install one.
:return: An FTPServer object is successful, else None
:rtype: Optional[FTPServer]
"""
ftp_server: Optional[FTPServer] = self.software_manager.software.get("FTPServer")
if ftp_server is None:
self.sys_log.warning(f"{self.__class__.__name__}:No FTPServer installed. Attempting to install FTPServer.")
self.software_manager.install(FTPServer)
ftp_server = self.software_manager.software.get("FTPServer")
# Force start if the service is stopped.
if ftp_server.operating_state == ServiceOperatingState.STOPPED:
if not ftp_server.start():
self.sys_log.warning(f"{self.__class__.__name__}: cannot start the FTP Server.")
if not ftp_server._can_perform_action():
self.sys_log.error(f"{self.__class__.__name__}: is unable use FTP Server service on its host.")
return
return ftp_server
# Getter property for the get_exfiltration_folder method ()
@property
def _host_file_system(self) -> FileSystem:
"""Return the C2 Host's filesystem (Used for exfiltration related commands) ."""
host_file_system: FileSystem = self.software_manager.file_system
if host_file_system is None:
self.sys_log.error(f"{self.__class__.__name__}: does not seem to have a file system!")
return host_file_system
def get_exfiltration_folder(self, folder_name: Optional[str] = "exfiltration_folder") -> Optional[Folder]:
"""Return a folder used for storing exfiltrated data. Otherwise returns None."""
if self._host_file_system is None:
return
exfiltration_folder: Union[Folder, None] = self._host_file_system.get_folder(folder_name)
if exfiltration_folder is None:
self.sys_log.info(f"{self.__class__.__name__}: Creating a exfiltration folder.")
return self._host_file_system.create_folder(folder_name=folder_name)
return exfiltration_folder
# Validate call ensures we are only handling Masquerade Packets.
@validate_call
def _handle_c2_payload(self, payload: C2Packet, session_id: Optional[str] = None) -> bool:
"""Handles masquerade payloads for both c2 beacons and c2 servers.
Currently, the C2 application suite can handle the following payloads:
KEEP ALIVE:
Establishes or confirms connection from the C2 Beacon to the C2 server.
Sent by both C2 beacons and C2 Servers.
INPUT:
Contains a c2 command which must be executed by the C2 beacon.
Sent by C2 Servers and received by C2 Beacons.
OUTPUT:
Contains the output of a c2 command which must be returned to the C2 Server.
Sent by C2 Beacons and received by C2 Servers
The payload is passed to a different method dependant on the payload type.
:param payload: The C2 Payload to be parsed and handled.
:return: True if the c2 payload was handled successfully, False otherwise.
:rtype: Bool
"""
if payload.payload_type == C2Payload.KEEP_ALIVE:
self.sys_log.info(f"{self.name} received a KEEP ALIVE payload.")
return self._handle_keep_alive(payload, session_id)
elif payload.payload_type == C2Payload.INPUT:
self.sys_log.info(f"{self.name} received an INPUT COMMAND payload.")
return self._handle_command_input(payload, session_id)
elif payload.payload_type == C2Payload.OUTPUT:
self.sys_log.info(f"{self.name} received an OUTPUT COMMAND payload.")
return self._handle_command_output(payload)
else:
self.sys_log.warning(
f"{self.name} received an unexpected c2 payload:{payload.payload_type}. Dropping Packet."
)
return False
@abstractmethod
def _handle_command_output(payload):
"""Abstract Method: Used in C2 server to parse and receive the output of commands sent to the c2 beacon."""
pass
@abstractmethod
def _handle_command_input(payload):
"""Abstract Method: Used in C2 beacon to parse and handle commands received from the c2 server."""
pass
@abstractmethod
def _handle_keep_alive(self, payload: C2Packet, session_id: Optional[str]) -> bool:
"""Abstract Method: Each C2 suite handles ``C2Payload.KEEP_ALIVE`` differently."""
pass
# from_network_interface=from_network_interface
def receive(self, payload: any, session_id: Optional[str] = None, **kwargs) -> bool:
"""Receives masquerade packets. Used by both c2 server and c2 beacon.
Defining the `Receive` method so that the application can receive packets via the session manager.
These packets are then immediately handed to ._handle_c2_payload.
:param payload: The Masquerade Packet to be received.
:type payload: C2Packet
:param session_id: The transport session_id that the payload is originating from.
:type session_id: str
:return: Returns a bool if the traffic was received correctly (See _handle_c2_payload.)
:rtype: bool
"""
if not isinstance(payload, C2Packet):
self.sys_log.warning(f"{self.name}: Payload is not an C2Packet")
self.sys_log.debug(f"{self.name}: {payload}")
return False
return self._handle_c2_payload(payload, session_id)
def _send_keep_alive(self, session_id: Optional[str]) -> bool:
"""Sends a C2 keep alive payload to the self.remote_connection IPv4 Address.
Used by both the c2 client and the s2 server for establishing and confirming connection.
This method also contains some additional validation to ensure that the C2 applications
are correctly configured before sending any traffic.
:param session_id: The transport session_id that the payload is originating from.
:type session_id: str
:returns: Returns True if a send alive was successfully sent. False otherwise.
:rtype bool:
"""
# Checking that the c2 application is capable of connecting to remote.
# Purely a safety guard clause.
if not (connection_status := self._check_connection()[0]):
self.sys_log.warning(
f"{self.name}: Unable to send keep alive due to c2 connection status: {connection_status}."
)
return False
# Passing our current C2 configuration in remain in sync.
keep_alive_packet = self._craft_packet(c2_payload=C2Payload.KEEP_ALIVE)
# Sending the keep alive via the .send() method (as with all other applications.)
if self.send(
payload=keep_alive_packet,
dest_ip_address=self.c2_remote_connection,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
session_id=session_id,
):
# Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms.
# This prevents the _resolve_keep_alive method from calling this method again (until the next timestep.)
self.keep_alive_sent = True
self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}")
self.sys_log.debug(
f"{self.name}: Keep Alive sent to {self.c2_remote_connection} "
f"Masquerade Port: {self.c2_config.masquerade_port} "
f"Masquerade Protocol: {self.c2_config.masquerade_protocol} "
)
return True
else:
self.sys_log.warning(
f"{self.name}: Failed to send a Keep Alive. The node may be unable to access networking resources."
)
return False
def _resolve_keep_alive(self, payload: C2Packet, session_id: Optional[str]) -> bool:
"""
Parses the Masquerade Port/Protocol within the received Keep Alive packet.
Used to dynamically set the Masquerade Port and Protocol based on incoming traffic.
Returns True on successfully extracting and configuring the masquerade port/protocols.
Returns False otherwise.
:param payload: The Keep Alive payload received.
:type payload: C2Packet
:param session_id: The transport session_id that the payload is originating from.
:type session_id: str
:return: True on successful configuration, false otherwise.
:rtype: bool
"""
# Validating that they are valid Enums.
if not isinstance(payload.masquerade_port, Port) or not isinstance(payload.masquerade_protocol, IPProtocol):
self.sys_log.warning(
f"{self.name}: Received invalid Masquerade Values within Keep Alive."
f"Port: {payload.masquerade_port} Protocol: {payload.masquerade_protocol}."
)
return False
# Updating the C2 Configuration attribute.
self.c2_config.masquerade_port = payload.masquerade_port
self.c2_config.masquerade_protocol = payload.masquerade_protocol
self.c2_config.keep_alive_frequency = payload.keep_alive_frequency
self.sys_log.debug(
f"{self.name}: C2 Config Resolved Config from Keep Alive:"
f"Masquerade Port: {self.c2_config.masquerade_port}"
f"Masquerade Protocol: {self.c2_config.masquerade_protocol}"
f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}"
)
# This statement is intended to catch on the C2 Application that is listening for connection.
if self.c2_remote_connection is None:
self.sys_log.debug(f"{self.name}: Attempting to configure remote C2 connection based off received output.")
self.c2_remote_connection = IPv4Address(self.c2_session.with_ip_address)
self.c2_connection_active = True # Sets the connection to active
self.keep_alive_inactivity = 0 # Sets the keep alive inactivity to zero
return True
def _reset_c2_connection(self) -> None:
"""
Resets all currently established C2 communications to their default setting.
This method is called once a C2 application considers their remote connection
severed and reverts back to default settings. Worth noting that that this will
revert any non-default configuration that a user/agent may have set.
"""
self.c2_connection_active = False
self.c2_session = None
self.keep_alive_inactivity = 0
self.keep_alive_frequency = 5
self.c2_remote_connection = None
self.c2_config.masquerade_port = Port.HTTP
self.c2_config.masquerade_protocol = IPProtocol.TCP
@abstractmethod
def _confirm_remote_connection(self, timestep: int) -> bool:
"""
Abstract method - Confirms the suitability of the current C2 application remote connection.
Each application will have perform different behaviour to confirm the remote connection.
:return: Boolean. True if remote connection is confirmed, false otherwise.
"""
def apply_timestep(self, timestep: int) -> None:
"""Apply a timestep to the c2_server & c2 beacon.
Used to keep track of when the c2 server should consider a beacon dead
and set it's c2_remote_connection attribute to false.
1. Each timestep the keep_alive_inactivity is increased.
2. If the keep alive inactivity eclipses that of the keep alive frequency then another keep alive is sent.
3. If a keep alive response packet is received then the ``keep_alive_inactivity`` attribute is reset.
Therefore, if ``keep_alive_inactivity`` attribute is not 0 after a keep alive is sent
then the connection is considered severed and c2 beacon will shut down.
:param timestep: The current timestep of the simulation.
:type timestep: Int
:return bool: Returns false if connection was lost. Returns True if connection is active or re-established.
:rtype bool:
"""
if (
self.operating_state is ApplicationOperatingState.RUNNING
and self.health_state_actual is SoftwareHealthState.GOOD
and self.c2_connection_active is True
):
self.keep_alive_inactivity += 1
self._confirm_remote_connection(timestep)
return super().apply_timestep(timestep=timestep)
def _check_connection(self) -> tuple[bool, RequestResponse]:
"""
Validation method: Checks that the C2 application is capable of sending C2 Command input/output.
Performs a series of connection validation to ensure that the C2 application is capable of
sending and responding to the remote c2 connection. This method is used to confirm connection
before carrying out Agent Commands hence why this method also returns a tuple
containing both a success boolean as well as RequestResponse.
:return: A tuple containing a boolean True/False and a corresponding Request Response
:rtype: tuple[bool, RequestResponse]
"""
if not self._can_perform_network_action():
self.sys_log.warning(f"{self.name}: Unable to make leverage networking resources. Rejecting Command.")
return (
False,
RequestResponse(
status="failure", data={"Reason": "Unable to access networking resources. Unable to send command."}
),
)
if self.c2_remote_connection is None:
self.sys_log.warning(f"{self.name}: C2 Application has yet to establish connection. Rejecting command.")
return (
False,
RequestResponse(
status="failure",
data={"Reason": "C2 Application has yet to establish connection. Unable to send command."},
),
)
return (
True,
RequestResponse(status="success", data={"Reason": "C2 Application is able to send connections."}),
)

View File

@@ -0,0 +1,636 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.masquerade import C2Packet
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts, RansomwareOpts, TerminalOpts
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection
class C2Beacon(AbstractC2, identifier="C2Beacon"):
"""
C2 Beacon Application.
Represents a vendor generic C2 beacon is used in conjunction with the C2 Server
to simulate malicious communications and infrastructure within primAITE.
Must be configured with the C2 Server's IP Address upon installation.
Please refer to the _configure method for further information.
Extends the Abstract C2 application to include the following:
1. Receiving commands from the C2 Server (Command input)
2. Leveraging the terminal application to execute requests (dependent on the command given)
3. Sending the RequestResponse back to the C2 Server (Command output)
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
"""
keep_alive_attempted: bool = False
"""Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms."""
terminal_session: TerminalClientConnection = None
"The currently in use terminal session."
@property
def _host_terminal(self) -> Optional[Terminal]:
"""Return the Terminal that is installed on the same machine as the C2 Beacon."""
host_terminal: Terminal = self.software_manager.software.get("Terminal")
if host_terminal is None:
self.sys_log.warning(f"{self.__class__.__name__} cannot find a terminal on its host.")
return host_terminal
@property
def _host_ransomware_script(self) -> RansomwareScript:
"""Return the RansomwareScript that is installed on the same machine as the C2 Beacon."""
ransomware_script: RansomwareScript = self.software_manager.software.get("RansomwareScript")
if ransomware_script is None:
self.sys_log.warning(f"{self.__class__.__name__} cannot find installed ransomware on its host.")
return ransomware_script
def _set_terminal_session(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool:
"""
Attempts to create and a terminal session using the parameters given.
If an IP Address is passed then this method will attempt to create a remote terminal
session. Otherwise a local terminal session will be created.
:return: Returns true if a terminal session was successfully set. False otherwise.
:rtype: Bool
"""
self.terminal_session is None
host_terminal: Terminal = self._host_terminal
self.terminal_session = host_terminal.login(username=username, password=password, ip_address=ip_address)
return self.terminal_session is not None
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
rm.add_request(
name="execute",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.establish())),
)
def _configure(request: RequestFormat, context: Dict) -> RequestResponse:
"""
Request for configuring the C2 Beacon.
:param request: Request with one element containing a dict of parameters for the configure method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: RequestResponse object with a success code reflecting whether the configuration could be applied.
:rtype: RequestResponse
"""
c2_remote_ip = request[-1].get("c2_server_ip_address")
if c2_remote_ip is None:
self.sys_log.error(f"{self.name}: Did not receive C2 Server IP in configuration parameters.")
RequestResponse(
status="failure", data={"No C2 Server IP given to C2 beacon. Unable to configure C2 Beacon"}
)
c2_remote_ip = IPv4Address(c2_remote_ip)
frequency = request[-1].get("keep_alive_frequency")
protocol = request[-1].get("masquerade_protocol")
port = request[-1].get("masquerade_port")
return RequestResponse.from_bool(
self.configure(
c2_server_ip_address=c2_remote_ip,
keep_alive_frequency=frequency,
masquerade_protocol=IPProtocol[protocol],
masquerade_port=Port[port],
)
)
rm.add_request("configure", request_type=RequestType(func=_configure))
return rm
def __init__(self, **kwargs):
kwargs["name"] = "C2Beacon"
super().__init__(**kwargs)
# Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager.
@validate_call
def configure(
self,
c2_server_ip_address: IPv4Address = None,
keep_alive_frequency: int = 5,
masquerade_protocol: Enum = IPProtocol.TCP,
masquerade_port: Enum = Port.HTTP,
) -> bool:
"""
Configures the C2 beacon to communicate with the C2 server.
The C2 Beacon has four different configuration options which can be used to
modify the networking behaviour between the C2 Server and the C2 Beacon.
Configuration Option | Option Meaning
---------------------|------------------------
c2_server_ip_address | The IP Address of the C2 Server. (The C2 Server must be running)
keep_alive_frequency | How often should the C2 Beacon confirm it's connection in timesteps.
masquerade_protocol | What protocol should the C2 traffic masquerade as? (HTTP, FTP or DNS)
masquerade_port | What port should the C2 traffic use? (TCP or UDP)
These configuration options are used to reassign the fields in the inherited inner class
``c2_config``.
If a connection is already in progress then this method also sends a keep alive to the C2
Server in order for the C2 Server to sync with the new configuration settings.
:param c2_server_ip_address: The IP Address of the C2 Server. Used to establish connection.
:type c2_server_ip_address: IPv4Address
:param keep_alive_frequency: The frequency (timesteps) at which the C2 beacon will send keep alive(s).
:type keep_alive_frequency: Int
:param masquerade_protocol: The Protocol that C2 Traffic will masquerade as. Defaults to TCP.
:type masquerade_protocol: Enum (IPProtocol)
:param masquerade_port: The Port that the C2 Traffic will masquerade as. Defaults to FTP.
:type masquerade_port: Enum (Port)
:return: Returns True if the configuration was successful, False otherwise.
"""
self.c2_remote_connection = IPv4Address(c2_server_ip_address)
self.c2_config.keep_alive_frequency = keep_alive_frequency
self.c2_config.masquerade_port = masquerade_port
self.c2_config.masquerade_protocol = masquerade_protocol
self.sys_log.info(
f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}."
)
self.sys_log.debug(
f"{self.name}: configured with the following settings:"
f"Remote C2 Server: {c2_server_ip_address}"
f"Keep Alive Frequency {keep_alive_frequency}"
f"Masquerade Protocol: {masquerade_protocol}"
f"Masquerade Port: {masquerade_port}"
)
# Send a keep alive to the C2 Server if we already have a keep alive.
if self.c2_connection_active is True:
self.sys_log.info(f"{self.name}: Updating C2 Server with updated C2 configuration.")
return self._send_keep_alive(self.c2_session.uuid if not None else None)
return True
def establish(self) -> bool:
"""Establishes connection to the C2 server via a send alive. The C2 Beacon must already be configured."""
if self.c2_remote_connection is None:
self.sys_log.info(f"{self.name}: Failed to establish connection. C2 Beacon has not been configured.")
return False
self.run()
self.num_executions += 1
# Creates a new session if using the establish method.
return self._send_keep_alive(session_id=None)
def _handle_command_input(self, payload: C2Packet, session_id: Optional[str]) -> bool:
"""
Handles the parsing of C2 Commands from C2 Traffic (Masquerade Packets).
Dependant the C2 Command parsed from the payload, the following methods are called and returned:
C2 Command | Internal Method
---------------------|------------------------
RANSOMWARE_CONFIGURE | self._command_ransomware_config()
RANSOMWARE_LAUNCH | self._command_ransomware_launch()
DATA_EXFILTRATION | self._command_data_exfiltration()
TERMINAL | self._command_terminal()
Please see each method individually for further information regarding
the implementation of these commands.
:param payload: The INPUT C2 Payload
:type payload: C2Packet
:return: The Request Response provided by the terminal execute method.
:rtype Request Response:
"""
command = payload.command
if not isinstance(command, C2Command):
self.sys_log.warning(f"{self.name}: Received unexpected C2 command. Unable to resolve command")
return self._return_command_output(
command_output=RequestResponse(
status="failure",
data={"Reason": "C2 Beacon received unexpected C2Command. Unable to resolve command."},
),
session_id=session_id,
)
if command == C2Command.RANSOMWARE_CONFIGURE:
self.sys_log.info(f"{self.name}: Received a ransomware configuration C2 command.")
return self._return_command_output(
command_output=self._command_ransomware_config(payload), session_id=session_id
)
elif command == C2Command.RANSOMWARE_LAUNCH:
self.sys_log.info(f"{self.name}: Received a ransomware launch C2 command.")
return self._return_command_output(
command_output=self._command_ransomware_launch(payload), session_id=session_id
)
elif command == C2Command.TERMINAL:
self.sys_log.info(f"{self.name}: Received a terminal C2 command.")
return self._return_command_output(command_output=self._command_terminal(payload), session_id=session_id)
elif command == C2Command.DATA_EXFILTRATION:
self.sys_log.info(f"{self.name}: Received a Data Exfiltration C2 command.")
return self._return_command_output(
command_output=self._command_data_exfiltration(payload), session_id=session_id
)
else:
self.sys_log.error(f"{self.name}: Received an C2 command: {command} but was unable to resolve command.")
return self._return_command_output(
RequestResponse(status="failure", data={"Reason": "Unexpected Behaviour. Unable to resolve command."})
)
def _return_command_output(self, command_output: RequestResponse, session_id: Optional[str] = None) -> bool:
"""Responsible for responding to the C2 Server with the output of the given command.
:param command_output: The RequestResponse returned by the terminal application's execute method.
:type command_output: Request Response
:param session_id: The current session established with the C2 Server.
:type session_id: Str
"""
output_packet = self._craft_packet(c2_payload=C2Payload.OUTPUT, command_options=command_output)
if self.send(
payload=output_packet,
dest_ip_address=self.c2_remote_connection,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
session_id=session_id,
):
self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}")
self.sys_log.debug(
f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}"
)
return True
else:
self.sys_log.warning(
f"{self.name}: failed to send a output packet. The node may be unable to access the network."
)
return False
def _command_ransomware_config(self, payload: C2Packet) -> RequestResponse:
"""
C2 Command: Ransomware Configuration.
Calls the locally installed RansomwareScript application's configure method
and passes the given parameters.
The class attribute self._host_ransomware_script will return None if the host
does not have an instance of the RansomwareScript.
:payload C2Packet: The incoming INPUT command.
:type Masquerade Packet: C2Packet.
:return: Returns the Request Response returned by the Terminal execute method.
:rtype: Request Response
"""
command_opts = RansomwareOpts.model_validate(payload.payload)
if self._host_ransomware_script is None:
return RequestResponse(
status="failure",
data={"Reason": "Cannot find any instances of a RansomwareScript. Have you installed one?"},
)
return RequestResponse.from_bool(
self._host_ransomware_script.configure(
server_ip_address=command_opts.server_ip_address, payload=command_opts.payload
)
)
def _command_ransomware_launch(self, payload: C2Packet) -> RequestResponse:
"""
C2 Command: Ransomware Launch.
Uses the RansomwareScript's public method .attack() to carry out the
ransomware attack and uses the .from_bool method to return a RequestResponse
:payload C2Packet: The incoming INPUT command.
:type Masquerade Packet: C2Packet.
:return: Returns the Request Response returned by the Terminal execute method.
:rtype: Request Response
"""
if self._host_ransomware_script is None:
return RequestResponse(
status="failure",
data={"Reason": "Cannot find any instances of a RansomwareScript. Have you installed one?"},
)
return RequestResponse.from_bool(self._host_ransomware_script.attack())
def _command_data_exfiltration(self, payload: C2Packet) -> RequestResponse:
"""
C2 Command: Data Exfiltration.
Uses the FTP Client & Server services to perform the data exfiltration.
This command instructs the C2 Beacon to ssh into the target ip
and execute a command which causes the FTPClient service to send a
target file will be moved from the target IP address onto the C2 Beacon's host
file system.
However, if no IP is given, then the command will move the target file from this
machine onto the C2 server. (This logic is performed on the C2)
:payload C2Packet: The incoming INPUT command.
:type Masquerade Packet: C2Packet.
:return: Returns a tuple containing Request Response returned by the Terminal execute method.
:rtype: Request Response
"""
if self._host_ftp_server is None:
self.sys_log.warning(f"{self.name}: C2 Beacon unable to the FTP Server. Unable to resolve command.")
return RequestResponse(
status="failure",
data={"Reason": "Cannot find any instances of both a FTP Server & Client. Are they installed?"},
)
command_opts = ExfilOpts.model_validate(payload.payload)
# Setting up the terminal session and the ftp server
if not self._set_terminal_session(
username=command_opts.username, password=command_opts.password, ip_address=command_opts.target_ip_address
):
return RequestResponse(
status="failure", data={"Reason": "Cannot create a terminal session. Are the credentials correct?"}
)
# Using the terminal to start the FTP Client on the remote machine.
self.terminal_session.execute(command=["service", "start", "FTPClient"])
# Need to supply to the FTP Client the C2 Beacon's host IP.
host_network_interfaces = self.software_manager.node.network_interfaces
local_ip = host_network_interfaces.get(next(iter(host_network_interfaces))).ip_address
# Creating the FTP creation options.
ftp_opts = {
"dest_ip_address": str(local_ip),
"src_folder_name": command_opts.target_folder_name,
"src_file_name": command_opts.target_file_name,
"dest_folder_name": command_opts.exfiltration_folder_name,
"dest_file_name": command_opts.target_file_name,
}
attempt_exfiltration: tuple[bool, RequestResponse] = self._perform_exfiltration(ftp_opts)
if attempt_exfiltration[0] is False:
self.sys_log.error(f"{self.name}: File Exfiltration Attempt Failed: {attempt_exfiltration[1].data}")
return attempt_exfiltration[1]
# Sending the transferred target data back to the C2 Server to successfully exfiltrate the data out the network.
return RequestResponse.from_bool(
self._host_ftp_client.send_file(
dest_ip_address=self.c2_remote_connection,
src_folder_name=command_opts.exfiltration_folder_name, # The Exfil folder is inherited attribute.
src_file_name=command_opts.target_file_name,
dest_folder_name=command_opts.exfiltration_folder_name,
dest_file_name=command_opts.target_file_name,
)
)
def _perform_exfiltration(self, ftp_opts: dict) -> tuple[bool, RequestResponse]:
"""
Attempts to exfiltrate a target file from a target using the parameters given.
Uses the current terminal_session to send a command to the
remote host's FTP Client passing the ExfilOpts as command options.
This will instruct the FTP client to send the target file to the
dest_ip_address's destination folder.
This method assumes that the following:
1. The self.terminal_session is the remote target.
2. The target has a functioning FTP Client Service.
:ExfilOpts: A Pydantic model containing the require configuration options
:type ExfilOpts: ExfilOpts
:return: Returns a tuple containing a success boolean and a Request Response..
:rtype: tuple[bool, RequestResponse
"""
# Creating the exfiltration folder .
exfiltration_folder = self.get_exfiltration_folder(ftp_opts.get("dest_folder_name"))
# Using the terminal to send the target data back to the C2 Beacon.
exfil_response: RequestResponse = RequestResponse.from_bool(
self.terminal_session.execute(command=["service", "FTPClient", "send", ftp_opts])
)
# Validating that we successfully received the target data.
if exfil_response.status == "failure":
self.sys_log.warning(f"{self.name}: Remote connection failure. failed to transfer the target data via FTP.")
return [False, exfil_response]
# Target file:
target_file: str = ftp_opts.get("src_file_name")
if exfiltration_folder.get_file(target_file) is None:
self.sys_log.warning(
f"{self.name}: Unable to locate exfiltrated file on local filesystem. "
f"Perhaps the file transfer failed?"
)
return [
False,
RequestResponse(status="failure", data={"reason": "Unable to locate exfiltrated data on file system."}),
]
if self._host_ftp_client is None:
self.sys_log.warning(f"{self.name}: C2 Beacon unable to the FTP Server. Unable to resolve command.")
return [
False,
RequestResponse(
status="failure",
data={"Reason": "Cannot find any instances of both a FTP Server & Client. Are they installed?"},
),
]
return [
True,
RequestResponse(
status="success",
data={"Reason": "Located the target file on local file system. Data exfiltration successful."},
),
]
def _command_terminal(self, payload: C2Packet) -> RequestResponse:
"""
C2 Command: Terminal.
Creates a request that executes a terminal command.
This request is then sent to the terminal service in order to be executed.
:payload C2Packet: The incoming INPUT command.
:type Masquerade Packet: C2Packet.
:return: Returns the Request Response returned by the Terminal execute method.
:rtype: Request Response
"""
command_opts = TerminalOpts.model_validate(payload.payload)
if self._host_terminal is None:
return RequestResponse(
status="failure",
data={"Reason": "Host does not seem to have terminal installed. Unable to resolve command."},
)
terminal_output: Dict[int, RequestResponse] = {}
# Creating a remote terminal session if given an IP Address, otherwise using a local terminal session.
if not self._set_terminal_session(
username=command_opts.username, password=command_opts.password, ip_address=command_opts.ip_address
):
return RequestResponse(
status="failure",
data={"Reason": "Cannot create a terminal session. Are the credentials correct?"},
)
# Converts a singular terminal command: [RequestFormat] into a list with one element [[RequestFormat]]
# Checks the first element - if this element is a str then there must be multiple commands.
command_opts.commands = (
[command_opts.commands] if isinstance(command_opts.commands[0], str) else command_opts.commands
)
for index, given_command in enumerate(command_opts.commands):
# A try catch exception ladder was used but was considered not the best approach
# as it can end up obscuring visibility of actual bugs (Not the expected ones) and was a temporary solution.
# TODO: Refactor + add further validation to ensure that a request is correct. (maybe a pydantic method?)
terminal_output[index] = self.terminal_session.execute(given_command)
# Reset our remote terminal session.
self.terminal_session is None
return RequestResponse(status="success", data=terminal_output)
def _handle_keep_alive(self, payload: C2Packet, session_id: Optional[str]) -> bool:
"""
Handles receiving and sending keep alive payloads. This method is only called if we receive a keep alive.
In the C2 Beacon implementation of this method the c2 connection active boolean
is set to true and the keep alive inactivity is reset only after sending a keep alive
as wel as receiving a response back from the C2 Server.
This is because the C2 Server is the listener and thus will only ever receive packets from
the C2 Beacon rather than the other way around. (The C2 Beacon is akin to a reverse shell)
Therefore, we need a response back from the listener (C2 Server)
before the C2 beacon is able to confirm it's connection.
Returns False if a keep alive was unable to be sent.
Returns True if a keep alive was successfully sent or already has been sent this timestep.
:return: True if successfully handled, false otherwise.
:rtype: Bool
"""
self.sys_log.info(f"{self.name}: Keep Alive Received from {self.c2_remote_connection}.")
# Using this guard clause to prevent packet storms and recognise that we've achieved a connection.
# This guard clause triggers on the c2 suite that establishes connection.
if self.keep_alive_attempted is True:
self.c2_connection_active = True # Sets the connection to active
self.keep_alive_inactivity = 0 # Sets the keep alive inactivity to zero
self.c2_session = self.software_manager.session_manager.sessions_by_uuid[session_id]
# We set keep alive_attempted here to show that we've achieved connection.
self.keep_alive_attempted = False
self.sys_log.warning(f"{self.name}: Connection successfully Established with C2 Server.")
return True
# If we've reached this part of the method then we've received a keep alive but haven't sent a reply.
# Therefore we also need to configure the masquerade attributes based off the keep alive sent.
if self._resolve_keep_alive(payload, session_id) is False:
self.sys_log.warning(f"{self.name}: Keep Alive Could not be resolved correctly. Refusing Keep Alive.")
return False
self.keep_alive_attempted = True
# If this method returns true then we have sent successfully sent a keep alive.
return self._send_keep_alive(session_id)
def _confirm_remote_connection(self, timestep: int) -> bool:
"""Checks the suitability of the current C2 Server connection.
If a connection cannot be confirmed then this method will return false otherwise true.
:param timestep: The current timestep of the simulation.
:type timestep: Int
:return: Returns False if connection was lost. Returns True if connection is active or re-established.
:rtype bool:
"""
self.keep_alive_attempted = False # Resetting keep alive sent.
if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency:
self.sys_log.info(
f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}."
)
self._send_keep_alive(session_id=self.c2_session.uuid)
if self.keep_alive_inactivity != 0:
self.sys_log.warning(
f"{self.name}: Did not receive keep alive from c2 Server. Connection considered severed."
)
self._reset_c2_connection()
self.close()
return False
return True
# Defining this abstract method from Abstract C2
def _handle_command_output(self, payload: C2Packet):
"""C2 Beacons currently does not need to handle output commands coming from the C2 Servers."""
self.sys_log.warning(f"{self.name}: C2 Beacon received an unexpected OUTPUT payload: {payload}.")
pass
def show(self, markdown: bool = False):
"""
Prints a table of the current status of the C2 Beacon.
Displays the current values of the following C2 attributes:
``C2 Connection Active``:
If the C2 Beacon is currently connected to the C2 Server
``C2 Remote Connection``:
The IP of the C2 Server. (Configured by upon installation)
``Keep Alive Inactivity``:
How many timesteps have occurred since the last keep alive.
``Keep Alive Frequency``:
How often should the C2 Beacon attempt a keep alive?
``Current Masquerade Protocol``:
The current protocol that the C2 Traffic is using. (e.g TCP/UDP)
``Current Masquerade Port``:
The current port that the C2 Traffic is using. (e.g HTTP (Port 80))
:param markdown: If True, outputs the table in markdown format. Default is False.
"""
table = PrettyTable(
[
"C2 Connection Active",
"C2 Remote Connection",
"Keep Alive Inactivity",
"Keep Alive Frequency",
"Current Masquerade Protocol",
"Current Masquerade Port",
]
)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.name} Running Status"
table.add_row(
[
self.c2_connection_active,
self.c2_remote_connection,
self.keep_alive_inactivity,
self.c2_config.keep_alive_frequency,
self.c2_config.masquerade_protocol,
self.c2_config.masquerade_port,
]
)
print(table)

View File

@@ -0,0 +1,396 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.masquerade import C2Packet
from primaite.simulator.system.applications.red_applications.c2 import (
CommandOpts,
ExfilOpts,
RansomwareOpts,
TerminalOpts,
)
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
class C2Server(AbstractC2, identifier="C2Server"):
"""
C2 Server Application.
Represents a vendor generic C2 Server used in conjunction with the C2 beacon
to simulate malicious communications and infrastructure within primAITE.
The C2 Server must be installed and be in a running state before it's able to receive
red agent actions and send commands to the C2 beacon.
Extends the Abstract C2 application to include the following:
1. Sending commands to the C2 Beacon. (Command input)
2. Parsing terminal RequestResponses back to the Agent.
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
"""
current_command_output: RequestResponse = None
"""The Request Response by the last command send. This attribute is updated by the method _handle_command_output."""
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
def _configure_ransomware_action(request: RequestFormat, context: Dict) -> RequestResponse:
"""Requests - Sends a RANSOMWARE_CONFIGURE C2Command to the C2 Beacon with the given parameters.
:param request: Request with one element containing a dict of parameters for the configure method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: RequestResponse object with a success code reflecting whether the configuration could be applied.
:rtype: RequestResponse
"""
command_payload = {
"server_ip_address": request[-1].get("server_ip_address"),
"payload": request[-1].get("payload"),
}
return self.send_command(given_command=C2Command.RANSOMWARE_CONFIGURE, command_options=command_payload)
def _launch_ransomware_action(request: RequestFormat, context: Dict) -> RequestResponse:
"""Agent Action - Sends a RANSOMWARE_LAUNCH C2Command to the C2 Beacon with the given parameters.
:param request: Request with one element containing a dict of parameters for the configure method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: RequestResponse object with a success code reflecting whether the ransomware was launched.
:rtype: RequestResponse
"""
return self.send_command(given_command=C2Command.RANSOMWARE_LAUNCH, command_options={})
def _data_exfiltration_action(request: RequestFormat, context: Dict) -> RequestResponse:
"""Agent Action - Sends a Data Exfiltration C2Command to the C2 Beacon with the given parameters.
:param request: Request with one element containing a dict of parameters for the configure method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: RequestResponse object with a success code reflecting whether the ransomware was launched.
:rtype: RequestResponse
"""
command_payload = request[-1]
return self.send_command(given_command=C2Command.DATA_EXFILTRATION, command_options=command_payload)
def _remote_terminal_action(request: RequestFormat, context: Dict) -> RequestResponse:
"""Agent Action - Sends a TERMINAL C2Command to the C2 Beacon with the given parameters.
:param request: Request with one element containing a dict of parameters for the configure method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: RequestResponse object with a success code reflecting whether the ransomware was launched.
:rtype: RequestResponse
"""
command_payload = request[-1]
return self.send_command(given_command=C2Command.TERMINAL, command_options=command_payload)
rm.add_request(
name="ransomware_configure",
request_type=RequestType(func=_configure_ransomware_action),
)
rm.add_request(
name="ransomware_launch",
request_type=RequestType(func=_launch_ransomware_action),
)
rm.add_request(
name="terminal_command",
request_type=RequestType(func=_remote_terminal_action),
)
rm.add_request(
name="exfiltrate",
request_type=RequestType(func=_data_exfiltration_action),
)
return rm
def __init__(self, **kwargs):
kwargs["name"] = "C2Server"
super().__init__(**kwargs)
self.run()
def _handle_command_output(self, payload: C2Packet) -> bool:
"""
Handles the parsing of C2 Command Output from C2 Traffic (Masquerade Packets).
Parses the Request Response from the given C2Packet's payload attribute (Inherited from Data packet).
This RequestResponse is then stored in the C2 Server class attribute self.current_command_output.
If the payload attribute does not contain a RequestResponse, then an error will be raised in syslog and
the self.current_command_output is updated to reflect the error.
:param payload: The OUTPUT C2 Payload
:type payload: C2Packet
:return: Returns True if the self.current_command_output was updated, false otherwise.
:rtype Bool:
"""
self.sys_log.info(f"{self.name}: Received command response from C2 Beacon: {payload}.")
command_output = payload.payload
if not isinstance(command_output, RequestResponse):
self.sys_log.warning(f"{self.name}: C2 Server received invalid command response: {command_output}.")
self.current_command_output = RequestResponse(
status="failure", data={"Reason": "Received unexpected C2 Response."}
)
return False
self.current_command_output = command_output
return True
def _handle_keep_alive(self, payload: C2Packet, session_id: Optional[str]) -> bool:
"""
Handles receiving and sending keep alive payloads. This method is only called if we receive a keep alive.
Abstract method inherited from abstract C2.
In the C2 Server implementation of this method the following logic is performed:
1. The ``self.c2_connection_active`` is set to True. (Indicates that we're received a connection)
2. The received keep alive (Payload parameter) is then resolved by _resolve_keep_alive.
3. After the keep alive is resolved, a keep alive is sent back to confirm connection.
This is because the C2 Server is the listener and thus will only ever receive packets from
the C2 Beacon rather than the other way around.
The C2 Beacon/Server communication is akin to that of a real-world reverse shells.
Returns False if a keep alive was unable to be sent.
Returns True if a keep alive was successfully sent or already has been sent this timestep.
:param payload: The Keep Alive payload received.
:type payload: C2Packet
:param session_id: The transport session_id that the payload originates from.
:type session_id: str
:return: True if the keep alive was successfully handled, false otherwise.
:rtype: Bool
"""
self.sys_log.info(f"{self.name}: Keep Alive Received. Attempting to resolve the remote connection details.")
self.c2_connection_active = True # Sets the connection to active
self.c2_session = self.software_manager.session_manager.sessions_by_uuid[session_id]
if self._resolve_keep_alive(payload, session_id) == False:
self.sys_log.warning(f"{self.name}: Keep Alive Could not be resolved correctly. Refusing Keep Alive.")
return False
self.sys_log.info(f"{self.name}: Remote connection successfully established: {self.c2_remote_connection}.")
self.sys_log.debug(f"{self.name}: Attempting to send Keep Alive response back to {self.c2_remote_connection}.")
# If this method returns true then we have sent successfully sent a keep alive response back.
return self._send_keep_alive(session_id)
@validate_call
def send_command(self, given_command: C2Command, command_options: Dict) -> RequestResponse:
"""
Sends a C2 command to the C2 Beacon using the given parameters.
C2 Command | Command Synopsis
---------------------|------------------------
RANSOMWARE_CONFIGURE | Configures an installed ransomware script based on the passed parameters.
RANSOMWARE_LAUNCH | Launches the installed ransomware script.
DATA_EXFILTRATION | Utilises the FTP Service to exfiltrate data back to the C2 Server.
TERMINAL | Executes a command via the terminal installed on the C2 Beacons Host.
Currently, these commands leverage the pre-existing capability of other applications.
However, the commands are sent via the network rather than the game layer which
grants more opportunity to the blue agent to prevent attacks.
Additionally, future editions of primAITE may expand the C2 repertoire to allow for
more complex red agent behaviour such as establishing further fall back channels
or introduce red applications that are only installable via C2 Servers. (T1105)
For more information on the impact of these commands please refer to the terminal
and the ransomware applications.
:param given_command: The C2 command to be sent to the C2 Beacon.
:type given_command: C2Command.
:param command_options: The relevant C2 Beacon parameters.
:type command_options: Dict
:return: Returns the Request Response of the C2 Beacon's host terminal service execute method.
:rtype: RequestResponse
"""
if not isinstance(given_command, C2Command):
self.sys_log.warning(f"{self.name}: Received unexpected C2 command. Unable to send command.")
return RequestResponse(
status="failure", data={"Reason": "Received unexpected C2Command. Unable to send command."}
)
connection_status: tuple[bool, RequestResponse] = self._check_connection()
if connection_status[0] is False:
return connection_status[1]
setup_success, command_options = self._command_setup(given_command, command_options)
if setup_success is False:
self.sys_log.warning(
f"{self.name}: Failed to perform necessary C2 Server setup for given command: {given_command}."
)
return RequestResponse(
status="failure", data={"Reason": "Failed to perform necessary C2 Server setup for given command."}
)
self.sys_log.info(f"{self.name}: Attempting to send command {given_command}.")
command_packet = self._craft_packet(
c2_payload=C2Payload.INPUT, c2_command=given_command, command_options=command_options.model_dump()
)
if self.send(
payload=command_packet,
dest_ip_address=self.c2_remote_connection,
session_id=self.c2_session.uuid,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
):
self.sys_log.info(f"{self.name}: Successfully sent {given_command}.")
self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.")
# If the command output was handled currently, the self.current_command_output will contain the RequestResponse.
if self.current_command_output is None:
return RequestResponse(
status="failure", data={"Reason": "Command sent to the C2 Beacon but no response was ever received."}
)
return self.current_command_output
def _command_setup(self, given_command: C2Command, command_options: dict) -> tuple[bool, CommandOpts]:
"""
Performs any necessary C2 Server setup needed to perform certain commands.
This includes any option validation and any other required setup.
The following table details any C2 Server prequisites for following commands.
C2 Command | Command Service/Application Requirements
---------------------|-----------------------------------------
RANSOMWARE_CONFIGURE | N/A
RANSOMWARE_LAUNCH | N/A
DATA_EXFILTRATION | FTP Server & File system folder
TERMINAL | N/A
Currently, only the data exfiltration command require the C2 Server
to perform any necessary setup. Specifically, the Data Exfiltration command requires
the C2 Server to have an running FTP Server service as well as a folder for
storing any exfiltrated data.
:param given_command: Any C2 Command.
:type given_command: C2Command.
:param command_options: The relevant command parameters.
:type command_options: Dict
:returns: Tuple containing a success bool if the setup was successful and the validated c2 opts.
:rtype: tuple[bool, CommandOpts]
"""
server_setup_success: bool = True
if given_command == C2Command.DATA_EXFILTRATION: # Data exfiltration setup
# Validating command options
command_options = ExfilOpts.model_validate(command_options)
if self._host_ftp_server is None:
self.sys_log.warning(f"{self.name}: Unable to setup the FTP Server for data exfiltration")
server_setup_success = False
if self.get_exfiltration_folder(command_options.exfiltration_folder_name) is None:
self.sys_log.warning(f"{self.name}: Unable to create a folder for storing exfiltration data.")
server_setup_success = False
if given_command == C2Command.TERMINAL:
# Validating command options
command_options = TerminalOpts.model_validate(command_options)
if given_command == C2Command.RANSOMWARE_CONFIGURE:
# Validating command options
command_options = RansomwareOpts.model_validate(command_options)
if given_command == C2Command.RANSOMWARE_LAUNCH:
# Validating command options
command_options = CommandOpts.model_validate(command_options)
return [server_setup_success, command_options]
def _confirm_remote_connection(self, timestep: int) -> bool:
"""Checks the suitability of the current C2 Beacon connection.
Inherited Abstract Method.
If a C2 Server has not received a keep alive within the current set
keep alive frequency (self._keep_alive_frequency) then the C2 beacons
connection is considered dead and any commands will be rejected.
This method is called on each timestep (Called by .apply_timestep)
:param timestep: The current timestep of the simulation.
:type timestep: Int
:return: Returns False if the C2 beacon is considered dead. Otherwise True.
:rtype bool:
"""
if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency:
self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.")
self.sys_log.debug(
f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}"
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}"
f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}"
)
self._reset_c2_connection()
return False
return True
# Abstract method inherited from abstract C2.
# C2 Servers do not currently receive any input commands from the C2 beacon.
def _handle_command_input(self, payload: C2Packet) -> None:
"""Defining this method (Abstract method inherited from abstract C2) in order to instantiate the class.
C2 Servers currently do not receive input commands coming from the C2 Beacons.
:param payload: The incoming C2Packet
:type payload: C2Packet.
"""
self.sys_log.warning(f"{self.name}: C2 Server received an unexpected INPUT payload: {payload}")
pass
def show(self, markdown: bool = False):
"""
Prints a table of the current C2 attributes on a C2 Server.
Displays the current values of the following C2 attributes:
``C2 Connection Active``:
If the C2 Server has established connection with a C2 Beacon.
``C2 Remote Connection``:
The IP of the C2 Beacon. (Configured by upon receiving a keep alive.)
``Current Masquerade Protocol``:
The current protocol that the C2 Traffic is using. (e.g TCP/UDP)
``Current Masquerade Port``:
The current port that the C2 Traffic is using. (e.g HTTP (Port 80))
:param markdown: If True, outputs the table in markdown format. Default is False.
"""
table = PrettyTable(
["C2 Connection Active", "C2 Remote Connection", "Current Masquerade Protocol", "Current Masquerade Port"]
)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.name} Running Status"
table.add_row(
[
self.c2_connection_active,
self.c2_remote_connection,
self.c2_config.masquerade_protocol,
self.c2_config.masquerade_port,
]
)
print(table)

View File

@@ -2,6 +2,8 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -169,3 +171,25 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
else:
self.sys_log.warning("Attack Attempted to launch too quickly")
return False
def show(self, markdown: bool = False):
"""
Prints a table of the current status of the Ransomware Script.
Displays the current values of the following Ransomware Attributes:
``server_ip_address`:
The IP of the target database.
``payload``:
The payload (type of attack) to be sent to the database.
:param markdown: If True, outputs the table in markdown format. Default is False.
"""
table = PrettyTable(["Target Server IP Address", "Payload"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.name} Running Status"
table.add_row([self.server_ip_address, self.payload])
print(table)

View File

@@ -1,9 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from copy import deepcopy
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import RequestType
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -20,9 +22,7 @@ if TYPE_CHECKING:
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
from typing import Type, TypeVar
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
from typing import Type
class SoftwareManager:
@@ -51,7 +51,7 @@ class SoftwareManager:
self.node = parent_node
self.session_manager = session_manager
self.software: Dict[str, Union[Service, Application]] = {}
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {}
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
self.sys_log: SysLog = sys_log
self.file_system: FileSystem = file_system
@@ -77,6 +77,8 @@ class SoftwareManager:
for software in self.port_protocol_mapping.values():
if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}:
open_ports.append(software.port)
if software.listen_on_ports:
open_ports += list(software.listen_on_ports)
return open_ports
def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool:
@@ -104,33 +106,38 @@ class SoftwareManager:
return True
return False
def install(self, software_class: Type[IOSoftwareClass]):
def install(self, software_class: Type[IOSoftware], **install_kwargs):
"""
Install an Application or Service.
:param software_class: The software class.
"""
# TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
# separation of concerns.
if software_class in self._software_class_to_name_map:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
**install_kwargs,
)
software.parent = self.node
if isinstance(software, Application):
software.install()
self.node.applications[software.uuid] = software
self.node._application_request_manager.add_request(
software.name, RequestType(func=software._request_manager)
)
elif isinstance(software, Service):
self.node.services[software.uuid] = software
self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
software.install()
software.software_manager = self
self.software[software.name] = software
self.port_protocol_mapping[(software.port, software.protocol)] = software
if isinstance(software, Application):
software.operating_state = ApplicationOperatingState.CLOSED
# add the software to the node's registry after it has been fully initialized
if isinstance(software, Service):
self.node.install_service(software)
elif isinstance(software, Application):
self.node.install_application(software)
self.node.sys_log.info(f"Installed {software.name}")
def uninstall(self, software_name: str):
"""
@@ -138,25 +145,31 @@ class SoftwareManager:
:param software_name: The software name.
"""
if software_name in self.software:
self.software[software_name].uninstall()
software = self.software.pop(software_name) # noqa
if isinstance(software, Application):
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Uninstalled {software_name}")
if software_name not in self.software:
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
self.software[software_name].uninstall()
software = self.software.pop(software_name) # noqa
if isinstance(software, Application):
self.node.applications.pop(software.uuid)
self.node._application_request_manager.remove_request(software.name)
elif isinstance(software, Service):
self.node.services.pop(software.uuid)
software.uninstall()
self.node._service_request_manager.remove_request(software.name)
software.parent = None
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Uninstalled {software_name}")
return
def send_internal_payload(self, target_software: str, payload: Any):
"""
@@ -213,7 +226,9 @@ class SoftwareManager:
frame: Frame,
):
"""
Receive a payload from the SessionManager and forward it to the corresponding service or application.
Receive a payload from the SessionManager and forward it to the corresponding service or applications.
This function handles both software assigned a specific port, and software listening in on other ports.
:param payload: The payload being received.
:param session: The transport session the payload originates from.
@@ -221,14 +236,25 @@ class SoftwareManager:
if payload.__class__.__name__ == "PortScanPayload":
self.software.get("NMAP").receive(payload=payload, session_id=session_id)
return
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
if receiver:
receiver.receive(
main_receiver = self.port_protocol_mapping.get((port, protocol), None)
if main_receiver:
main_receiver.receive(
payload=payload, session_id=session_id, from_network_interface=from_network_interface, frame=frame
)
else:
listening_receivers = [
software
for software in self.software.values()
if port in software.listen_on_ports and software != main_receiver
]
for receiver in listening_receivers:
receiver.receive(
payload=deepcopy(payload),
session_id=session_id,
from_network_interface=from_network_interface,
frame=frame,
)
if not main_receiver and not listening_receivers:
self.sys_log.warning(f"No service or application found for port {port} and protocol {protocol}")
pass
def show(self, markdown: bool = False):
"""

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -191,12 +191,16 @@ class DatabaseService(Service):
:return: Response to connection request containing success info.
:rtype: Dict[str, Union[int, Dict[str, bool]]]
"""
self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}")
status_code = 500 # Default internal server error
connection_id = None
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at "
f"capacity."
)
if self.health_state_actual in [
SoftwareHealthState.GOOD,
SoftwareHealthState.FIXING,
@@ -208,12 +212,16 @@ class DatabaseService(Service):
# try to create connection
if not self.add_connection(connection_id=connection_id, session_id=session_id):
status_code = 500
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
else:
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, "
f"returning status code 500"
)
else:
status_code = 401 # Unauthorised
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised "
f"(incorrect password), returning status code 401"
)
else:
status_code = 404 # service not found
return {
@@ -377,6 +385,8 @@ class DatabaseService(Service):
)
else:
result = {"status_code": 401, "type": "sql"}
else:
self.sys_log.info(f"{self.name}: Ignoring payload as it is not a Database payload")
self.send(payload=result, session_id=session_id)
return True

View File

@@ -1,8 +1,10 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Optional
from typing import Dict, Optional
from primaite import getLogger
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -28,6 +30,58 @@ class FTPClient(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
def _send_data_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""
Request for sending data via the ftp_client using the request options parameters.
:param request: Request with one element containing a dict of parameters for the send method.
:type request: RequestFormat
:param context: additional context for resolving this action, currently unused
:type context: dict
:return: RequestResponse object with a success code reflecting whether the configuration could be applied.
:rtype: RequestResponse
"""
dest_ip = request[-1].get("dest_ip_address")
dest_ip = None if dest_ip is None else IPv4Address(dest_ip)
# Missing FTP Options results is an automatic failure.
src_folder = request[-1].get("src_folder_name", None)
src_file_name = request[-1].get("src_file_name", None)
dest_folder = request[-1].get("dest_folder_name", None)
dest_file_name = request[-1].get("dest_file_name", None)
if not self.file_system.access_file(folder_name=src_folder, file_name=src_file_name):
self.sys_log.debug(
f"{self.name}: Received a FTP Request to transfer file: {src_file_name} to Remote IP: {dest_ip}."
)
return RequestResponse(
status="failure",
data={
"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"
},
)
return RequestResponse.from_bool(
self.send_file(
dest_ip_address=dest_ip,
src_folder_name=src_folder,
src_file_name=src_file_name,
dest_folder_name=dest_folder,
dest_file_name=dest_file_name,
)
)
rm.add_request("send", request_type=RequestType(func=_send_data_request)),
return rm
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,545 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
from pydantic import BaseModel
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.ssh import (
SSHConnectionMessage,
SSHPacket,
SSHTransportMessage,
SSHUserCredentials,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor
# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication
class TerminalClientConnection(BaseModel):
"""
TerminalClientConnection Class.
This class is used to record current User Connections to the Terminal class.
"""
parent_terminal: Terminal
"""The parent Node that this connection was created on."""
ssh_session_id: str = None
"""Session ID that connection is linked to, used for sending commands via session manager."""
connection_uuid: str = None
"""Connection UUID"""
connection_request_id: str = None
"""Connection request ID"""
time: datetime = None
"""Timestamp connection was created."""
ip_address: IPv4Address
"""Source IP of Connection"""
is_active: bool = True
"""Flag to state whether the connection is active or not"""
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ip_address: {self.ip_address}')"
def __repr__(self) -> str:
return self.__str__()
def __getitem__(self, key: Any) -> Any:
return getattr(self, key)
@property
def client(self) -> Optional[Terminal]:
"""The Terminal that holds this connection."""
return self.parent_terminal
def disconnect(self) -> bool:
"""Disconnect the session."""
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
@abstractmethod
def execute(self, command: Any) -> bool:
"""Execute a given command."""
pass
class LocalTerminalConnection(TerminalClientConnection):
"""
LocalTerminalConnectionClass.
This class represents a local terminal when connected.
"""
ip_address: str = "Local Connection"
def execute(self, command: Any) -> Optional[RequestResponse]:
"""Execute a given command on local Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
return None
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return None
return self.parent_terminal.execute(command)
class RemoteTerminalConnection(TerminalClientConnection):
"""
RemoteTerminalConnection Class.
This class acts as broker between the terminal and remote.
"""
def execute(self, command: Any) -> bool:
"""Execute a given command on the remote Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
return False
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return False
# Send command to remote terminal to process.
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload: SSHPacket = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
connection_request_uuid=self.connection_request_id,
connection_uuid=self.connection_uuid,
ssh_command=command,
)
return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id)
class Terminal(Service):
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
"""Dictionary of connect requests made to remote nodes."""
def __init__(self, **kwargs):
kwargs["name"] = "Terminal"
kwargs["port"] = Port.SSH
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
return state
def show(self, markdown: bool = False):
"""
Display the remote connections to this terminal instance in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
self.show_connections(markdown=markdown)
def _init_request_manager(self) -> RequestManager:
"""Initialise Request manager."""
rm = super()._init_request_manager()
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
if login:
return RequestResponse(
status="success",
data={
"ip_address": str(login.ip_address),
"username": request[0],
},
)
else:
return RequestResponse(status="failure", data={})
rm.add_request(
"ssh_to_remote",
request_type=RequestType(func=_remote_login),
)
def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from remote connection."""
ip_address = IPv4Address(request[0])
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = self._disconnect(remote_connection.connection_uuid)
if outcome:
return RequestResponse(status="success", data={})
return RequestResponse(status="failure", data={})
rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff))
def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Execute an instruction."""
ip_address: IPv4Address = IPv4Address(request[0])
command: str = request[1]["command"]
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = remote_connection.execute(command)
if outcome:
return RequestResponse(
status="success",
data={},
)
else:
return RequestResponse(
status="failure",
data={},
)
rm.add_request(
"send_remote_command",
request_type=RequestType(func=remote_execute_request),
)
return rm
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
"""Execute a passed ssh command via the request manager."""
return self.parent.apply_request(command)
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
"""Find Remote Terminal Connection from a given IP."""
for connection in self._connections.values():
if connection.ip_address == ip_address:
return connection
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
"""Create a new connection object and amend to list of active connections.
:param connection_uuid: Connection ID of the new local connection
:param session_id: Session ID of the new local connection
:return: TerminalClientConnection object
"""
new_connection = LocalTerminalConnection(
parent_terminal=self,
connection_uuid=connection_uuid,
ssh_session_id=session_id,
time=datetime.now(),
)
self._connections[connection_uuid] = new_connection
self._client_connection_requests[connection_uuid] = new_connection
return new_connection
def login(
self, username: str, password: str, ip_address: Optional[IPv4Address] = None
) -> Optional[TerminalClientConnection]:
"""Login to the terminal. Will attempt a remote login if ip_address is given, else local.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt. If None, login is assumed local.
:type: ip_address: Optional[IPv4Address]
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"{self.name}: Cannot login as service is not running.")
return None
if ip_address:
# Assuming that if IP is passed we are connecting to remote
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
else:
return self._process_local_login(username=username, password=password)
def _process_local_login(self, username: str, password: str) -> Optional[TerminalClientConnection]:
"""Local session login to terminal.
:param username: Username for login.
:param password: Password for login.
:return: boolean, True if successful, else False
"""
# TODO: Un-comment this when UserSessionManager is merged.
connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password)
if connection_uuid:
self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}")
# Add new local session to list of connections and return
return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection")
else:
self.sys_log.warning(f"{self.name}: Login failed, incorrect Username or Password")
return None
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return connection_id in self._client_connection_requests
def _check_client_connection(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id):
self._disconnect(connection_id)
return False
return connection_id in self._connections
def _send_remote_login(
self,
username: str,
password: str,
ip_address: IPv4Address,
connection_request_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Optional[RemoteTerminalConnection]:
"""Send a remote login attempt and connect to Node.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt.
:type: ip_address: IPv4Address
:param: connection_request_id: Connection Request ID, if not provided, a new one is generated
:type: connection_request_id: Optional[str]
:param: is_reattempt: True if the request has been reattempted. Default False.
:type: is_reattempt: Optional[bool]
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
"""
connection_request_id = connection_request_id or str(uuid4())
if is_reattempt:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
remote_terminal_connection = self._client_connection_requests.pop(connection_request_id)
if isinstance(remote_terminal_connection, RemoteTerminalConnection):
self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.")
return remote_terminal_connection
else:
self.sys_log.warning(f"{self.name}: Connection request {connection_request_id} declined")
return None
else:
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
return None
self.sys_log.info(
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
)
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
payload_contents = {
"type": "login_request",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
}
payload: SSHPacket = SSHPacket(
payload=payload_contents,
transport_message=transport_message,
connection_message=connection_message,
user_account=user_details,
connection_request_uuid=connection_request_id,
)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=ip_address, dest_port=self.port
)
return self._send_remote_login(
username=username,
password=password,
ip_address=ip_address,
is_reattempt=True,
connection_request_id=connection_request_id,
)
def _create_remote_connection(
self, connection_id: str, connection_request_id: str, session_id: str, source_ip: str
) -> None:
"""Create a new TerminalClientConnection Object.
:param: connection_request_id: Connection Request ID
:type: connection_request_id: str
:param: session_id: Session ID of connection.
:type: session_id: str
"""
client_connection = RemoteTerminalConnection(
parent_terminal=self,
ssh_session_id=session_id,
connection_uuid=connection_id,
ip_address=source_ip,
connection_request_id=connection_request_id,
time=datetime.now(),
)
self._connections[connection_id] = client_connection
self._client_connection_requests[connection_request_id] = client_connection
def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
source_ip = kwargs["frame"].ip.src_ip_address
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
if isinstance(payload, SSHPacket):
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
# validate & add connection
# TODO: uncomment this as part of 2781
username = payload.user_account.username
password = payload.user_account.password
connection_id = self.parent.user_session_manager.remote_login(
username=username, password=password, remote_ip_address=source_ip
)
if connection_id:
connection_request_id = payload.connection_request_uuid
self.sys_log.info(f"{self.name}: Connection authorised, session_id: {session_id}")
self._create_remote_connection(
connection_id=connection_id,
connection_request_id=connection_request_id,
session_id=session_id,
source_ip=source_ip,
)
transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload_contents = {
"type": "login_success",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
"connection_id": connection_id,
}
payload: SSHPacket = SSHPacket(
payload=payload_contents,
transport_message=transport_message,
connection_message=connection_message,
connection_request_uuid=connection_request_id,
connection_uuid=connection_id,
)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_port=self.port, session_id=session_id
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
self.sys_log.info(f"{self.name}: Login Successful")
self._create_remote_connection(
connection_id=payload.connection_uuid,
connection_request_id=payload.connection_request_uuid,
session_id=session_id,
source_ip=source_ip,
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
# Requesting a command to be executed
self.sys_log.info(f"{self.name}: Received command to execute")
command = payload.ssh_command
valid_connection = self._check_client_connection(payload.connection_uuid)
if valid_connection:
remote_session = self.software_manager.node.user_session_manager.remote_sessions.get(
payload.connection_uuid
)
remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep
self.execute(command)
return True
else:
self.sys_log.error(
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
)
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "disconnect":
connection_id = payload["connection_id"]
valid_id = self._check_client_connection(connection_id)
if valid_id:
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.")
self._disconnect(payload["connection_id"])
self.parent.user_session_manager.remote_logout(remote_session_id=connection_id)
else:
self.sys_log.error(f"{self.name}: No Active connection held for received connection ID.")
if payload["type"] == "user_timeout":
connection_id = payload["connection_id"]
valid_id = connection_id in self._connections
if valid_id:
connection = self._connections.pop(connection_id)
connection.is_active = False
self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.")
else:
self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.")
return True
def _disconnect(self, connection_uuid: str) -> bool:
"""Disconnect connection.
:param connection_uuid: Connection ID that we want to disconnect.
:return True if successful, False otherwise.
"""
# TODO: Handle the possibility of attempting to disconnect
if not self._connections:
self.sys_log.warning(f"{self.name}: No remote connection present")
return False
connection = self._connections.pop(connection_uuid, None)
if not connection:
return False
connection.is_active = False
if isinstance(connection, RemoteTerminalConnection):
# Send disconnect command via software manager
session_id = connection.ssh_session_id
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_uuid},
dest_port=self.port,
session_id=session_id,
)
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
return True
elif isinstance(connection, LocalTerminalConnection):
self.parent.user_session_manager.local_logout()
return True
def send(
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None
) -> bool:
"""
Send a payload out from the Terminal.
:param payload: The payload to be sent.
:param dest_up_address: The IP address of the payload destination.
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"{self.name}: Cannot send commands when Operating state is {self.operating_state}!")
return False
self.sys_log.debug(f"{self.name}: Sending payload: {payload}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
)

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from primaite import getLogger
@@ -22,7 +22,7 @@ _LOGGER = getLogger(__name__)
class WebServer(Service):
"""Class used to represent a Web Server Service in simulation."""
last_response_status_code: Optional[HttpStatusCode] = None
response_codes_this_timestep: List[HttpStatusCode] = []
def describe_state(self) -> Dict:
"""
@@ -34,11 +34,19 @@ class WebServer(Service):
:rtype: Dict
"""
state = super().describe_state()
state["last_response_status_code"] = (
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
)
state["response_codes_this_timestep"] = [code.value for code in self.response_codes_this_timestep]
return state
def pre_timestep(self, timestep: int) -> None:
"""
Logic to execute at the start of the timestep - clear the observation-related attributes.
:param timestep: the current timestep in the episode.
:type timestep: int
"""
self.response_codes_this_timestep = []
return super().pre_timestep(timestep)
def __init__(self, **kwargs):
kwargs["name"] = "WebServer"
kwargs["protocol"] = IPProtocol.TCP
@@ -89,7 +97,7 @@ class WebServer(Service):
self.send(payload=response, session_id=session_id)
# return true if response is OK
self.last_response_status_code = response.status_code
self.response_codes_this_timestep.append(response.status_code)
return response.status_code == HttpStatusCode.OK
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:

Some files were not shown because too many files have changed in this diff Show More