#2887 - Resolve conflicts from merge

This commit is contained in:
Charlie Crane
2025-01-23 09:17:27 +00:00
174 changed files with 7047 additions and 8412 deletions

View File

@@ -14,31 +14,36 @@ parameters:
- name: matrix
type: object
default:
# - job_name: 'UbuntuPython38'
# py: '3.8'
# img: 'ubuntu-latest'
# every_time: false
# publish_coverage: false
- job_name: 'UbuntuPython311'
py: '3.11'
- job_name: 'UbuntuPython39'
py: '3.9'
img: 'ubuntu-latest'
every_time: false
publish_coverage: false
- job_name: 'UbuntuPython310'
py: '3.10'
img: 'ubuntu-latest'
every_time: true
publish_coverage: true
# - job_name: 'WindowsPython38'
# py: '3.8'
# img: 'windows-latest'
# every_time: false
# publish_coverage: false
- job_name: 'UbuntuPython311'
py: '3.11'
img: 'ubuntu-latest'
every_time: false
publish_coverage: false
- job_name: 'WindowsPython39'
py: '3.9'
img: 'windows-latest'
every_time: false
publish_coverage: false
- job_name: 'WindowsPython311'
py: '3.11'
img: 'windows-latest'
every_time: false
publish_coverage: false
# - job_name: 'MacOSPython38'
# py: '3.8'
# img: 'macOS-latest'
# every_time: false
# publish_coverage: false
- job_name: 'MacOSPython39'
py: '3.9'
img: 'macOS-latest'
every_time: false
publish_coverage: false
- job_name: 'MacOSPython311'
py: '3.11'
img: 'macOS-latest'
@@ -63,7 +68,7 @@ stages:
displayName: 'Use Python ${{ item.py }}'
- script: |
python -m pip install pre-commit
python -m pip install pre-commit>=6.1
pre-commit install
pre-commit run --all-files
displayName: 'Run pre-commits'
@@ -71,7 +76,6 @@ stages:
- script: |
python -m pip install --upgrade pip==23.0.1
pip install wheel==0.38.4 --upgrade
pip install setuptools==66 --upgrade
pip install build==0.10.0
pip install pytest-azurepipelines
displayName: 'Install build dependencies'

View File

@@ -1,10 +1,10 @@
repos:
- repo: local
hooks:
- id: ensure-copyright-clause
name: ensure copyright clause
entry: python copyright_clause_pre_commit_hook.py
language: python
# - repo: local
# hooks:
# - id: ensure-copyright-clause
# name: ensure copyright clause
# entry: python copyright_clause_pre_commit_hook.py
# language: python
- repo: http://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
@@ -31,7 +31,7 @@ repos:
- id: isort
args: [ "--profile", "black" ]
- repo: http://github.com/PyCQA/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies:

View File

@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [4.0.0] = TBC
### Added
### Changed
- Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty.
- Actions within PrimAITE are now extensible, allowing for plugin support.
- Added a config schema to `ObservationManager`, `ActionManager`, and `RewardFunction`.
- Streamlined the way agents are created from config
- Agent config no longer requires a dummy action space if the action space is empty, the same applies for observation space and reward function
- Actions now support a config schema, to allow yaml data validation and default parameter values
- Action parameters are no longer defined through IDs, instead meaningful data is provided directly in the action map
- Test and example YAMLs have been updated to match the new agent and action schemas, such as:
- Removed empty action spaces, observation spaces, or reward spaces for agent which didn't use them
- Relabeled action parameters to match the new action config schemas, and updated the values to no longer rely on indices
- Removed action space options which were previously used for assigning meaning to action space IDs
- Updated tests that don't use YAMLs to still use the new action and agent schemas
## [3.3.0] - 2024-09-04
### Added

View File

@@ -70,7 +70,7 @@ PrimAITE incorporates the following features:
- Architected with a separate Simulation layer and Game layer. This separation of concerns defines a clear path towards transfer learning with environments of differing fidelity;
- Ability to reconfigure an RL reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure success for green agents;
- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / ALLOW, source / destination IP addresses, protocol and port);
- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / PERMIT, source / destination IP addresses, protocol and port);
- Application of traffic to the links of the system laydown adheres to the ACL rulesets and routing tables contained within each network device;
- Provides RL environments adherent to the Farama Foundation Gymnasium (Previously OpenAI Gym) API, allowing integration with any compliant RL Agent frameworks;
- Provides RL environments adherent to Ray RLlib environment specifications for single-agent and multi-agent scenarios;

View File

@@ -184,7 +184,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
- 192.168.1.5
- ANY
- ANY
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or ALLOW for the same parameters will override an earlier entry.
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or PERMIT for the same parameters will override an earlier entry.
Observation Spaces
******************
The observation space provides the blue agent with information about the current status of nodes and links.
@@ -331,7 +331,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
* [0, 1] - Permission (0 = DENY, 1 = PERMIT)
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)

View File

@@ -23,123 +23,117 @@ The following logic is applied:
+------------------------------------------+---------------------------------------------------------------------+
| Action | Action Mask Logic |
+==========================================+=====================================================================+
| **DONOTHING** | Always Possible. |
| **do_nothing** | Always Possible. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_SCAN** | Node is on. Service is running. |
| **node_service_scan** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_STOP** | Node is on. Service is running. |
| **node_service_stop** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_START** | Node is on. Service is stopped. |
| **node_service_start** | Node is on. Service is stopped. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_PAUSE** | Node is on. Service is running. |
| **node_service_pause** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_RESUME** | Node is on. Service is paused. |
| **node_service_resume** | Node is on. Service is paused. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_RESTART** | Node is on. Service is running. |
| **node_service_restart** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_DISABLE** | Node is on. |
| **node_service_disable** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_ENABLE** | Node is on. Service is disabled. |
| **node_service_enable** | Node is on. Service is disabled. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SERVICE_FIX** | Node is on. Service is running. |
| **node_service_fix** | Node is on. Service is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_APPLICATION_EXECUTE** | Node is on. |
| **node_application_execute** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_APPLICATION_SCAN** | Node is on. Application is running. |
| **node_application_scan** | Node is on. Application is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_APPLICATION_CLOSE** | Node is on. Application is running. |
| **node_application_close** | Node is on. Application is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_APPLICATION_FIX** | Node is on. Application is running. |
| **node_application_fix** | Node is on. Application is running. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_APPLICATION_INSTALL** | Node is on. |
| **node_application_install** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_APPLICATION_REMOVE** | Node is on. |
| **node_application_remove** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_SCAN** | Node is on. File exists. File not deleted. |
| **node_file_scan** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_CREATE** | Node is on. |
| **node_file_create** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_CHECKHASH** | Node is on. File exists. File not deleted. |
| **node_file_checkhash** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_DELETE** | Node is on. File exists. |
| **node_file_delete** | Node is on. File exists. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_REPAIR** | Node is on. File exists. File not deleted. |
| **node_file_repair** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_RESTORE** | Node is on. File exists. File is deleted. |
| **node_file_restore** | Node is on. File exists. File is deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_CORRUPT** | Node is on. File exists. File not deleted. |
| **node_file_corrupt** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FILE_ACCESS** | Node is on. File exists. File not deleted. |
| **node_file_access** | Node is on. File exists. File not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FOLDER_CREATE** | Node is on. |
| **node_folder_create** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FOLDER_SCAN** | Node is on. Folder exists. Folder not deleted. |
| **node_folder_scan** | Node is on. Folder exists. Folder not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FOLDER_CHECKHASH** | Node is on. Folder exists. Folder not deleted. |
| **node_folder_checkhash** | Node is on. Folder exists. Folder not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FOLDER_REPAIR** | Node is on. Folder exists. Folder not deleted. |
| **node_folder_repair** | Node is on. Folder exists. Folder not deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_FOLDER_RESTORE** | Node is on. Folder exists. Folder is deleted. |
| **node_folder_restore** | Node is on. Folder exists. Folder is deleted. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_OS_SCAN** | Node is on. |
| **node_os_scan** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **HOST_NIC_ENABLE** | NIC is disabled. Node is on. |
| **host_nic_enable** | NIC is disabled. Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **HOST_NIC_DISABLE** | NIC is enabled. Node is on. |
| **host_nic_disable** | NIC is enabled. Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SHUTDOWN** | Node is on. |
| **node_shutdown** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_STARTUP** | Node is off. |
| **node_startup** | Node is off. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_RESET** | Node is on. |
| **node_reset** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_NMAP_PING_SCAN** | Node is on. |
| **node_nmap_ping_scan** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_NMAP_PORT_SCAN** | Node is on. |
| **node_nmap_port_scan** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. |
| **node_network_service_recon** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NETWORK_PORT_ENABLE** | Node is on. Router is on. |
| **network_port_enable** | Node is on. Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NETWORK_PORT_DISABLE** | Router is on. |
| **network_port_disable** | Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **ROUTER_ACL_ADDRULE** | Router is on. |
| **router_acl_addrule** | Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **ROUTER_ACL_REMOVERULE** | Router is on. |
| **router_acl_removerule** | Router is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **FIREWALL_ACL_ADDRULE** | Firewall is on. |
| **firewall_acl_addrule** | Firewall is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **FIREWALL_ACL_REMOVERULE** | Firewall is on. |
| **firewall_acl_removerule** | Firewall is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_NMAP_PING_SCAN** | Node is on. |
| **configure_database_client** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_NMAP_PORT_SCAN** | Node is on. |
| **configure_ransomware_script** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. |
| **c2_server_ransomware_configure** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_DATABASE_CLIENT** | Node is on. |
| **configure_dos_bot** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_RANSOMWARE_SCRIPT** | Node is on. |
| **configure_c2_beacon** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_DOSBOT** | Node is on. |
| **c2_server_ransomware_launch** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **CONFIGURE_C2_BEACON** | Node is on. |
| **c2_server_terminal_command** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_RANSOMWARE_LAUNCH** | Node is on. |
| **c2_server_data_exfiltrate** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_RANSOMWARE_CONFIGURE** | Node is on. |
| **node_account_change_password** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_TERMINAL_COMMAND** | Node is on. |
| **node_session_remote_login** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **C2_SERVER_DATA_EXFILTRATE** | Node is on. |
| **node_session_remote_logoff** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_ACCOUNTS_CHANGE_PASSWORD** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **SSH_TO_REMOTE** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **SESSIONS_REMOTE_LOGOFF** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+
| **NODE_SEND_REMOTE_COMMAND** | Node is on. |
| **node_send_remote_command** | Node is on. |
+------------------------------------------+---------------------------------------------------------------------+

View File

@@ -23,19 +23,6 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
@@ -91,10 +78,6 @@ For more information see :py:mod:`primaite.game.agent.observations`
The action space is configured to be made up of individual action types. Once configured, the agent can select an action type and some optional action parameters at every step. For example: The ``NODE_SERVICE_SCAN`` action takes the parameters ``node_id`` and ``service_id``.
``action_list``
^^^^^^^^^^^^^^^
A list of action modules. The options are listed in the :py:mod:`primaite.game.agent.actions.ActionManager.act_class_identifiers` module.
``action_map``
^^^^^^^^^^^^^^

View File

@@ -0,0 +1,67 @@
.. only:: comment
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
Extensible Actions
******************
Changes to Actions class Structure.
===================================
Actions within PrimAITE have been updated to inherit from a base class, AbstractAction, standardising their format and allowing for easier creation of custom actions. Actions now use a ``ConfigSchema`` to define the possible configuration variables, and use pydantic to enforce correct parameters are passed through.
Developing Custom Actions.
==========================
Custom actions within PrimAITE must be a sub-class of `AbstractAction`, and contain 3 key items:
#. ConfigSchema class
#. Unique Identifier
#. `form_request` method.
ConfigSchema
############
The ConfigSchema sub-class of the action must contain all `configurable` variables within the action, that would be specified within the environments configuration YAML file.
Unique Identifier
#################
When declaring a custom class, it must have a unique identifier string, that allows PrimAITE to generate the correct action when needed.
.. code:: Python
class CreateDirectoryAction(AbstractAction, identifier="node_folder_create")
config: CreateDirectoryAction.ConfigSchema
class ConfigSchema(AbstractAction.ConfigSchema):
verb: ClassVar[str] = "create"
node_name: str
directory_name: str
def form_request(cls, config: ConfigSchema) -> RequestFormat:
return ["network",
"node",
config.node_name,
"file_system",
config.verb,
"folder",
config.directory_name,
]
The above action would fail pydantic validation as the identifier "node_folder_create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`.
form_request method
###################
PrimAITE actions need to have a `form_request` method, which can be passed to the `RequestManager` for processing. This allows the custom action to be actioned within the simulation environment.

View File

@@ -0,0 +1,78 @@
.. only:: comment
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
.. _about:
Extensible Agents
*****************
Agents defined within PrimAITE have been updated to allow for easier creation of new bespoke agents for use in custom environments.
Developing Agents for PrimAITE
==============================
All agent types within PrimAITE must be subclassed from ``AbstractAgent`` in order to be used from configuration YAML files. This then allows you to implement any custom agent logic for the new agent in your training scenario. Examples of implementing custom agent logic can be seen in pre-existing agents, such as the ``DataManipulationBot`` and ``RandomAgent``.
The core features that should be implemented in any new agent are detailed below:
#. **ConfigSchema**:
Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*, these would be listed seperately.
Agent generation will fail pydantic checks if incorrect or invalid parameters are passed to the ConfigSchema of the chosen Agent.
.. code-block:: python
class ExampleAgent(AbstractAgent, identifier = "ExampleAgent"):
"""An example agent for demonstration purposes."""
config: "ExampleAgent.ConfigSchema" = Field(default_factory= lambda: ExampleAgent.ConfigSchema())
"""Agent configuration"""
num_executions: int = 0
"""Number of action executions by agent"""
class ConfigSchema(AbstractAgent.ConfigSchema):
"""ExampleAgent configuration schema"""
type: str = "ExampleAgent
"""Name of agent"""
starting_host: int
"""Host node that this agent should start from in the given environment."""
.. code-block:: yaml
- ref: example_green_agent
team: GREEN
type: ExampleAgent
action_space:
action_map:
0:
action: do_nothing
options: {}
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 25
frequency: 20
variance: 5
starting_host: "Server_1"
#. **Identifiers**:
All agent classes should have an ``identifier`` attribute, a unique kebab-case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent.
Changes to YAML file
====================
PrimAITE v4.0.0 introduces some breaking changes to how environment configuration yaml files are created. YAML files created for Primaite versions 3.3.0 should be compatible through a translation function, though it is encouraged that these are updated to reflect the updated format of 4.0.0+.
Agents now follow a more standardised settings definition, so should be more consistent across YAML files and the available agent types with PrimAITE.
All configurable items for agents sit under the ``agent_settings`` heading within your YAML files. There is no need for the inclusion of a ``start_settings``. Please see the above YAML example for full changes to agents.

View File

@@ -2,44 +2,44 @@
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| Name | Version | License | Description | URL |
+===================+=========+====================================+=======================================================================================================+====================================================================+
| gymnasium | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| ipywidgets | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| jupyterlab | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| kaleido | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| matplotlib | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| networkx | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| numpy | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| platformdirs | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| plotly | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| polars | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| prettytable | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| pydantic | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| PyYAML | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| ray | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| stable-baselines3 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| tensorflow | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| typer | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| Deepdiff | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
| sb3_contrib | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking)| https://github.com/Stable-Baselines-Team/stable-baselines3-contrib |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| Name | Supported Version | Built Version | License | Description | URL |
+===================+=====================+===============+======================================+========================================================================================================+=====================================================================+
| gymnasium | 0.28.1 | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| ipywidgets | ~=8.0 | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| jupyterlab | 3.6.1 | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| kaleido | ==0.2.1 | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| matplotlib | >=3.7.1 | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| networkx | 3.1 | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| numpy | ~1.23 | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| platformdirs | 3.5.1 | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| plotly | 5.15 | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| polars | 0.20.30 | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| prettytable | 3.8.0 | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| pydantic | 2.7.0 | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| PyYAML | >=6.0 | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| ray | >=2.20, <2.33 | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| stable-baselines3 | 2.1.0 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| tensorflow | ~=2.12 | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| typer | >=0.9 | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| Deepdiff | 8.0.1 | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
| sb3_contrib | 2.1.0 | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking) | https://github.com/Stable-Baselines-Team/stable-baselines3-contrib |
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+

View File

@@ -113,18 +113,6 @@ If not using the data manipulation bot manually, it needs to be used with a data
folders: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY

View File

@@ -70,7 +70,7 @@ Python
Configuration
=============
The RansomwareScript inherits configuration options such as ``fix_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
The RansomwareScript inherits configuration options such as ``fixing_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
``server_ip``

View File

@@ -22,8 +22,8 @@ options
The configuration options are the attributes that fall under the options for an application or service.
fix_duration
""""""""""""
fixing_duration
"""""""""""""""
Optional. Default value is ``2``.

View File

@@ -7,7 +7,7 @@ name = "primaite"
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}]
license = {file = "LICENSE"}
requires-python = ">=3.9, <3.12"
requires-python = ">=3.9, <3.13"
dynamic = ["version", "readme"]
classifiers = [
"Development Status :: 5 - Production/Stable",
@@ -26,15 +26,15 @@ dependencies = [
"gymnasium==0.28.1",
"jupyterlab==3.6.1",
"kaleido==0.2.1",
"matplotlib==3.7.1",
"matplotlib>=3.7.1",
"networkx==3.1",
"numpy==1.23.5",
"numpy~=1.23",
"platformdirs==3.5.1",
"plotly==5.15.0",
"polars==0.20.30",
"prettytable==3.8.0",
"PyYAML==6.0",
"typer[all]==0.9.0",
"PyYAML>=6.0",
"typer[all]>=0.9",
"pydantic==2.7.0",
"ipywidgets",
"deepdiff"
@@ -53,8 +53,8 @@ license-files = ["LICENSE"]
[project.optional-dependencies]
rl = [
"ray[rllib] >= 2.20.0, <2.33",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
"tensorflow~=2.12",
"stable-baselines3==2.1.0",
"sb3-contrib==2.1.0",
]
dev = [
@@ -69,7 +69,7 @@ dev = [
"pytest-xdist==3.3.1",
"pytest-cov==4.0.0",
"pytest-flake8==1.1.1",
"setuptools==66",
"setuptools==75.6.0",
"Sphinx==7.1.2",
"sphinx-copybutton==0.5.2",
"wheel==0.38.4",

View File

@@ -1,4 +1,4 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import subprocess
import sys
from typing import Any

View File

@@ -30,35 +30,22 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_APPLICATION_EXECUTE
action: node_application_execute
options:
node_id: 0
application_id: 0
node_name: client_2
application_name: WebBrowser
2:
action: NODE_APPLICATION_EXECUTE
action: node_application_execute
options:
node_id: 0
application_id: 1
node_name: client_2
application_name: DatabaseClient
reward_function:
reward_components:
@@ -79,35 +66,22 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_APPLICATION_EXECUTE
action: node_application_execute
options:
node_id: 0
application_id: 0
node_name: client_1
application_name: WebBrowser
2:
action: NODE_APPLICATION_EXECUTE
action: node_application_execute
options:
node_id: 0
application_id: 1
node_name: client_1
application_name: WebBrowser
reward_function:
reward_components:
@@ -128,33 +102,12 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
applications:
- application_name: DataManipulationBot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
agent_settings:
possible_start_nodes: [client_1, client_2]
target_application: DataManipulationBot
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -208,8 +161,8 @@ agents:
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
- HTTP
- POSTGRES_SERVER
protocol_list:
- ICMP
- TCP
@@ -235,490 +188,426 @@ agents:
options: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
action: node_service_scan
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
# stop webapp service
2:
action: NODE_SERVICE_STOP
action: node_service_stop
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
# start webapp service
3:
action: "NODE_SERVICE_START"
action: "node_service_start"
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
4:
action: "NODE_SERVICE_PAUSE"
action: "node_service_pause"
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
5:
action: "NODE_SERVICE_RESUME"
action: "node_service_resume"
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
6:
action: "NODE_SERVICE_RESTART"
action: "node_service_restart"
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
7:
action: "NODE_SERVICE_DISABLE"
action: "node_service_disable"
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
8:
action: "NODE_SERVICE_ENABLE"
action: "node_service_enable"
options:
node_id: 1
service_id: 0
node_name: web_server
service_name: WebServer
9: # check database.db file
action: "NODE_FILE_SCAN"
action: "node_file_scan"
options:
node_id: 2
folder_id: 0
file_id: 0
node_name: database_server
folder_name: database
file_name: database.db
10:
action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
action: "node_file_scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
file_id: 0
node_name: database_server
folder_name: database
file_name: database.db
11:
action: "NODE_FILE_DELETE"
action: "node_file_delete"
options:
node_id: 2
folder_id: 0
file_id: 0
node_name: database_server
folder_name: database
file_name: database.db
12:
action: "NODE_FILE_REPAIR"
action: "node_file_repair"
options:
node_id: 2
folder_id: 0
file_id: 0
node_name: database_server
folder_name: database
file_name: database.db
13:
action: "NODE_SERVICE_FIX"
action: "node_service_fix"
options:
node_id: 2
service_id: 0
node_name: database_server
service_name: DatabaseService
14:
action: "NODE_FOLDER_SCAN"
action: "node_folder_scan"
options:
node_id: 2
folder_id: 0
node_name: database_server
folder_name: database
15:
action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
action: "node_folder_scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
options:
node_id: 2
folder_id: 0
node_name: database_server
folder_name: database
16:
action: "NODE_FOLDER_REPAIR"
action: "node_folder_repair"
options:
node_id: 2
folder_id: 0
node_name: database_server
folder_name: database
17:
action: "NODE_FOLDER_RESTORE"
action: "node_folder_restore"
options:
node_id: 2
folder_id: 0
node_name: database_server
folder_name: database
18:
action: "NODE_OS_SCAN"
action: "node_os_scan"
options:
node_id: 0
node_name: domain_controller
19:
action: "NODE_SHUTDOWN"
action: "node_shutdown"
options:
node_id: 0
node_name: domain_controller
20:
action: NODE_STARTUP
action: node_startup
options:
node_id: 0
node_name: domain_controller
21:
action: NODE_RESET
action: node_reset
options:
node_id: 0
node_name: domain_controller
22:
action: "NODE_OS_SCAN"
action: "node_os_scan"
options:
node_id: 1
node_name: web_server
23:
action: "NODE_SHUTDOWN"
action: "node_shutdown"
options:
node_id: 1
node_name: web_server
24:
action: NODE_STARTUP
action: node_startup
options:
node_id: 1
node_name: web_server
25:
action: NODE_RESET
action: node_reset
options:
node_id: 1
node_name: web_server
26: # old action num: 18
action: "NODE_OS_SCAN"
action: "node_os_scan"
options:
node_id: 2
node_name: database_server
27:
action: "NODE_SHUTDOWN"
action: "node_shutdown"
options:
node_id: 2
node_name: database_server
28:
action: NODE_STARTUP
action: node_startup
options:
node_id: 2
node_name: database_server
29:
action: NODE_RESET
action: node_reset
options:
node_id: 2
node_name: database_server
30:
action: "NODE_OS_SCAN"
action: "node_os_scan"
options:
node_id: 3
node_name: backup_server
31:
action: "NODE_SHUTDOWN"
action: "node_shutdown"
options:
node_id: 3
node_name: backup_server
32:
action: NODE_STARTUP
action: node_startup
options:
node_id: 3
node_name: backup_server
33:
action: NODE_RESET
action: node_reset
options:
node_id: 3
node_name: backup_server
34:
action: "NODE_OS_SCAN"
action: "node_os_scan"
options:
node_id: 4
node_name: security_suite
35:
action: "NODE_SHUTDOWN"
action: "node_shutdown"
options:
node_id: 4
node_name: security_suite
36:
action: NODE_STARTUP
action: node_startup
options:
node_id: 4
node_name: security_suite
37:
action: NODE_RESET
action: node_reset
options:
node_id: 4
node_name: security_suite
38:
action: "NODE_OS_SCAN"
action: "node_os_scan"
options:
node_id: 5
node_name: client_1
39: # old action num: 19 # shutdown client 1
action: "NODE_SHUTDOWN"
action: "node_shutdown"
options:
node_id: 5
node_name: client_1
40: # old action num: 20
action: NODE_STARTUP
action: node_startup
options:
node_id: 5
node_name: client_1
41: # old action num: 21
action: NODE_RESET
action: node_reset
options:
node_id: 5
node_name: client_1
42:
action: "NODE_OS_SCAN"
action: "node_os_scan"
options:
node_id: 6
node_name: client_2
43:
action: "NODE_SHUTDOWN"
action: "node_shutdown"
options:
node_id: 6
node_name: client_2
44:
action: NODE_STARTUP
action: node_startup
options:
node_id: 6
node_name: client_2
45:
action: NODE_RESET
action: node_reset
options:
node_id: 6
node_name: client_2
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
action: "ROUTER_ACL_ADDRULE"
action: "router_acl_add_rule"
options:
target_router: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
permission: DENY
src_ip: 192.168.10.21 # client 1
dst_ip: ALL # ALL
src_port: ALL
dst_port: ALL
protocol_name: ALL
src_wildcard: NONE
dst_wildcard: NONE
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
action: "ROUTER_ACL_ADDRULE"
action: "router_acl_add_rule"
options:
target_router: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
permission: DENY
src_ip: 192.168.10.22 # client 2
dst_ip: ALL # ALL
src_port: ALL
dst_port: ALL
protocol_name: ALL
src_wildcard: NONE
dst_wildcard: NONE
48: # old action num: 24 # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
action: "router_acl_add_rule"
options:
target_router: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
permission: DENY
src_ip: 192.168.10.21 # client 1
dst_ip: 192.168.1.12 # web server
src_port: ALL
dst_port: ALL
protocol_name: TCP
src_wildcard: NONE
dst_wildcard: NONE
49: # old action num: 25 # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
action: "router_acl_add_rule"
options:
target_router: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
permission: DENY
src_ip: 192.168.10.22 # client 2
dst_ip: 192.168.1.12 # web server
src_port: ALL
dst_port: ALL
protocol_name: TCP
src_wildcard: NONE
dst_wildcard: NONE
50: # old action num: 26
action: "ROUTER_ACL_ADDRULE"
action: "router_acl_add_rule"
options:
target_router: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
permission: DENY
src_ip: 192.168.10.21 # client 1
dst_ip: 192.168.1.14 # database
src_port: ALL
dst_port: ALL
protocol_name: TCP
src_wildcard: NONE
dst_wildcard: NONE
51: # old action num: 27
action: "ROUTER_ACL_ADDRULE"
action: "router_acl_add_rule"
options:
target_router: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
permission: DENY
src_ip: 192.168.10.22 # client 2
dst_ip: 192.168.1.14 # database
src_port: ALL
dst_port: ALL
protocol_name: TCP
src_wildcard: NONE
dst_wildcard: NONE
52: # old action num: 28
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 0
53: # old action num: 29
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 1
54: # old action num: 30
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 2
55: # old action num: 31
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 3
56: # old action num: 32
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 4
57: # old action num: 33
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 5
58: # old action num: 34
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 6
59: # old action num: 35
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 7
60: # old action num: 36
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 8
61: # old action num: 37
action: "ROUTER_ACL_REMOVERULE"
action: "router_acl_remove_rule"
options:
target_router: router_1
position: 9
62: # old action num: 38
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 0
nic_id: 0
node_name: domain_controller
nic_num: 1
63: # old action num: 39
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 0
nic_id: 0
node_name: domain_controller
nic_num: 1
64: # old action num: 40
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 1
nic_id: 0
node_name: web_server
nic_num: 1
65: # old action num: 41
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 1
nic_id: 0
node_name: web_server
nic_num: 1
66: # old action num: 42
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 2
nic_id: 0
node_name: database_server
nic_num: 1
67: # old action num: 43
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 2
nic_id: 0
node_name: database_server
nic_num: 1
68: # old action num: 44
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 3
nic_id: 0
node_name: backup_server
nic_num: 1
69: # old action num: 45
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 3
nic_id: 0
node_name: backup_server
nic_num: 1
70: # old action num: 46
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 4
nic_id: 0
node_name: security_suite
nic_num: 1
71: # old action num: 47
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 4
nic_id: 0
node_name: security_suite
nic_num: 1
72: # old action num: 48
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 4
nic_id: 1
node_name: security_suite
nic_num: 2
73: # old action num: 49
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 4
nic_id: 1
node_name: security_suite
nic_num: 2
74: # old action num: 50
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 5
nic_id: 0
node_name: client_1
nic_num: 1
75: # old action num: 51
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 5
nic_id: 0
node_name: client_1
nic_num: 1
76: # old action num: 52
action: "HOST_NIC_DISABLE"
action: "host_nic_disable"
options:
node_id: 6
nic_id: 0
node_name: client_2
nic_num: 1
77: # old action num: 53
action: "HOST_NIC_ENABLE"
action: "host_nic_enable"
options:
node_id: 6
nic_id: 0
node_name: client_2
nic_num: 1
options:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:

File diff suppressed because it is too large Load Diff

View File

@@ -6,68 +6,48 @@ game:
agents:
- ref: RL_Agent
type: ProxyAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_SHUTDOWN
action: node_shutdown
options:
node_id: 0
node_name: client_1
2:
action: NODE_SHUTDOWN
action: node_shutdown
options:
node_id: 1
node_name: server
3:
action: NODE_STARTUP
action: node_startup
options:
node_id: 0
node_name: client_1
4:
action: NODE_STARTUP
action: node_startup
options:
node_id: 1
node_name: server
5:
action: HOST_NIC_DISABLE
action: host_nic_disable
options:
node_id: 0
nic_id: 0
node_name: client_1
nic_num: 1
6:
action: HOST_NIC_DISABLE
action: host_nic_disable
options:
node_id: 1
nic_id: 0
node_name: server
nic_num: 1
7:
action: HOST_NIC_ENABLE
action: host_nic_enable
options:
node_id: 0
nic_id: 0
node_name: client_1
nic_num: 1
8:
action: HOST_NIC_ENABLE
action: host_nic_enable
options:
node_id: 1
nic_id: 0
options:
nodes:
- node_name: client_1
- node_name: server
max_folders_per_node: 0
max_files_per_folder: 0
max_services_per_node: 0
max_nics_per_node: 1
max_acl_rules: 0
ip_list:
- 192.168.1.2
- 192.168.1.3
reward_function:
reward_components: []
node_name: server
nic_num: 1
simulation:
network:

View File

@@ -6,25 +6,17 @@ agents: &greens
action_probabilities:
0: 0.2
1: 0.8
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_APPLICATION_EXECUTE
action: node_application_execute
options:
node_id: 0
application_id: 0
node_name: client
application_name: DatabaseClient
reward_function:
reward_components:

View File

@@ -6,25 +6,17 @@ agents: &greens
action_probabilities:
0: 0.95
1: 0.05
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_APPLICATION_EXECUTE
action: node_application_execute
options:
node_id: 0
application_id: 0
node_name: client
application_name: DatabaseClient
reward_function:
reward_components:

View File

@@ -3,24 +3,9 @@ reds: &reds
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 10
frequency: 10
variance: 0
possible_start_nodes: [client,]
target_application: DataManipulationBot
start_step: 10
frequency: 10
variance: 0

View File

@@ -3,24 +3,9 @@ reds: &reds
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 3
frequency: 2
variance: 1
possible_start_nodes: [client_1]
target_application: DataManipulationBot
start_step: 3
frequency: 2
variance: 1

View File

@@ -54,65 +54,46 @@ agents:
- server:eth-1<->switch_1:eth-2
action_space:
action_list:
- type: DONOTHING
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
action: do_nothing
options: {}
1:
action: NODE_SHUTDOWN
action: node_shutdown
options:
node_id: 0
node_name: client_1
2:
action: NODE_SHUTDOWN
action: node_shutdown
options:
node_id: 1
node_name: server
3:
action: NODE_STARTUP
action: node_startup
options:
node_id: 0
node_name: client_1
4:
action: NODE_STARTUP
action: node_startup
options:
node_id: 1
node_name: server
5:
action: HOST_NIC_DISABLE
action: host_nic_disable
options:
node_id: 0
nic_id: 0
node_name: client_1
nic_num: 1
6:
action: HOST_NIC_DISABLE
action: host_nic_disable
options:
node_id: 1
nic_id: 0
node_name: server
nic_num: 1
7:
action: HOST_NIC_ENABLE
action: host_nic_enable
options:
node_id: 0
nic_id: 0
node_name: client_1
nic_num: 1
8:
action: HOST_NIC_ENABLE
action: host_nic_enable
options:
node_id: 1
nic_id: 0
options:
nodes:
- node_name: client
- node_name: server
max_folders_per_node: 0
max_files_per_folder: 0
max_services_per_node: 0
max_nics_per_node: 1
max_acl_rules: 0
ip_list:
- 192.168.1.2
- 192.168.1.3
node_name: server
nic_num: 1
reward_function:
reward_components:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from primaite.game.agent.actions import (
abstract,
acl,
application,
file,
folder,
host_nic,
manager,
network,
node,
service,
session,
software,
)
from primaite.game.agent.actions.manager import ActionManager
__all__ = (
"abstract",
"acl",
"application",
"software",
"file",
"folder",
"host_nic",
"manager",
"network",
"node",
"service",
"session",
"ActionManager",
)

View File

@@ -0,0 +1,36 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import ABC
from typing import Any, ClassVar, Dict, Optional, Type
from pydantic import BaseModel, ConfigDict
from primaite.interface.request import RequestFormat
class AbstractAction(BaseModel, ABC):
"""Base class for actions."""
config: "AbstractAction.ConfigSchema"
class ConfigSchema(BaseModel, ABC):
"""Base configuration schema for Actions."""
model_config = ConfigDict(extra="forbid")
type: str = ""
_registry: ClassVar[Dict[str, Type[AbstractAction]]] = {}
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Cannot create new action under reserved name {identifier}")
cls._registry[identifier] = cls
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
pass

View File

@@ -0,0 +1,157 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import ABC
from typing import Literal, Union
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port
__all__ = (
"RouterACLAddRuleAction",
"RouterACLRemoveRuleAction",
"FirewallACLAddRuleAction",
"FirewallACLRemoveRuleAction",
)
class ACLAddRuleAbstractAction(AbstractAction, ABC):
"""Base abstract class for ACL add rule actions."""
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL add rule abstract actions."""
src_ip: IPV4Address
protocol_name: Union[IPProtocol, Literal["ALL"]]
permission: Literal["PERMIT", "DENY"]
position: int
dst_ip: Union[IPV4Address, Literal["ALL"]]
src_port: Union[Port, Literal["ALL"]]
dst_port: Union[Port, Literal["ALL"]]
src_wildcard: Union[IPV4Address, Literal["NONE"]]
dst_wildcard: Union[IPV4Address, Literal["NONE"]]
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
"""Base abstract class for ACL remove rule actions."""
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema base for ACL remove rule abstract actions."""
position: int
class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"):
"""Action which adds a rule to a router's ACL."""
config: "RouterACLAddRuleAction.ConfigSchema"
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
"""Configuration Schema for RouterACLAddRuleAction."""
target_router: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.target_router,
"acl",
"add_rule",
config.permission,
config.protocol_name,
str(config.src_ip),
str(config.src_wildcard),
config.src_port,
str(config.dst_ip),
str(config.dst_wildcard),
config.dst_port,
config.position,
]
class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"):
"""Action which removes a rule from a router's ACL."""
config: "RouterACLRemoveRuleAction.ConfigSchema"
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
"""Configuration schema for RouterACLRemoveRuleAction."""
target_router: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"):
"""Action which adds a rule to a firewall port's ACL."""
config: "FirewallACLAddRuleAction.ConfigSchema"
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLAddRuleAction."""
target_firewall_nodename: str
firewall_port_name: str
firewall_port_direction: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.target_firewall_nodename,
config.firewall_port_name,
config.firewall_port_direction,
"acl",
"add_rule",
config.permission,
config.protocol_name,
str(config.src_ip),
str(config.src_wildcard),
config.src_port,
str(config.dst_ip),
str(config.dst_wildcard),
config.dst_port,
config.position,
]
class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"):
"""Action which removes a rule from a firewall port's ACL."""
config: "FirewallACLRemoveRuleAction.ConfigSchema"
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
"""Configuration schema for FirewallACLRemoveRuleAction."""
target_firewall_nodename: str
firewall_port_name: str
firewall_port_direction: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.target_firewall_nodename,
config.firewall_port_name,
config.firewall_port_direction,
"acl",
"remove_rule",
config.position,
]

View File

@@ -0,0 +1,137 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.abstract import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeApplicationExecuteAction",
"NodeApplicationScanAction",
"NodeApplicationCloseAction",
"NodeApplicationFixAction",
"NodeApplicationInstallAction",
"NodeApplicationRemoveAction",
)
class NodeApplicationAbstractAction(AbstractAction, ABC):
"""
Base class for application actions.
Any action which applies to an application and uses node_name and application_name as its only two parameters can
inherit from this base class.
"""
config: "NodeApplicationAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for Node Application actions."""
node_name: str
application_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"application",
config.application_name,
config.verb,
]
class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"):
"""Action which executes an application."""
config: "NodeApplicationExecuteAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationExecuteAction."""
verb: str = "execute"
class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"):
"""Action which scans an application."""
config: "NodeApplicationScanAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationScanAction."""
verb: str = "scan"
class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"):
"""Action which closes an application."""
config: "NodeApplicationCloseAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationCloseAction."""
verb: str = "close"
class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"):
"""Action which fixes an application."""
config: "NodeApplicationFixAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationFixAction."""
verb: str = "fix"
class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"):
"""Action which installs an application."""
config: "NodeApplicationInstallAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationInstallAction."""
verb: str = "install"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"software_manager",
"application",
config.verb,
config.application_name,
]
class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"):
"""Action which removes/uninstalls an application."""
config: "NodeApplicationRemoveAction.ConfigSchema"
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
"""Configuration schema for NodeApplicationRemoveAction."""
verb: str = "uninstall"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"software_manager",
"application",
config.verb,
config.application_name,
]

View File

@@ -0,0 +1,189 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeFileCreateAction",
"NodeFileScanAction",
"NodeFileDeleteAction",
"NodeFileRestoreAction",
"NodeFileCorruptAction",
"NodeFileAccessAction",
"NodeFileCheckhashAction",
"NodeFileRepairAction",
)
class NodeFileAbstractAction(AbstractAction, ABC):
"""Abstract base class for file actions.
Any action which applies to a file and uses node_name, folder_name, and file_name as its
only three parameters can inherit from this base class.
"""
config: "NodeFileAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for NodeFileAbstractAction."""
node_name: str
folder_name: str
file_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
"folder",
config.folder_name,
"file",
config.file_name,
config.verb,
]
class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"):
"""Action which creates a new file in a given folder."""
config: "NodeFileCreateAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCreateAction."""
verb: ClassVar[str] = "create"
force: bool = False
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
"file",
config.folder_name,
config.file_name,
config.verb,
]
class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
"""Action which scans a file."""
config: "NodeFileScanAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileScanAction."""
verb: ClassVar[str] = "scan"
class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"):
"""Action which deletes a file."""
config: "NodeFileDeleteAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileDeleteAction."""
verb: ClassVar[str] = "delete"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
"file",
config.folder_name,
config.file_name,
]
class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"):
"""Action which restores a file."""
config: "NodeFileRestoreAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileRestoreAction."""
verb: ClassVar[str] = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"):
"""Action which corrupts a file."""
config: "NodeFileCorruptAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCorruptAction."""
verb: ClassVar[str] = "corrupt"
class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"):
"""Action which increases a file's access count."""
config: "NodeFileAccessAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileAccessAction."""
verb: ClassVar[str] = "access"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None or config.file_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
config.folder_name,
config.file_name,
]
class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"):
"""Action which checks the hash of a file."""
config: "NodeFileCheckhashAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration schema for NodeFileCheckhashAction."""
verb: ClassVar[str] = "checkhash"
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
"""Action which repairs a file."""
config: "NodeFileRepairAction.ConfigSchema"
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
"""Configuration Schema for NodeFileRepairAction."""
verb: ClassVar[str] = "repair"

View File

@@ -0,0 +1,117 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeFolderScanAction",
"NodeFolderCheckhashAction",
"NodeFolderRepairAction",
"NodeFolderRestoreAction",
"NodeFolderCreateAction",
)
class NodeFolderAbstractAction(AbstractAction, ABC):
"""
Base class for folder actions.
Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from
this base class.
"""
config: "NodeFolderAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NodeFolder actions."""
node_name: str
folder_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
"folder",
config.folder_name,
config.verb,
]
class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"):
"""Action which scans a folder."""
config: "NodeFolderScanAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderScanAction."""
verb: ClassVar[str] = "scan"
class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"):
"""Action which checks the hash of a folder."""
config: "NodeFolderCheckhashAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderCheckhashAction."""
verb: ClassVar[str] = "checkhash"
class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"):
"""Action which repairs a folder."""
config: "NodeFolderRepairAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderRepairAction."""
verb: ClassVar[str] = "repair"
class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"):
"""Action which restores a folder."""
config: "NodeFolderRestoreAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderRestoreAction."""
verb: ClassVar[str] = "restore"
class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"):
"""Action which creates a new folder."""
config: "NodeFolderCreateAction.ConfigSchema"
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
"""Configuration schema for NodeFolderCreateAction."""
verb: ClassVar[str] = "create"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.folder_name is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"file_system",
config.verb,
"folder",
config.folder_name,
]

View File

@@ -0,0 +1,62 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("HostNICEnableAction", "HostNICDisableAction")
class HostNICAbstractAction(AbstractAction, ABC):
"""
Abstract base class for NIC actions.
Any action which applies to a NIC and uses node_name and nic_num as its only two parameters can inherit from this
base class.
"""
config: "HostNICAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for HostNIC actions."""
node_name: str
nic_num: int
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.nic_num is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"network_interface",
config.nic_num,
config.verb,
]
class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"):
"""Action which enables a NIC."""
config: "HostNICEnableAction.ConfigSchema"
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
"""Configuration schema for HostNICEnableAction."""
verb: ClassVar[str] = "enable"
class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"):
"""Action which disables a NIC."""
config: "HostNICDisableAction.ConfigSchema"
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
"""Configuration schema for HostNICDisableAction."""
verb: ClassVar[str] = "disable"

View File

@@ -0,0 +1,108 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""yaml example.
agents:
- name: agent_1
action_space:
actions:
- do_nothing
- node_service_start
- node_service_stop
action_map:
"""
from __future__ import annotations
from typing import Dict, Tuple
from gymnasium import spaces
from pydantic import BaseModel, ConfigDict, Field, field_validator
from primaite.game.agent.actions.abstract import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("DoNothingAction", "ActionManager")
class DoNothingAction(AbstractAction, identifier="do_nothing"):
"""Do Nothing Action."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration Schema for do_nothingAction."""
type: str = "do_nothing"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
class _ActionMapItem(BaseModel):
model_config = ConfigDict(extra="forbid")
action: str
options: Dict
class ActionManager(BaseModel):
"""Class which manages the action space for an agent."""
class ConfigSchema(BaseModel):
"""Config Schema for ActionManager."""
model_config = ConfigDict(extra="forbid")
action_map: Dict[int, _ActionMapItem] = {}
"""Mapping between integer action choices and CAOS actions."""
@field_validator("action_map", mode="after")
def consecutive_action_nums(cls, v: Dict) -> Dict:
"""Make sure all numbers between 0 and N are represented as dict keys in action map."""
assert all([i in v.keys() for i in range(len(v))])
return v
config: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
action_map: Dict[int, Tuple[str, Dict]] = {}
"""Init as empty, populate after model validation."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.action_map = {n: (v.action, v.options) for n, v in self.config.action_map.items()}
def get_action(self, action: int) -> Tuple[str, Dict]:
"""
Produce action in CAOS format.
The agent chooses an action (as an integer), this is converted into an action in CAOS format
The CAOS format is basically an action identifier, followed by parameters stored in a dictionary.
"""
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_class = AbstractAction._registry[action_identifier]
config = act_class.ConfigSchema(**action_options)
return act_class.form_request(config=config)
@property
def space(self) -> spaces.Space:
"""Return the gymnasium action space for this agent."""
return spaces.Discrete(len(self.action_map))
@classmethod
def from_config(cls, cfg: Dict) -> "ActionManager":
"""
Construct an ActionManager from a config dictionary.
The action space config supports must contain the following key:
``action_map`` - List of actions available to the agent, formatted as a dictionary where the key is the
action number between 0 - N, and the value is the CAOS-formatted action.
:param cfg: The action space config.
:type cfg: Dict
:return: The constructed ActionManager.
:rtype: ActionManager
"""
return cls(**cfg.get("options", {}), act_map=cfg.get("action_map"))

View File

@@ -0,0 +1,57 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"):
"""Base class for Network port actions."""
config: "NetworkPortAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NetworkPort actions."""
target_nodename: str
port_num: int
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.target_nodename is None or config.port_num is None:
return ["do_nothing"]
return [
"network",
"node",
config.target_nodename,
"network_interface",
config.port_num,
config.verb,
]
class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"):
"""Action which enables are port on a router or a firewall."""
config: "NetworkPortEnableAction.ConfigSchema"
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
"""Configuration schema for NetworkPortEnableAction."""
verb: ClassVar[str] = "enable"
class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"):
"""Action which disables are port on a router or a firewall."""
config: "NetworkPortDisableAction.ConfigSchema"
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
"""Configuration schema for NetworkPortDisableAction."""
verb: ClassVar[str] = "disable"

View File

@@ -0,0 +1,186 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from typing import ClassVar, List, Optional, Union
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.port import Port
__all__ = (
"NodeOSScanAction",
"NodeShutdownAction",
"NodeStartupAction",
"NodeResetAction",
"NodeNMAPPingScanAction",
"NodeNMAPPortScanAction",
"NodeNetworkServiceReconAction",
)
class NodeAbstractAction(AbstractAction, identifier="node_abstract"):
"""
Abstract base class for node actions.
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
"""
config: "NodeAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration schema for Node actions."""
node_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
print(config)
return ["network", "node", config.node_name, config.verb]
class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"):
"""Action which scans a node's OS."""
config: "NodeOSScanAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeOSScanAction."""
verb: ClassVar[str] = "scan"
class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"):
"""Action which shuts down a node."""
config: "NodeShutdownAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeShutdownAction."""
verb: ClassVar[str] = "shutdown"
class NodeStartupAction(NodeAbstractAction, identifier="node_startup"):
"""Action which starts up a node."""
config: "NodeStartupAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeStartupAction."""
verb: ClassVar[str] = "startup"
class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
"""Action which resets a node."""
config: "NodeResetAction.ConfigSchema"
class ConfigSchema(NodeAbstractAction.ConfigSchema):
"""Configuration schema for NodeResetAction."""
verb: ClassVar[str] = "reset"
class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"):
"""Base class for NodeNMAP actions."""
config: "NodeNMAPAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base Configuration Schema for NodeNMAP actions."""
target_ip_address: Union[str, List[str]]
show: bool = False
source_node: str
@classmethod
@abstractmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
# NMAP action requests don't share a common format for their requests
# This is just a placeholder to ensure the method is defined.
pass
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
"""Action which performs an NMAP ping scan."""
config: "NodeNMAPPingScanAction.ConfigSchema"
@classmethod
def form_request(cls, config: "NodeNMAPPingScanAction.ConfigSchema") -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.source_node,
"application",
"NMAP",
"ping_scan",
{"target_ip_address": config.target_ip_address, "show": config.show},
]
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
"""Action which performs an NMAP port scan."""
config: "NodeNMAPPortScanAction.ConfigSchema"
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration Schema for NodeNMAPPortScanAction."""
source_node: str
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None
target_port: Optional[Union[Port, List[Port]]] = None
show: Optional[bool] = (False,)
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.source_node,
"application",
"NMAP",
"port_scan",
{
"target_ip_address": config.target_ip_address,
"target_port": config.target_port,
"target_protocol": config.target_protocol,
"show": config.show,
},
]
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"):
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
config: "NodeNetworkServiceReconAction.ConfigSchema"
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
"""Configuration schema for NodeNetworkServiceReconAction."""
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None
target_port: Optional[Union[Port, List[Port]]] = None
show: Optional[bool] = (False,)
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.source_node,
"application",
"NMAP",
"network_service_recon",
{
"target_ip_address": config.target_ip_address,
"target_port": config.target_port,
"target_protocol": config.target_protocol,
"show": config.show,
},
]

View File

@@ -0,0 +1,135 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeServiceScanAction",
"NodeServiceStopAction",
"NodeServiceStartAction",
"NodeServicePauseAction",
"NodeServiceResumeAction",
"NodeServiceRestartAction",
"NodeServiceDisableAction",
"NodeServiceEnableAction",
"NodeServiceFixAction",
)
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
"""Abstract Action for Node Service related actions.
Any actions which use node_name and service_name can inherit from this class.
"""
config: "NodeServiceAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
node_name: str
service_name: str
verb: ClassVar[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", config.node_name, "service", config.service_name, config.verb]
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
"""Action which scans a service."""
config: "NodeServiceScanAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceScanAction."""
verb: ClassVar[str] = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"):
"""Action which stops a service."""
config: "NodeServiceStopAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStopAction."""
verb: ClassVar[str] = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"):
"""Action which starts a service."""
config: "NodeServiceStartAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceStartAction."""
verb: ClassVar[str] = "start"
class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"):
"""Action which pauses a service."""
config: "NodeServicePauseAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServicePauseAction."""
verb: ClassVar[str] = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"):
"""Action which resumes a service."""
config: "NodeServiceResumeAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceResumeAction."""
verb: ClassVar[str] = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"):
"""Action which restarts a service."""
config: "NodeServiceRestartAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceRestartAction."""
verb: ClassVar[str] = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"):
"""Action which disables a service."""
config: "NodeServiceDisableAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceDisableAction."""
verb: ClassVar[str] = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"):
"""Action which enables a service."""
config: "NodeServiceEnableAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceEnableAction."""
verb: ClassVar[str] = "enable"
class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"):
"""Action which fixes a service."""
config: "NodeServiceFixAction.ConfigSchema"
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
"""Configuration Schema for NodeServiceFixAction."""
verb: ClassVar[str] = "fix"

View File

@@ -0,0 +1,108 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"NodeSessionsRemoteLoginAction",
"NodeSessionsRemoteLogoutAction",
"NodeAccountChangePasswordAction",
)
class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"):
"""Base class for NodeSession actions."""
config: "NodeSessionAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Base configuration schema for NodeSessionAbstractActions."""
node_name: str
remote_ip: str
@classmethod
@abstractmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""
Abstract method for request forming.
Should return the action formatted as a request which can be ingested by the PrimAITE simulation.
"""
pass
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"):
"""Action which performs a remote session login."""
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLoginAction."""
username: str
password: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.remote_ip is None:
return ["do_nothing"]
return [
"network",
"node",
config.node_name,
"service",
"Terminal",
"node_session_remote_login",
config.username,
config.password,
config.remote_ip,
]
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"):
"""Action which performs a remote session logout."""
config: "NodeSessionsRemoteLogoutAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
verb: str = "remote_logoff"
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if config.node_name is None or config.remote_ip is None:
return ["do_nothing"]
return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip]
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"):
"""Action which changes the password for a user."""
config: "NodeAccountChangePasswordAction.ConfigSchema"
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
"""Configuration schema for NodeAccountsChangePasswordAction."""
username: str
current_password: str
new_password: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"service",
"UserManager",
"change_password",
config.username,
config.current_password,
config.new_password,
]

View File

@@ -0,0 +1,241 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import List, Optional, Union
from pydantic import ConfigDict, Field
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
__all__ = (
"ConfigureRansomwareScriptAction",
"ConfigureDoSBotAction",
"ConfigureC2BeaconAction",
"NodeSendRemoteCommandAction",
"TerminalC2ServerAction",
"RansomwareLaunchC2ServerAction",
"ExfiltrationC2ServerAction",
"ConfigureDatabaseClientAction",
)
class ConfigureRansomwareScriptAction(AbstractAction, identifier="configure_ransomware_script"):
"""Action which sets config parameters for a ransomware script on a node."""
config: "ConfigureRansomwareScriptAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for ConfigureRansomwareScriptAction."""
node_name: str
server_ip_address: Optional[str] = None
server_password: Optional[str] = None
payload: Optional[str] = None
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
data = dict(
server_ip_address=config.server_ip_address,
server_password=config.server_password,
payload=config.payload,
)
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", data]
class RansomwareConfigureC2ServerAction(ConfigureRansomwareScriptAction, identifier="c2_server_ransomware_configure"):
"""Action which causes a C2 server to send a command to set options on a ransomware script remotely."""
@classmethod
def form_request(cls, config: ConfigureRansomwareScriptAction.ConfigSchema) -> RequestFormat:
data = dict(
server_ip_address=config.server_ip_address, server_password=config.server_password, payload=config.payload
)
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_configure", data]
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
"""Action which sets config parameters for a DoS bot on a node."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
model_config = ConfigDict(extra="forbid")
node_name: str
target_ip_address: Optional[str] = None
target_port: Optional[str] = None
payload: Optional[str] = None
repeat: Optional[bool] = None
port_scan_p_of_success: Optional[float] = None
dos_intensity: Optional[float] = None
max_sessions: Optional[int] = None
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
data = dict(
target_ip_address=config.target_ip_address,
target_port=config.target_port,
payload=config.payload,
repeat=config.repeat,
port_scan_p_of_success=config.port_scan_p_of_success,
dos_intensity=config.dos_intensity,
max_sessions=config.max_sessions,
)
data = {k: v for k, v in data.items() if v is not None}
return ["network", "node", config.node_name, "application", "DoSBot", "configure", data]
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
"""Action which configures a C2 Beacon based on the parameters given."""
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for ConfigureC2BeaconAction."""
node_name: str
c2_server_ip_address: str
keep_alive_frequency: int = Field(default=5, ge=1)
masquerade_protocol: str = Field(default="TCP")
masquerade_port: str = Field(default="HTTP")
@classmethod
def form_request(self, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
data = dict(
c2_server_ip_address=config.c2_server_ip_address,
keep_alive_frequency=config.keep_alive_frequency,
masquerade_protocol=config.masquerade_protocol,
masquerade_port=config.masquerade_port,
)
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", data]
class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"):
"""Action which sends a terminal command to a remote node via SSH."""
config: "NodeSendRemoteCommandAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for NodeSendRemoteCommandAction."""
node_name: str
remote_ip: str
command: RequestFormat
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return [
"network",
"node",
config.node_name,
"service",
"Terminal",
"send_remote_command",
config.remote_ip,
{"command": config.command},
]
class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
config: "TerminalC2ServerAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
node_name: str
commands: Union[List[RequestFormat], RequestFormat]
ip_address: Optional[str]
username: Optional[str]
password: Optional[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
command_model = {
"commands": config.commands,
"ip_address": config.ip_address,
"username": config.username,
"password": config.password,
}
return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model]
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"):
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
config: "RansomwareLaunchC2ServerAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Configuration schema for RansomwareLaunchC2ServerAction."""
node_name: str
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
# This action currently doesn't require any further configuration options.
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"]
class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"):
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
config: "ExfiltrationC2ServerAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
node_name: str
username: Optional[str]
password: Optional[str]
target_ip_address: str
target_file_name: str
target_folder_name: str
exfiltration_folder_name: Optional[str]
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
command_model = {
"target_file_name": config.target_file_name,
"target_folder_name": config.target_folder_name,
"exfiltration_folder_name": config.exfiltration_folder_name,
"target_ip_address": config.target_ip_address,
"username": config.username,
"password": config.password,
}
return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model]
class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"):
"""Action which sets config parameters for a database client on a node."""
config: "ConfigureDatabaseClientAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
"""Schema for options that can be passed to this action."""
node_name: str
server_ip_address: Optional[str] = None
server_password: Optional[str] = None
@classmethod
def form_request(cls, config: ConfigSchema) -> RequestFormat:
"""Return the action formatted as a request that can be ingested by the simulation."""
if config.node_name is None:
return ["do_nothing"]
data = {"server_ip_address": config.server_ip_address, "server_password": config.server_password}
return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", data]

View File

@@ -1,6 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import logging
from pathlib import Path
from typing import Optional
from prettytable import MARKDOWN, PrettyTable
@@ -20,20 +21,22 @@ class _NotJSONFilter(logging.Filter):
class AgentLog:
"""
A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent.
An Agent Log class is a simple logger dedicated to managing and writing updates and information for an agent.
Each log message is written to a file located at: <simulation output directory>/agent_name/agent_name.log
Each log message is written to a file located at:
<simulation output directory>/agent_name/agent_name.log
"""
def __init__(self, agent_name: str):
def __init__(self, agent_name: Optional[str]):
"""
Constructs a Agent Log instance for a given hostname.
:param hostname: The hostname associated with the system logs being recorded.
:param agent_name: The agent_name associated with the system logs being recorded.
"""
self.agent_name = agent_name
self.current_episode: int = 1
super().__init__()
self.agent_name = agent_name if agent_name else "unnamed_agent"
self.current_timestep: int = 0
self.current_episode: int = 1
self.setup_logger()
@property
@@ -90,7 +93,7 @@ class AgentLog:
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal:
print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}")
print(f"{self.agent_name}: ({self.timestep}) ({level}) {msg}")
def debug(self, msg: str, to_terminal: bool = False):
"""

View File

@@ -1,10 +1,12 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""Interface for agents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, ConfigDict, Field
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.agent_log import AgentLog
@@ -15,6 +17,8 @@ from primaite.interface.request import RequestFormat, RequestResponse
if TYPE_CHECKING:
pass
__all__ = ("AgentHistoryItem", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
class AgentHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
@@ -39,89 +43,56 @@ class AgentHistoryItem(BaseModel):
reward_info: Dict[str, Any] = {}
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AgentStartSettings":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class AgentSettings(BaseModel):
"""Settings for configuring the operation of an agent."""
start_settings: Optional[AgentStartSettings] = None
"Configuration for when an agent begins performing it's actions"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
"""Construct agent settings from a config dictionary.
:param config: A dict of options for the agent settings.
:type config: Dict
:return: The agent settings.
:rtype: AgentSettings
"""
if config is None:
return cls()
return cls(**config)
class AbstractAgent(ABC):
class AbstractAgent(BaseModel, ABC):
"""Base class for scripted and RL agents."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings] = None,
) -> None:
"""
Initialize an agent.
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
:type agent_name: Optional[str]
:param action_space: Action space for the agent.
:type action_space: Optional[ActionManager]
:param observation_space: Observation space for the agent.
:type observation_space: Optional[ObservationSpace]
:param reward_function: Reward function for the agent.
:type reward_function: Optional[RewardFunction]
:param agent_settings: Configurable Options for Abstracted Agents
:type agent_settings: Optional[AgentSettings]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.history: List[AgentHistoryItem] = []
self.logger = AgentLog(agent_name)
class AgentSettingsSchema(BaseModel, ABC):
"""Schema for the 'agent_settings' key."""
model_config = ConfigDict(extra="forbid")
class ConfigSchema(BaseModel, ABC):
"""Configuration Schema for AbstractAgents."""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
type: str
ref: str = ""
"""name of the agent."""
team: Optional[Literal["BLUE", "GREEN", "RED"]] = None
agent_settings: AbstractAgent.AgentSettingsSchema = Field(default=lambda: AbstractAgent.AgentSettingsSchema())
action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
observation_space: ObservationManager.ConfigSchema = Field(
default_factory=lambda: ObservationManager.ConfigSchema()
)
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
history: List[AgentHistoryItem] = []
action_manager: ActionManager = Field(default_factory=lambda: ActionManager())
observation_manager: ObservationManager = Field(default_factory=lambda: ObservationManager())
reward_function: RewardFunction = Field(default_factory=lambda: RewardFunction())
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
cls._registry[identifier] = cls
def model_post_init(self, __context: Any) -> None:
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
self.action_manager = ActionManager(config=self.config.action_space)
self.observation_manager = ObservationManager(config=self.config.observation_space)
self.reward_function = RewardFunction(config=self.config.reward_function)
return super().model_post_init(__context)
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -159,9 +130,9 @@ class AbstractAgent(ABC):
"""
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
return ("do_nothing", {})
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> RequestFormat:
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
@@ -182,36 +153,47 @@ class AbstractAgent(ABC):
"""Update the most recent history item with the reward value."""
self.history[-1].reward = self.reward_function.current_reward
@classmethod
def from_config(cls, config: Dict) -> AbstractAgent:
"""Grab the relevant agent class and construct an instance from a config dict."""
agent_type = config["type"]
agent_class = cls._registry[agent_type]
return agent_class(config=config)
class AbstractScriptedAgent(AbstractAgent):
class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"):
"""Base class for actors which generate their own behaviour."""
config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for AbstractScriptedAgents."""
type: str = "AbstractScriptedAgent"
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Return an action to be taken in the environment."""
return super().get_action(obs=obs, timestep=timestep)
class ProxyAgent(AbstractAgent):
class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
"""Agent that sends observations to an RL model and receives actions from that model."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings] = None,
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema())
most_recent_action: ActType = None
class AgentSettingsSchema(AbstractAgent.AgentSettingsSchema):
"""Schema for the `agent_settings` part of the agent config."""
flatten_obs: bool = False
action_masking: bool = False
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Proxy Agent."""
type: str = "Proxy_Agent"
agent_settings: ProxyAgent.AgentSettingsSchema = Field(default_factory=lambda: ProxyAgent.AgentSettingsSchema())
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
@@ -233,3 +215,8 @@ class ProxyAgent(AbstractAgent):
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action
@property
def flatten_obs(self) -> bool:
"""Return agent flatten_obs param."""
return self.config.agent_settings.flatten_obs

View File

@@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs
__all__ = [
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
"ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",]
# fmt: on

View File

@@ -24,8 +24,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
"""List of IP addresses."""
wildcard_list: Optional[List[str]] = None
"""List of wildcard strings."""
port_list: Optional[List[int]] = None
"""List of port numbers."""
port_list: Optional[List[str]] = None
"""List of port names."""
protocol_list: Optional[List[str]] = None
"""List of protocol names."""
num_rules: Optional[int] = None
@@ -37,7 +37,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
num_rules: int,
ip_list: List[IPv4Address],
wildcard_list: List[str],
port_list: List[int],
port_list: List[str],
protocol_list: List[str],
) -> None:
"""
@@ -51,8 +51,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
:type ip_list: List[IPv4Address]
:param wildcard_list: List of wildcard strings.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param port_list: List of port names.
:type port_list: List[str]
:param protocol_list: List of protocol names.
:type protocol_list: List[str]
"""
@@ -60,7 +60,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i

View File

@@ -190,6 +190,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
if self.files:
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
self.cached_obs: Optional[ObsType] = self.default_observation
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
@@ -204,7 +206,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
return self.default_observation
if self.file_system_requires_scan:
health_status = folder_state["visible_status"]
if not folder_state["scanned_this_step"]:
health_status = self.cached_obs["health_status"]
else:
health_status = folder_state["visible_status"]
else:
health_status = folder_state["health_status"]

View File

@@ -27,13 +27,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
port_list: Optional[List[str]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
include_users: Optional[bool] = None
"""If True, report user session information."""
def __init__(
@@ -41,7 +41,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
where: WhereType,
ip_list: List[str],
wildcard_list: List[str],
port_list: List[int],
port_list: List[str],
protocol_list: List[str],
num_rules: int,
include_users: bool,
@@ -56,8 +56,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:type ip_list: List[str]
:param wildcard_list: List of wildcard rules.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param port_list: List of port names.
:type port_list: List[str]
:param protocol_list: List of protocol types.
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
@@ -72,7 +72,6 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
]
# TODO: check what the port nums are for firewall.
self.internal_inbound_acl = ACLObservation(
where=self.where + ["internal_inbound_acl", "acl"],
@@ -140,6 +139,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
},
},
}
if self.include_users:
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
def observe(self, state: Dict) -> ObsType:
"""
@@ -153,29 +154,35 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
firewall_state = access_from_nested_dict(state, self.where)
if firewall_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
is_on = firewall_state["operating_state"] == 1
if not is_on:
obs = {**self.default_observation}
else:
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
},
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
if self.include_users:
sess = firewall_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -186,34 +193,36 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
:return: Gymnasium space representing the observation space for firewall status.
:rtype: spaces.Space
"""
space = spaces.Dict(
{
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
"ACL": spaces.Dict(
{
"INTERNAL": spaces.Dict(
{
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}
),
"DMZ": spaces.Dict(
{
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}
),
"EXTERNAL": spaces.Dict(
{
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}
),
}
),
}
)
return space
shape = {
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
"ACL": spaces.Dict(
{
"INTERNAL": spaces.Dict(
{
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}
),
"DMZ": spaces.Dict(
{
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}
),
"EXTERNAL": spaces.Dict(
{
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}
),
}
),
}
if self.include_users:
shape["users"] = spaces.Dict(
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
)
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:

View File

@@ -54,7 +54,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"""
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
"""
include_users: Optional[bool] = True
include_users: Optional[bool] = None
"""If True, report user session information."""
def __init__(
@@ -191,25 +191,31 @@ class HostObservation(AbstractObservation, identifier="HOST"):
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
is_on = node_state["operating_state"] == 1
if not is_on:
obs = {**self.default_observation}
else:
obs = {}
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_users:
sess = node_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
obs["operating_status"] = node_state["operating_state"]
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_users:
sess = node_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property

View File

@@ -56,7 +56,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
port_list: Optional[List[str]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""

View File

@@ -1,11 +1,12 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from functools import cached_property
from typing import Any, Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
from pydantic import BaseModel, computed_field, ConfigDict, Field, model_validator, ValidationError
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
@@ -140,7 +141,7 @@ class NullObservation(AbstractObservation, identifier="NONE"):
return cls()
class ObservationManager:
class ObservationManager(BaseModel):
"""
Manage the observations of an Agent.
@@ -150,15 +151,66 @@ class ObservationManager:
3. Formatting this information so an agent can use it to make decisions.
"""
def __init__(self, obs: AbstractObservation) -> None:
"""Initialise observation space.
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
:param observation: Observation object
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = obs
self.current_observation: ObsType
"""Cached copy of the observation at the time it was most recently calculated."""
class ConfigSchema(BaseModel):
"""Config Schema for Observation Manager."""
model_config = ConfigDict(extra="forbid")
type: str = "NONE"
"""Identifier name for the top-level observation."""
options: AbstractObservation.ConfigSchema = Field(
default_factory=lambda: NullObservation.ConfigSchema(), validate_default=True
)
"""Options to pass into the top-level observation during creation."""
@model_validator(mode="before")
@classmethod
def resolve_obs_options_type(cls, data: Any) -> Any:
"""
When constructing the model from a dict, resolve the correct observation class based on `type` field.
Workaround: The `options` field is statically typed as AbstractObservation. Therefore, it falls over when
passing in data that adheres to a subclass schema rather than the plain AbstractObservation schema. There is
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
our plugin architecture.
We may be able to revisit and implement a better solution when needed using the following resources as
research starting points:
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
https://github.com/pydantic/pydantic/issues/7366
https://github.com/pydantic/pydantic/issues/7462
https://github.com/pydantic/pydantic/pull/7983
"""
if not isinstance(data, dict):
return data
# (TODO: duplicate default definition between here and the actual model)
obs_type = data["type"] if "type" in data else "NONE"
obs_class = AbstractObservation._registry[obs_type]
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields
if "options" not in data:
data["options"] = obs_class.ConfigSchema()
# if options passed as a dict, validate against schema
elif isinstance(data["options"], dict):
data["options"] = obs_class.ConfigSchema(**data["options"])
return data
config: ConfigSchema = Field(default_factory=lambda: ObservationManager.ConfigSchema())
current_observation: ObsType = 0
@computed_field
@cached_property
def obs(self) -> AbstractObservation:
"""Create the main observation component for the observation manager from the config."""
obs_class = AbstractObservation._registry[self.config.type]
obs_instance = obs_class.from_config(config=self.config.options)
return obs_instance
def update(self, state: Dict) -> Dict:
"""

View File

@@ -31,7 +31,7 @@ class AbstractObservation(ABC):
"""Initialise an observation. This method must be overwritten."""
self.default_observation: ObsType
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register an observation type.
@@ -40,6 +40,8 @@ class AbstractObservation(ABC):
:raises ValueError: When attempting to create a component with a name that is already in use.
"""
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Duplicate observation component type {identifier}")
cls._registry[identifier] = cls

View File

@@ -33,13 +33,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
port_list: Optional[List[str]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
include_users: Optional[bool] = True
include_users: Optional[bool] = None
"""If True, report user session information."""
def __init__(
@@ -84,6 +84,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
}
if self.ports:
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
if self.include_users:
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
def observe(self, state: Dict) -> ObsType:
"""
@@ -98,16 +100,21 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
if router_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
if self.include_users:
sess = router_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
is_on = router_state["operating_state"] == 1
if not is_on:
obs = {**self.default_observation}
else:
obs = {}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
if self.include_users:
sess = router_state["services"]["UserSessionManager"]
obs["users"] = {
"local_login": 1 if sess["current_local_user"] else 0,
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
}
return obs
@property
@@ -121,6 +128,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
shape = {"ACL": self.acl.space}
if self.ports:
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
if self.include_users:
shape["users"] = spaces.Dict(
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
)
return spaces.Dict(shape)
@classmethod

View File

@@ -30,7 +30,7 @@ the structure:
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Never
from primaite import getLogger
@@ -48,21 +48,17 @@ class AbstractReward(BaseModel):
config: "AbstractReward.ConfigSchema"
# def __init__(self, schema_name, **kwargs):
# super.__init__(self, **kwargs)
# # Create ConfigSchema class
# self.config_class = type(schema_name, (BaseModel, ABC), **kwargs)
# self.config = self.config_class()
class ConfigSchema(BaseModel, ABC):
"""Config schema for AbstractReward."""
type: str
type: str = ""
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Duplicate reward {identifier}")
cls._registry[identifier] = cls
@@ -381,14 +377,19 @@ class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
"""Apply a negative reward when taking any action except DONOTHING."""
"""Apply a negative reward when taking any action except do_nothing."""
config: "ActionPenalty.ConfigSchema"
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for ActionPenalty."""
"""Config schema for ActionPenalty.
:param action_penalty: Reward to give agents for taking any action except do_nothing
:type action_penalty: float
:param do_nothing_penalty: Reward to give agent for taking the do_nothing action
:type do_nothing_penalty: float
"""
type: str = "ACTION_PENALTY"
action_penalty: float = -1.0
do_nothing_penalty: float = 0.0
@@ -402,21 +403,81 @@ class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
:return: Reward value
:rtype: float
"""
if last_action_response.action == "DONOTHING":
if last_action_response.action == "do_nothing":
return self.config.do_nothing_penalty
else:
return self.config.action_penalty
class RewardFunction:
class _SingleComponentConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
type: str
options: AbstractReward.ConfigSchema
weight: float = 1.0
@model_validator(mode="before")
@classmethod
def resolve_obs_options_type(cls, data: Any) -> Any:
"""
When constructing the model from a dict, resolve the correct reward class based on `type` field.
Workaround: The `options` field is statically typed as AbstractReward. Therefore, it falls over when
passing in data that adheres to a subclass schema rather than the plain AbstractReward schema. There is
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
our plugin architecture.
We may be able to revisit and implement a better solution when needed using the following resources as
research starting points:
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
https://github.com/pydantic/pydantic/issues/7366
https://github.com/pydantic/pydantic/issues/7462
https://github.com/pydantic/pydantic/pull/7983
"""
if not isinstance(data, dict):
return data
assert "type" in data, ValueError('Reward component definition is missing the "type" key.')
rew_type = data["type"]
rew_class = AbstractReward._registry[rew_type]
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields.
if "options" not in data:
data["options"] = rew_class.ConfigSchema()
# if options are passed as a dict, validate against schema
elif isinstance(data["options"], dict):
data["options"] = rew_class.ConfigSchema(**data["options"])
return data
class RewardFunction(BaseModel):
"""Manages the reward function for the agent."""
def __init__(self):
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float = 0.0
self.total_reward: float = 0.0
model_config = ConfigDict(extra="forbid")
class ConfigSchema(BaseModel):
"""Config Schema for RewardFunction."""
model_config = ConfigDict(extra="forbid")
reward_components: Iterable[_SingleComponentConfig] = []
config: ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
reward_components: List[Tuple[AbstractReward, float]] = []
current_reward: float = 0.0
total_reward: float = 0.0
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
for rew_config in self.config.reward_components:
rew_class = AbstractReward._registry[rew_config.type]
rew_instance = rew_class(config=rew_config.options)
self.register_component(component=rew_instance, weight=rew_config.weight)
def register_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.

View File

@@ -1 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from primaite.game.agent import interface
from primaite.game.agent.scripted_agents import abstract_tap, data_manipulation_bot, probabilistic_agent, random_agent
__all__ = ("abstract_tap", "data_manipulation_bot", "interface", "probabilistic_agent", "random_agent")

View File

@@ -0,0 +1,61 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
import random
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
from gymnasium.core import ObsType
from pydantic import Field
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
__all__ = "AbstractTAPAgent"
class AbstractTAPAgent(PeriodicAgent, identifier="AbstractTAP"):
"""Base class for TAP agents to inherit from."""
config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema())
next_execution_timestep: int = 0
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
"""Schema for the `agent_settings` part of the agent config."""
possible_starting_nodes: List[str] = Field(default_factory=list)
class ConfigSchema(PeriodicAgent.ConfigSchema):
"""Configuration schema for Abstract TAP agents."""
type: str = "AbstractTAP"
agent_settings: AbstractTAPAgent.AgentSettingsSchema = Field(
default_factory=lambda: AbstractTAPAgent.AgentSettingsSchema()
)
starting_node: Optional[str] = None
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Return an action to be taken in the environment."""
return super().get_action(obs=obs, timestep=timestep)
@abstractmethod
def setup_agent(self) -> None:
"""Set up agent."""
pass
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.config.agent_settings.variance, self.config.agent_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
self.starting_node = random.choice(self.config.agent_settings.possible_starting_nodes)
self.logger.debug(f"Selected starting node: {self.starting_node}")

View File

@@ -1,31 +1,35 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from pydantic import Field
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
__all__ = "DataManipulationAgent"
class DataManipulationAgent(AbstractScriptedAgent):
class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
next_execution_timestep: int = 0
starting_node_idx: int = 0
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
"""Schema for the `agent_settings` part of the agent config."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
target_application: str = "DataManipulationBot"
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
class ConfigSchema(PeriodicAgent.ConfigSchema):
"""Configuration Schema for DataManipulationAgent."""
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
type: str = "RedDatabaseCorruptingAgent"
agent_settings: "DataManipulationAgent.AgentSettingsSchema" = Field(
default_factory=lambda: DataManipulationAgent.AgentSettingsSchema()
)
self.next_execution_timestep = timestep + random_timestep_increment
config: "DataManipulationAgent.ConfigSchema" = Field(default_factory=lambda: DataManipulationAgent.ConfigSchema())
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._set_next_execution_timestep(timestep=self.config.agent_settings.start_step, variance=0)
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
@@ -38,21 +42,14 @@ class DataManipulationAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
self.logger.debug(msg="Performing do NOTHING")
return "DONOTHING", {}
self.logger.debug(msg="Performing do nothing action")
return "do_nothing", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
self._set_next_execution_timestep(
timestep=timestep + self.config.agent_settings.frequency, variance=self.config.agent_settings.variance
)
self.logger.info(msg="Performing a data manipulation attack!")
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)
self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")
return "node_application_execute", {
"node_name": self.start_node,
"application_name": self.config.agent_settings.target_application,
}

View File

@@ -1,29 +1,28 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""Agents with predefined behaviours."""
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple
import numpy as np
import pydantic
from gymnasium.core import ObsType
from numpy.random import Generator
from pydantic import Field
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
__all__ = "ProbabilisticAgent"
class ProbabilisticAgent(AbstractScriptedAgent):
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
class Settings(pydantic.BaseModel):
"""Config schema for Probabilistic agent settings."""
rng: Generator = Field(default_factory=lambda: np.random.default_rng(np.random.randint(0, 65535)))
model_config = pydantic.ConfigDict(extra="forbid")
"""Strict validation."""
action_probabilities: Dict[int, float]
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
"""Schema for the `agent_settings` part of the agent config."""
action_probabilities: Dict[int, float] = None
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
@pydantic.field_validator("action_probabilities", mode="after")
@classmethod
@@ -44,31 +43,20 @@ class ProbabilisticAgent(AbstractScriptedAgent):
)
return v
def __init__(
self,
agent_name: str,
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
settings: Dict = {},
) -> None:
# If the action probabilities are not specified, create equal probabilities for all actions
if "action_probabilities" not in settings:
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Probabilistic Agent."""
# The random number seed for np.random is dependent on whether a random number seed is set
# in the config file. If there is one it is processed by set_random_seed() in environment.py
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
self.settings = ProbabilisticAgent.Settings(**settings)
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)
type: str = "ProbabilisticAgent"
agent_settings: "ProbabilisticAgent.AgentSettingsSchema" = Field(
default_factory=lambda: ProbabilisticAgent.AgentSettingsSchema()
)
# convert probabilities from
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema())
super().__init__(agent_name, action_space, observation_space, reward_function)
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
@property
def probabilities(self) -> Dict[str, int]:
"""Convenience method to view the probabilities of the Agent."""
return np.asarray(list(self.config.agent_settings.action_probabilities.values()))
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -1,20 +1,27 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import random
from typing import Dict, Optional, Tuple
from functools import cached_property
from typing import Dict, List, Tuple
from gymnasium.core import ObsType
from pydantic import BaseModel
from pydantic import computed_field, Field, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
__all__ = ("RandomAgent", "PeriodicAgent")
class RandomAgent(AbstractScriptedAgent):
class RandomAgent(AbstractScriptedAgent, identifier="RandomAgent"):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
config: "RandomAgent.ConfigSchema" = Field(default_factory=lambda: RandomAgent.ConfigSchema())
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Random Agents."""
type: str = "RandomAgent"
def get_action(self) -> Tuple[str, Dict]:
"""Sample the action space randomly.
:param obs: Current observation for this agent, not used in RandomAgent
@@ -27,41 +34,60 @@ class RandomAgent(AbstractScriptedAgent):
return self.action_manager.get_action(self.action_manager.space.sample())
class PeriodicAgent(AbstractScriptedAgent):
class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
class Settings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema())
start_step: int = 20
"The timestep at which an agent begins performing it's actions."
start_variance: int = 5
"Deviation around the start step."
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
"""Schema for the `agent_settings` part of the agent config."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions."
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to."
max_executions: int = 999999
"Maximum number of times the agent can execute its action."
"The amount the frequency can randomly change to"
possible_start_nodes: List[str]
target_application: str
def __init__(
self,
agent_name: str,
action_space: ActionManager,
observation_space: ObservationManager,
reward_function: RewardFunction,
settings: Optional[Settings] = None,
) -> None:
"""Initialise PeriodicAgent."""
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance).
If variance were greater than frequency, sometimes the bracketed term would be negative
and the attack would never happen again.
"""
if self.variance >= self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Periodic Agent."""
type: str = "PeriodicAgent"
"""Name of the agent."""
agent_settings: "PeriodicAgent.AgentSettingsSchema" = Field(
default_factory=lambda: PeriodicAgent.AgentSettingsSchema()
)
self.settings = settings or PeriodicAgent.Settings()
self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance)
self.num_executions = 0
max_executions: int = 999999
"Maximum number of times the agent can execute its action."
num_executions: int = 0
"""Number of times the agent has executed an action."""
next_execution_timestep: int = 0
"""Timestep of the next action execution by the agent."""
@computed_field
@cached_property
def start_node(self) -> str:
"""On instantiation, randomly select a start node."""
return random.choice(self.config.agent_settings.possible_start_nodes)
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
"""Set the next execution timestep with a configured random variance.
@@ -76,9 +102,14 @@ class PeriodicAgent(AbstractScriptedAgent):
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
if timestep == self.next_execution_timestep and self.num_executions < self.max_executions:
self.num_executions += 1
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
self._set_next_execution_timestep(
timestep + self.config.agent_settings.frequency, self.config.agent_settings.variance
)
return "node_application_execute", {
"node_name": self.start_node,
"application_name": self.config.agent_settings.target_application,
}
return "DONOTHING", {}
return "do_nothing", {}

View File

@@ -1,78 +0,0 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class TAP001(AbstractScriptedAgent):
"""
TAP001 | Mobile Malware -- Ransomware Variant.
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
next_execution_timestep: int = 0
starting_node_idx: int = 0
installed: bool = False
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute the ransomware application.
This application acts a wrapper around the kill-chain, similar to green-analyst and
the previous UC2 data manipulation bot.
:param obs: Current observation for this agent.
:type obs: ObsType
:param timestep: The current simulation timestep, used for scheduling actions
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
if not self.installed:
self.installed = True
return "NODE_APPLICATION_INSTALL", {
"node_id": self.starting_node_idx,
"application_name": "RansomwareScript",
}
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
for n, act in self.action_manager.action_map.items():
if not act[0] == "NODE_APPLICATION_INSTALL":
continue
if act[1]["node_id"] == self.starting_node_idx:
self.ip_address = act[1]["ip_address"]
return
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)

View File

@@ -7,14 +7,8 @@ import numpy as np
from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.rewards import SharedReward
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
@@ -44,7 +38,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import Software
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.ip_protocol import IPProtocol
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -258,6 +252,7 @@ class PrimaiteGame:
net = sim.network
simulation_config = cfg.get("simulation", {})
defaults_config = cfg.get("defaults", {})
network_config = simulation_config.get("network", {})
airspace_cfg = network_config.get("airspace", {})
frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {})
@@ -283,6 +278,18 @@ class PrimaiteGame:
_LOGGER.error(msg)
raise ValueError(msg)
# TODO: handle simulation defaults more cleanly
if "node_start_up_duration" in defaults_config:
new_node.start_up_duration = defaults_config["node_startup_duration"]
if "node_shut_down_duration" in defaults_config:
new_node.shut_down_duration = defaults_config["node_shut_down_duration"]
if "node_scan_duration" in defaults_config:
new_node.node_scan_duration = defaults_config["node_scan_duration"]
if "folder_scan_duration" in defaults_config:
new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"]
if "folder_restore_duration" in defaults_config:
new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"]
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
for user_cfg in node_cfg["users"]:
@@ -315,12 +322,12 @@ class PrimaiteGame:
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
new_node.software_manager.install(service_class)
new_service = new_node.software_manager.software[service_class.__name__]
# fixing duration for the service
if "fix_duration" in service_cfg.get("options", {}):
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
if "fixing_duration" in service_cfg.get("options", {}):
new_service.config.fixing_duration = service_cfg["options"]["fixing_duration"]
_set_software_listen_on_ports(new_service, service_cfg)
# start the service
@@ -329,6 +336,15 @@ class PrimaiteGame:
msg = f"Configuration contains an invalid service type: {service_type}"
_LOGGER.error(msg)
raise ValueError(msg)
# TODO: handle simulation defaults more cleanly
if "service_fix_duration" in defaults_config:
new_service.fixing_duration = defaults_config["service_fix_duration"]
if "service_restart_duration" in defaults_config:
new_service.restart_duration = defaults_config["service_restart_duration"]
if "service_install_duration" in defaults_config:
new_service.install_duration = defaults_config["service_install_duration"]
# service-dependent options
if service_type == "DNSClient":
if "options" in service_cfg:
@@ -361,74 +377,20 @@ class PrimaiteGame:
application_type = application_cfg["type"]
if application_type in Application._registry:
new_node.software_manager.install(Application._registry[application_type])
application_class = Application._registry[application_type]
application_options = application_cfg.get("options", {})
application_options["type"] = application_type
new_node.software_manager.install(application_class, software_config=application_options)
new_application = new_node.software_manager.software[application_type] # grab the instance
# fixing duration for the application
if "fix_duration" in application_cfg.get("options", {}):
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
else:
msg = f"Configuration contains an invalid application type: {application_type}"
_LOGGER.error(msg)
raise ValueError(msg)
_set_software_listen_on_ports(new_application, application_cfg)
# run the application
new_application.run()
if application_type == "DataManipulationBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
server_password=opt.get("server_password"),
payload=opt.get("payload", "DELETE"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
elif application_type == "RansomwareScript":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
server_password=opt.get("server_password"),
payload=opt.get("payload", "ENCRYPT"),
)
elif application_type == "DatabaseClient":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("db_server_ip")),
server_password=opt.get("server_password"),
)
elif application_type == "WebBrowser":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.target_url = opt.get("target_url")
elif application_type == "DoSBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
target_ip_address=IPv4Address(opt.get("target_ip_address")),
target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")],
payload=opt.get("payload"),
repeat=bool(opt.get("repeat")),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
dos_intensity=float(opt.get("dos_intensity", "1.0")),
max_sessions=int(opt.get("max_sessions", "1000")),
)
elif application_type == "C2Beacon":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")),
keep_alive_frequency=(opt.get("keep_alive_frequency", 5)),
masquerade_protocol=PROTOCOL_LOOKUP[
(opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"]))
],
masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))],
)
if "network_interfaces" in node_cfg:
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
@@ -470,76 +432,10 @@ class PrimaiteGame:
agents_cfg = cfg.get("agents", [])
for agent_cfg in agents_cfg:
agent_ref = agent_cfg["ref"] # noqa: F841
agent_type = agent_cfg["type"]
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg)
# CREATE ACTION SPACE
action_space = ActionManager.from_config(game, action_space_cfg)
# CREATE REWARD FUNCTION
reward_function = RewardFunction.from_config(reward_function_cfg)
# CREATE AGENT
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing
settings = agent_cfg.get("agent_settings", {})
new_agent = ProbabilisticAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
settings=settings,
)
elif agent_type == "PeriodicAgent":
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
new_agent = PeriodicAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
settings=settings,
)
elif agent_type == "ProxyAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
game.rl_agents[agent_cfg["ref"]] = new_agent
elif agent_type == "RedDatabaseCorruptingAgent":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
elif agent_type == "TAP001":
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
new_agent = TAP001(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=reward_function,
agent_settings=agent_settings,
)
else:
msg = f"Configuration error: {agent_type} is not a valid agent type."
_LOGGER.error(msg)
raise ValueError(msg)
new_agent = AbstractAgent.from_config(agent_cfg)
game.agents[agent_cfg["ref"]] = new_agent
if isinstance(new_agent, ProxyAgent):
game.rl_agents[agent_cfg["ref"]] = new_agent
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()

View File

@@ -11,6 +11,15 @@
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -19,7 +28,7 @@
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable\n"
"from prettytable import PrettyTable"
]
},
{
@@ -195,7 +204,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -209,7 +218,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -51,24 +51,15 @@
" - ref: CustomC2Agent\n",
" team: RED\n",
" type: ProxyAgent\n",
" observation_space: null\n",
"\n",
" action_space:\n",
" action_list:\n",
" - type: DONOTHING\n",
" - type: NODE_APPLICATION_INSTALL\n",
" - type: NODE_APPLICATION_EXECUTE\n",
" - type: CONFIGURE_C2_BEACON\n",
" - type: C2_SERVER_RANSOMWARE_LAUNCH\n",
" - type: C2_SERVER_RANSOMWARE_CONFIGURE\n",
" - type: C2_SERVER_TERMINAL_COMMAND\n",
" - type: C2_SERVER_DATA_EXFILTRATE\n",
" options:\n",
" nodes:\n",
" - node_name: web_server\n",
" applications: \n",
" applications:\n",
" - application_name: C2Beacon\n",
" - node_name: client_1\n",
" applications: \n",
" applications:\n",
" - application_name: C2Server\n",
" max_folders_per_node: 1\n",
" max_files_per_folder: 1\n",
@@ -82,15 +73,15 @@
" - 0.0.0.1\n",
" action_map:\n",
" 0:\n",
" action: DONOTHING\n",
" action: do_nothing\n",
" options: {}\n",
" 1:\n",
" action: NODE_APPLICATION_INSTALL\n",
" action: node_application_install\n",
" options:\n",
" node_id: 0\n",
" application_name: C2Beacon\n",
" 2:\n",
" action: CONFIGURE_C2_BEACON\n",
" action: configure_c2_beacon\n",
" options:\n",
" node_id: 0\n",
" config:\n",
@@ -99,12 +90,12 @@
" masquerade_protocol:\n",
" masquerade_port:\n",
" 3:\n",
" action: NODE_APPLICATION_EXECUTE\n",
" action: node_application_execute\n",
" options:\n",
" node_id: 0\n",
" application_id: 0 \n",
" application_id: 0\n",
" 4:\n",
" action: C2_SERVER_TERMINAL_COMMAND\n",
" action: c2_server_terminal_command\n",
" options:\n",
" node_id: 1\n",
" ip_address:\n",
@@ -112,20 +103,20 @@
" username: admin\n",
" password: admin\n",
" commands:\n",
" - \n",
" -\n",
" - software_manager\n",
" - application\n",
" - install\n",
" - RansomwareScript\n",
" 5:\n",
" action: C2_SERVER_RANSOMWARE_CONFIGURE\n",
" action: c2_server_ransomware_configure\n",
" options:\n",
" node_id: 1\n",
" config:\n",
" server_ip_address: 192.168.1.14\n",
" payload: ENCRYPT\n",
" 6:\n",
" action: C2_SERVER_DATA_EXFILTRATE\n",
" action: c2_server_data_exfiltrate\n",
" options:\n",
" node_id: 1\n",
" target_file_name: \"database.db\"\n",
@@ -134,14 +125,14 @@
" target_ip_address: 192.168.1.14\n",
" account:\n",
" username: admin\n",
" password: admin \n",
" password: admin\n",
"\n",
" 7:\n",
" action: C2_SERVER_RANSOMWARE_LAUNCH\n",
" action: c2_server_ransomware_launch\n",
" options:\n",
" node_id: 1\n",
" 8:\n",
" action: CONFIGURE_C2_BEACON\n",
" action: configure_c2_beacon\n",
" options:\n",
" node_id: 0\n",
" config:\n",
@@ -150,7 +141,7 @@
" masquerade_protocol: TCP\n",
" masquerade_port: DNS\n",
" 9:\n",
" action: CONFIGURE_C2_BEACON\n",
" action: configure_c2_beacon\n",
" options:\n",
" node_id: 0\n",
" config:\n",
@@ -177,7 +168,7 @@
" # removing all agents & adding the custom agent.\n",
" cfg['agents'] = {}\n",
" cfg['agents'] = c2_agent_yaml\n",
" \n",
"\n",
"\n",
"env = PrimaiteGymEnv(env_config=cfg)"
]
@@ -222,7 +213,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_INSTALL\n",
"### **Command and Control** | C2 Beacon Actions | node_application_install\n",
"\n",
"The custom proxy red agent defined at the start of this notebook has been configured to install the C2 Beacon as action ``1`` in it's action map. \n",
"\n",
@@ -230,10 +221,6 @@
"\n",
"```yaml\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: NODE_APPLICATION_INSTALL\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: web_server\n",
@@ -243,7 +230,7 @@
" ...\n",
" action_map:\n",
" 1:\n",
" action: NODE_APPLICATION_INSTALL \n",
" action: node_application_install \n",
" options:\n",
" node_id: 0 # Index 0 at the node list.\n",
" application_name: C2Beacon\n",
@@ -265,7 +252,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Command and Control** | C2 Beacon Actions | CONFIGURE_C2_BEACON \n",
"### **Command and Control** | C2 Beacon Actions | configure_c2_beacon \n",
"\n",
"The custom proxy red agent defined at the start of this notebook can configure the C2 Beacon via action ``2`` in it's action map. \n",
"\n",
@@ -273,10 +260,6 @@
"\n",
"```yaml\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: CONFIGURE_C2_BEACON\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: web_server\n",
@@ -285,7 +268,7 @@
" action_map:\n",
" ...\n",
" 2:\n",
" action: CONFIGURE_C2_BEACON\n",
" action: configure_c2_beacon\n",
" options:\n",
" node_id: 0 # Node Index\n",
" config: # Further information about these config options can be found at the bottom of this notebook.\n",
@@ -312,18 +295,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_EXECUTE\n",
"### **Command and Control** | C2 Beacon Actions | node_application_execute\n",
"\n",
"The final action is ``NODE_APPLICATION_EXECUTE`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n",
"The final action is ``node_application_execute`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n",
"\n",
"The yaml snippet below shows all the relevant agent options for this action:\n",
"\n",
"```yaml\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: NODE_APPLICATION_EXECUTE\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" - node_name: web_server\n",
@@ -334,7 +313,7 @@
" action_map:\n",
" ...\n",
" 3:\n",
" action: NODE_APPLICATION_EXECUTE\n",
" action: node_application_execute\n",
" options:\n",
" node_id: 0\n",
" application_id: 0\n",
@@ -347,7 +326,7 @@
"metadata": {},
"outputs": [],
"source": [
"env.step(3) "
"env.step(3)"
]
},
{
@@ -390,10 +369,6 @@
"\n",
"``` yaml\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: C2_SERVER_TERMINAL_COMMAND\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" ...\n",
@@ -441,7 +416,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_CONFIGURE\n",
"### **Command and Control** | C2 Server Actions | c2_server_ransomware_configure\n",
"\n",
"Another action the C2 Server grants is the ability for a Red Agent to configure the RansomwareScript via the C2 Server rather than the note directly.\n",
"\n",
@@ -451,10 +426,6 @@
"\n",
"``` yaml\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: C2_SERVER_RANSOMWARE_CONFIGURE\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" ...\n",
@@ -464,7 +435,7 @@
" ...\n",
" action_map:\n",
" 5:\n",
" action: C2_SERVER_RANSOMWARE_CONFIG\n",
" action: c2_server_ransomware_configure\n",
" options:\n",
" node_id: 1\n",
" config:\n",
@@ -497,9 +468,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Command and Control** | C2 Server Actions | C2_SERVER_DATA_EXFILTRATE\n",
"### **Command and Control** | C2 Server Actions | c2_server_data_exfiltrate\n",
"\n",
"The second to last action available is the ``C2_SERVER_DATA_EXFILTRATE`` which is indexed as action ``6`` in the action map.\n",
"The second to last action available is the ``c2_server_data_exfiltrate`` which is indexed as action ``6`` in the action map.\n",
"\n",
"This action can be used to exfiltrate a target file on a remote node to the C2 Beacon and the C2 Server's host file system via the ``FTP`` services.\n",
"\n",
@@ -507,10 +478,6 @@
"\n",
"``` yaml\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: C2_SERVER_DATA_EXFILTRATE\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" ...\n",
@@ -520,7 +487,7 @@
" ...\n",
" action_map:\n",
" 6:\n",
" action: C2_SERVER_DATA_EXFILTRATE\n",
" action: c2_server_data_exfiltrate\n",
" options:\n",
" node_id: 1\n",
" target_file_name: \"database.db\"\n",
@@ -567,9 +534,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_LAUNCH\n",
"### **Command and Control** | C2 Server Actions | c2_server_ransomware_launch\n",
"\n",
"Finally, the last available action is for the C2_SERVER_RANSOMWARE_LAUNCH to start the ransomware script installed on the same node as the C2 beacon.\n",
"Finally, the last available action is for the c2_server_ransomware_launch to start the ransomware script installed on the same node as the C2 beacon.\n",
"\n",
"This action is indexed as action ``7``.\n",
"\n",
@@ -577,10 +544,6 @@
"\n",
"``` yaml\n",
" action_space:\n",
" action_list:\n",
" ...\n",
" - type: C2_SERVER_RANSOMWARE_LAUNCH\n",
" ...\n",
" options:\n",
" nodes: # Node List\n",
" ...\n",
@@ -590,7 +553,7 @@
" ...\n",
" action_map:\n",
" 7:\n",
" action: C2_SERVER_RANSOMWARE_LAUNCH\n",
" action: c2_server_ransomware_launch\n",
" options:\n",
" node_id: 1\n",
"```\n"
@@ -632,7 +595,7 @@
"metadata": {},
"outputs": [],
"source": [
"custom_blue_agent_yaml = \"\"\" \n",
"custom_blue_agent_yaml = \"\"\"\n",
" - ref: defender\n",
" team: BLUE\n",
" type: ProxyAgent\n",
@@ -715,28 +678,23 @@
" - type: \"NONE\"\n",
" label: ICS\n",
" options: {}\n",
" \n",
"\n",
" action_space:\n",
" action_list:\n",
" - type: NODE_APPLICATION_REMOVE\n",
" - type: NODE_SHUTDOWN\n",
" - type: ROUTER_ACL_ADDRULE\n",
" - type: DONOTHING\n",
" action_map:\n",
" 0:\n",
" action: DONOTHING\n",
" action: do_nothing\n",
" options: {}\n",
" 1:\n",
" action: NODE_APPLICATION_REMOVE\n",
" action: node_application_remove\n",
" options:\n",
" node_id: 0\n",
" application_name: C2Beacon\n",
" 2:\n",
" action: NODE_SHUTDOWN\n",
" action: node_shutdown\n",
" options:\n",
" node_id: 0\n",
" 3:\n",
" action: ROUTER_ACL_ADDRULE\n",
" action: router_acl_add_rule\n",
" options:\n",
" target_router: router_1\n",
" position: 1\n",
@@ -747,7 +705,7 @@
" dest_port_id: 2\n",
" protocol_id: 1\n",
" source_wildcard_id: 0\n",
" dest_wildcard_id: 0 \n",
" dest_wildcard_id: 0\n",
"\n",
"\n",
" options:\n",
@@ -796,7 +754,7 @@
" # removing all agents & adding the custom agent.\n",
" cfg['agents'] = {}\n",
" cfg['agents'] = custom_blue\n",
" \n",
"\n",
"\n",
"blue_env = PrimaiteGymEnv(env_config=cfg)"
]
@@ -1121,7 +1079,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The code cell below uses the custom blue agent defined at the start of this section perform a NODE_APPLICATION_REMOVE on the C2 beacon:"
"The code cell below uses the custom blue agent defined at the start of this section perform a node_application_remove on the C2 beacon:"
]
},
{
@@ -1130,7 +1088,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Using CAOS ACTION: NODE_APPLICATION_REMOVE & capturing the OBS\n",
"# Using CAOS ACTION: node_application_remove & capturing the OBS\n",
"post_blue_action_obs, _, _, _, _ = blue_env.step(1)"
]
},
@@ -1216,7 +1174,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The code cell below uses the custom blue agent defined at the start of this section to perform a ``NODE_SHUT_DOWN`` action on the web server."
"The code cell below uses the custom blue agent defined at the start of this section to perform a ``node_shut_down`` action on the web server."
]
},
{
@@ -1225,7 +1183,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Using CAOS ACTION: NODE_SHUT_DOWN & capturing the OBS\n",
"# Using CAOS ACTION: node_shut_down & capturing the OBS\n",
"post_blue_action_obs, _, _, _, _ = blue_env.step(2)"
]
},
@@ -1306,7 +1264,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The code cell below uses the custom blue agent defined at the start of this section to perform a ROUTER_ACL_ADDRULE on router 1."
"The code cell below uses the custom blue agent defined at the start of this section to perform a router_acl_add_rule on router 1."
]
},
{
@@ -1315,7 +1273,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Using CAOS ACTION: ROUTER_ACL_ADDRULE & capturing the OBS\n",
"# Using CAOS ACTION: router_acl_add_rule & capturing the OBS\n",
"post_blue_action_obs, _, _, _, _ = blue_env.step(3)"
]
},
@@ -1429,11 +1387,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As demonstrated earlier, red agents can use the ``CONFIGURE_C2_BEACON`` action to configure these settings mid episode through the configuration options:\n",
"As demonstrated earlier, red agents can use the ``configure_c2_beacon`` action to configure these settings mid episode through the configuration options:\n",
"\n",
"``` YAML\n",
"...\n",
" action: CONFIGURE_C2_BEACON\n",
" action: configure_c2_beacon\n",
" options:\n",
" node_id: 0\n",
" config:\n",
@@ -1468,7 +1426,7 @@
" # removing all agents & adding the custom agent.\n",
" cfg['agents'] = {}\n",
" cfg['agents'] = c2_agent_yaml\n",
" \n",
"\n",
"\n",
"c2_config_env = PrimaiteGymEnv(env_config=cfg)"
]
@@ -1555,7 +1513,7 @@
"source": [
"for i in range(6):\n",
" env.step(0)\n",
" \n",
"\n",
"c2_server_1.show()"
]
},
@@ -1676,7 +1634,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Comparing the OBS of the default frequency to a timestep frequency of 1 \n",
"# Comparing the OBS of the default frequency to a timestep frequency of 1\n",
"for i in range(2):\n",
" keep_alive_obs, _, _, _, _ = blue_config_env.step(0)\n",
" display_obs_diffs(default_obs, keep_alive_obs, blue_config_env.game.step_counter)"
@@ -1760,7 +1718,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Capturing default C2 Traffic \n",
"# Capturing default C2 Traffic\n",
"for i in range(3):\n",
" tcp_c2_obs, _, _, _, _ = blue_config_env.step(0)\n",
"\n",

View File

@@ -15,6 +15,15 @@
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -67,9 +76,9 @@
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" if red_action == 'do_nothing':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
" elif red_action == 'node_application_execute':\n",
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str"
@@ -147,12 +156,7 @@
" nodes: {}\n",
"\n",
" action_space:\n",
"\n",
" # The agent has two action choices, either do nothing, or execute a pre-scripted attack by using \n",
" action_list:\n",
" - type: DONOTHING\n",
" - type: NODE_APPLICATION_EXECUTE\n",
"\n",
" \n",
" # The agent has access to the DataManipulationBoth on clients 1 and 2.\n",
" options:\n",
" nodes:\n",
@@ -306,19 +310,9 @@
"outputs": [],
"source": [
"change = yaml.safe_load(\"\"\"\n",
"action_space:\n",
" action_list:\n",
" - type: DONOTHING\n",
" - type: NODE_APPLICATION_EXECUTE\n",
" options:\n",
" nodes:\n",
" - node_name: client_1\n",
" applications:\n",
" - application_name: DataManipulationBot\n",
" max_folders_per_node: 1\n",
" max_files_per_folder: 1\n",
" max_services_per_node: 1\n",
"# TODO:\n",
"\"\"\")\n",
"#TODO 2869 fix\n",
"\n",
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
@@ -444,7 +438,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},

View File

@@ -165,13 +165,13 @@
"\n",
"| node_id | node name |\n",
"|---------|------------------|\n",
"| 1 | domain_controller|\n",
"| 2 | web_server |\n",
"| 3 | database_server |\n",
"| 4 | backup_server |\n",
"| 5 | security_suite |\n",
"| 6 | client_1 |\n",
"| 7 | client_2 |\n",
"| 0 | domain_controller|\n",
"| 1 | web_server |\n",
"| 2 | database_server |\n",
"| 3 | backup_server |\n",
"| 4 | security_suite |\n",
"| 5 | client_1 |\n",
"| 6 | client_2 |\n",
"\n",
"Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n",
"\n",
@@ -371,6 +371,15 @@
"First, load the required modules"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -449,9 +458,9 @@
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" if red_action == 'do_nothing':\n",
" red_str = 'DO NOTHING'\n",
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
" elif red_action == 'node_application_execute':\n",
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
" red_str = f\"ATTACK from {client}\"\n",
" return red_str"
@@ -547,7 +556,7 @@
"\n",
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 0.9 when both green agents make successful requests.\n",
"\n",
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
"Run the following cell until the green action is `node_application_execute` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
]
},
{

View File

@@ -9,6 +9,15 @@
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -201,7 +201,7 @@
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
"]\n",
"game.simulation.apply_request(caos_action)"
]
@@ -259,7 +259,7 @@
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n",
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n",
"]\n",
"game.simulation.apply_request(caos_action)"
]
@@ -396,7 +396,7 @@
"source": [
"caos_action = [\n",
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
"]\n",
"game.simulation.apply_request(caos_action)"
]

View File

@@ -25,6 +25,15 @@
"Let's set up a minimal network simulation and send some requests to see how it works."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -18,6 +18,15 @@
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -18,6 +18,15 @@
"#### First, Import packages and read our config file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -32,8 +41,6 @@
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
"with open(PRIMAITE_PATHS.user_config_path / 'example_config/data_manipulation_marl.yaml', 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",

View File

@@ -11,6 +11,15 @@
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -95,7 +104,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -109,7 +118,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -18,6 +18,15 @@
"#### First, we import the inital packages and read in our configuration file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -168,7 +177,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -182,7 +191,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -238,7 +238,7 @@
"### Episode 2\n",
"When we reset the environment again, it moves onto episode 2, where it will bring in greens_1 and reds_1 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
"\n",
"Most green actions will be `NODE_APPLICATION_EXECUTE` while red will `DONOTHING` except at steps 10 and 20."
"Most green actions will be `node_application_execute` while red will `DONOTHING` except at steps 10 and 20."
]
},
{
@@ -269,7 +269,7 @@
"### Episode 3\n",
"When we reset the environment again, it moves onto episode 3, where it will bring in greens_2 and reds_2 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
"\n",
"Now, green will perform `NODE_APPLICATION_EXECUTE` only 5% of the time, while red will perform `NODE_APPLICATION_EXECUTE` more frequently than before."
"Now, green will perform `node_application_execute` only 5% of the time, while red will perform `node_application_execute` more frequently than before."
]
},
{

Binary file not shown.

Before

Width:  |  Height:  |  Size: 110 KiB

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 69 KiB

After

Width:  |  Height:  |  Size: 69 KiB

View File

@@ -18,6 +18,15 @@
"Import packages and read config file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -89,7 +89,7 @@ class PrimaiteGymEnv(gymnasium.Env):
:return: Action mask
:rtype: List[bool]
"""
if not self.agent.action_masking:
if not self.agent.config.agent_settings.action_masking:
return np.asarray([True] * len(self.agent.action_manager.action_map))
else:
return self.game.action_mask(self._agent_name)

View File

@@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
)
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
if agent.action_masking:
if agent.config.agent_settings.action_masking:
self.observation_space[agent_name] = spaces.Dict(
{
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
@@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
if agent.action_masking:
if agent.config.agent_settings.action_masking:
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
else:
all_obs[agent_name] = obs
@@ -168,7 +168,7 @@ class PrimaiteRayEnv(gymnasium.Env):
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
if self.env.agent.action_masking:
if self.env.agent.config.agent_settings.action_masking:
self.observation_space = spaces.Dict(
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
)
@@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
if self.env.agent.action_masking:
if self.env.agent.config.agent_settings.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
return new_obs, *_
@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
# if action masking is enabled, intercept the step method and add action mask to observation
if self.env.agent.action_masking:
if self.env.agent.config.agent_settings.action_masking:
obs, *_ = self.env.step(action)
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
return new_obs, *_

View File

@@ -264,7 +264,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -664,7 +664,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@@ -130,8 +130,8 @@ class File(FileSystemItemABC):
Return False if corruption is detected, otherwise True
"""
warnings.warn("NODE_FILE_CHECKHASH is currently not implemented.")
self.sys_log.warning("NODE_FILE_CHECKHASH is currently not implemented.")
warnings.warn("node_file_checkhash is currently not implemented.")
self.sys_log.warning("node_file_checkhash is currently not implemented.")
return False
if self.deleted:

View File

@@ -30,6 +30,11 @@ class FileSystem(SimComponent):
num_file_deletions: int = 0
"Number of file deletions in the current step."
_default_folder_scan_duration: Optional[int] = None
"Override default scan duration for folders"
_default_folder_restore_duration: Optional[int] = None
"Override default restore duration for folders"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Ensure a default root folder
@@ -258,6 +263,11 @@ class FileSystem(SimComponent):
name=folder.name, request_type=RequestType(func=folder._request_manager)
)
self.folders[folder.uuid] = folder
# set the folder scan and restore durations.
if self._default_folder_scan_duration is not None:
folder.scan_duration = self._default_folder_scan_duration
if self._default_folder_restore_duration is not None:
folder.restore_duration = self._default_folder_restore_duration
return folder
def delete_folder(self, folder_name: str) -> bool:

View File

@@ -43,6 +43,9 @@ def convert_size(size_bytes: int) -> str:
class FileSystemItemHealthStatus(Enum):
"""Status of the FileSystemItem."""
NONE = 0
"""File system item health status is not known."""
GOOD = 1
"""File/Folder is OK."""
@@ -72,7 +75,7 @@ class FileSystemItemABC(SimComponent):
health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
"Actual status of the current FileSystemItem"
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.NONE
"Visible status of the current FileSystemItem"
previous_hash: Optional[str] = None

View File

@@ -46,7 +46,7 @@ class Folder(FileSystemItemABC):
:param sys_log: The SysLog instance to us to create system logs.
"""
super().__init__(**kwargs)
self._scanned_this_step: bool = False
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def _init_request_manager(self) -> RequestManager:
@@ -83,6 +83,7 @@ class Folder(FileSystemItemABC):
state = super().describe_state()
state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()}
state["deleted_files"] = {file.name: file.describe_state() for uuid, file in self.deleted_files.items()}
state["scanned_this_step"] = self._scanned_this_step
return state
def show(self, markdown: bool = False):
@@ -135,7 +136,7 @@ class Folder(FileSystemItemABC):
def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self._scanned_this_step = False
for file in self.files.values():
file.pre_timestep(timestep)
@@ -148,9 +149,17 @@ class Folder(FileSystemItemABC):
for file_id in self.files:
file = self.get_file_by_id(file_uuid=file_id)
file.scan()
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.CORRUPT
# set folder health to worst file's health by generating a list of file healths. If no files, use 0
self.health_status = FileSystemItemHealthStatus(
max(
[f.health_status.value for f in self.files.values()]
or [
0,
]
)
)
self.visible_health_status = self.health_status
self._scanned_this_step = True
def _reveal_to_red_timestep(self) -> None:
"""Apply reveal to red timestep."""
@@ -387,8 +396,8 @@ class Folder(FileSystemItemABC):
Return False if corruption is detected, otherwise True
"""
warnings.warn("NODE_FOLDER_CHECKHASH is currently not implemented.")
self.sys_log.error("NODE_FOLDER_CHECKHASH is currently not implemented.")
warnings.warn("node_folder_checkhash is currently not implemented.")
self.sys_log.error("node_folder_checkhash is currently not implemented.")
return False
if self.deleted:

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Literal, Type
from typing import Any, ClassVar, Dict, Literal, Optional, Type
from pydantic import BaseModel, model_validator
@@ -49,7 +49,7 @@ class NetworkNodeAdder(BaseModel):
_registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {}
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str], **kwargs: Any) -> None:
"""
Register a network node adder class.
@@ -58,6 +58,8 @@ class NetworkNodeAdder(BaseModel):
:raises ValueError: When attempting to register a name that is already reserved.
"""
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Duplicate node adder {identifier}")
cls._registry[identifier] = cls

View File

@@ -824,7 +824,7 @@ class User(SimComponent):
return self.model_dump()
class UserManager(Service):
class UserManager(Service, identifier="UserManager"):
"""
Manages users within the PrimAITE system, handling creation, authentication, and administration.
@@ -833,11 +833,18 @@ class UserManager(Service):
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for UserManager."""
type: str = "UserManager"
config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema())
users: Dict[str, User] = {}
def __init__(self, **kwargs):
"""
Initializes a UserManager instanc.
Initializes a UserManager instance.
:param username: The username for the default admin user
:param password: The password for the default admin user
@@ -1130,13 +1137,20 @@ class RemoteUserSession(UserSession):
return state
class UserSessionManager(Service):
class UserSessionManager(Service, identifier="UserSessionManager"):
"""
Manages user sessions on a Node, including local and remote sessions.
This class handles authentication, session management, and session timeouts for users interacting with the Node.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for UserSessionManager."""
type: str = "UserSessionManager"
config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema())
local_session: Optional[UserSession] = None
"""The current local user session, if any."""
@@ -1554,7 +1568,6 @@ class Node(SimComponent, ABC):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
@classmethod
def from_config(cls, config: Dict) -> "Node":
"""Create Node object from a given configuration dictionary."""
@@ -1564,7 +1577,7 @@ class Node(SimComponent, ABC):
obj = cls(config=cls.ConfigSchema(**config))
return obj
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register a node type.
@@ -1572,10 +1585,10 @@ class Node(SimComponent, ABC):
:type identifier: str
:raises ValueError: When attempting to register an node with a name that is already allocated.
"""
if identifier == "default":
super().__init_subclass__(**kwargs)
if identifier is None:
return
identifier = identifier.lower()
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls

View File

@@ -1,10 +1,12 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, Optional, Set, Type
from pydantic import Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
@@ -21,13 +23,20 @@ class ApplicationOperatingState(Enum):
"The application is being installed or updated."
class Application(IOSoftware):
class Application(IOSoftware, ABC):
"""
Represents an Application in the simulation environment.
Applications are user-facing programs that may perform input/output operations.
"""
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
"""Config Schema for Application class."""
type: str
config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema())
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
"The current operating state of the Application."
execution_control_status: str = "manual"
@@ -44,21 +53,36 @@ class Application(IOSoftware):
_registry: ClassVar[Dict[str, Type["Application"]]] = {}
"""Registry of application types. Automatically populated when subclasses are defined."""
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
"""
Register an application type.
:param identifier: Uniquely specifies an application class by name. Used for finding items by config.
:type identifier: str
:type identifier: Optional[str]
:raises ValueError: When attempting to register an application with a name that is already allocated.
"""
if identifier == "default":
return
super().__init_subclass__(**kwargs)
if identifier is None:
return
if identifier in cls._registry:
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
cls._registry[identifier] = cls
@classmethod
def from_config(cls, config: Dict) -> "Application":
"""Create an application from a config dictionary.
:param config: dict of options for application components constructor
:type config: dict
:return: The application component.
:rtype: Application
"""
if config["type"] not in cls._registry:
raise ValueError(f"Invalid Application type {config['type']}")
application_class = cls._registry[config["type"]]
application_object = application_class(config=application_class.ConfigSchema(**config))
return application_object
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel
from pydantic import BaseModel, Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -67,11 +67,19 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
Database Service. It mainly operates over TCP protocol.
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for DatabaseClient."""
type: str = "DatabaseClient"
db_server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema())
server_ip_address: Optional[IPv4Address] = None
"""The IPv4 address of the Database Service server, defaults to None."""
server_password: Optional[str] = None
_query_success_tracker: Dict[str, bool] = {}
"""Keep track of connections that were established or verified during this step. Used for rewards."""
@@ -93,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self.server_ip_address = self.config.db_server_ip
self.server_password = self.config.server_password
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -3,7 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
from prettytable import PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -52,6 +52,13 @@ class NMAP(Application, identifier="NMAP"):
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for NMAP."""
type: str = "NMAP"
config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema())
_active_port_scans: Dict[str, PortScanPayload] = {}
_port_scan_responses: Dict[str, PortScanPayload] = {}

View File

@@ -2,9 +2,9 @@
from abc import abstractmethod
from enum import Enum
from ipaddress import IPv4Address
from typing import Dict, Optional, Union
from typing import Dict, Optional, Set, Union
from pydantic import BaseModel, Field, validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.file_system.file_system import FileSystem, Folder
@@ -45,10 +45,10 @@ class C2Payload(Enum):
"""C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon."""
OUTPUT = "output_command"
"""C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server."""
"""C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server."""
class AbstractC2(Application, identifier="AbstractC2"):
class AbstractC2(Application):
"""
An abstract command and control (c2) application.
@@ -60,9 +60,25 @@ class AbstractC2(Application, identifier="AbstractC2"):
Defaults to masquerading as HTTP (Port 80) via TCP.
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
"""
class ConfigSchema(Application.ConfigSchema):
"""Configuration for AbstractC2."""
keep_alive_frequency: int = Field(default=5, ge=1)
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema())
c2_connection_active: bool = False
"""Indicates if the c2 server and c2 beacon are currently connected."""
@@ -75,19 +91,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
keep_alive_inactivity: int = 0
"""Indicates how many timesteps since the last time the c2 application received a keep alive."""
class _C2Opts(BaseModel):
"""A Pydantic Schema for the different C2 configuration options."""
keep_alive_frequency: int = Field(default=5, ge=1)
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
c2_config: _C2Opts = _C2Opts()
"""
Holds the current configuration settings of the C2 Suite.
@@ -100,6 +103,12 @@ class AbstractC2(Application, identifier="AbstractC2"):
C2 beacon to reconfigure it's configuration settings.
"""
def __init__(self, **kwargs):
"""Initialise the C2 applications to by default listen for HTTP traffic."""
kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
def _craft_packet(
self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {}
) -> C2Packet:
@@ -118,13 +127,13 @@ class AbstractC2(Application, identifier="AbstractC2"):
:type c2_command: C2Command.
:param command_options: The relevant C2 Beacon parameters.F
:type command_options: Dict
:return: Returns the construct C2Packet
:return: Returns the constructed C2Packet
:rtype: C2Packet
"""
constructed_packet = C2Packet(
masquerade_protocol=self.c2_config.masquerade_protocol,
masquerade_port=self.c2_config.masquerade_port,
keep_alive_frequency=self.c2_config.keep_alive_frequency,
masquerade_protocol=self.config.masquerade_protocol,
masquerade_port=self.config.masquerade_port,
keep_alive_frequency=self.config.keep_alive_frequency,
payload_type=c2_payload,
command=c2_command,
payload=command_options,
@@ -140,13 +149,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
"""
return super().describe_state()
def __init__(self, **kwargs):
"""Initialise the C2 applications to by default listen for HTTP traffic."""
kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
@property
def _host_ftp_client(self) -> Optional[FTPClient]:
"""Return the FTPClient that is installed C2 Application's host.
@@ -330,8 +332,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
if self.send(
payload=keep_alive_packet,
dest_ip_address=self.c2_remote_connection,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
dest_port=self.config.masquerade_port,
ip_protocol=self.config.masquerade_protocol,
session_id=session_id,
):
# Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms.
@@ -340,8 +342,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}")
self.sys_log.debug(
f"{self.name}: Keep Alive sent to {self.c2_remote_connection} "
f"Masquerade Port: {self.c2_config.masquerade_port} "
f"Masquerade Protocol: {self.c2_config.masquerade_protocol} "
f"Masquerade Port: {self.config.masquerade_port} "
f"Masquerade Protocol: {self.config.masquerade_protocol} "
)
return True
else:
@@ -376,15 +378,15 @@ class AbstractC2(Application, identifier="AbstractC2"):
# Updating the C2 Configuration attribute.
self.c2_config.masquerade_port = payload.masquerade_port
self.c2_config.masquerade_protocol = payload.masquerade_protocol
self.c2_config.keep_alive_frequency = payload.keep_alive_frequency
self.config.masquerade_port = payload.masquerade_port
self.config.masquerade_protocol = payload.masquerade_protocol
self.config.keep_alive_frequency = payload.keep_alive_frequency
self.sys_log.debug(
f"{self.name}: C2 Config Resolved Config from Keep Alive:"
f"Masquerade Port: {self.c2_config.masquerade_port}"
f"Masquerade Protocol: {self.c2_config.masquerade_protocol}"
f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}"
f"Masquerade Port: {self.config.masquerade_port}"
f"Masquerade Protocol: {self.config.masquerade_protocol}"
f"Keep Alive Frequency: {self.config.keep_alive_frequency}"
)
# This statement is intended to catch on the C2 Application that is listening for connection.
@@ -410,8 +412,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
self.keep_alive_inactivity = 0
self.keep_alive_frequency = 5
self.c2_remote_connection = None
self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"]
self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
self.config.masquerade_port = PORT_LOOKUP["HTTP"]
self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
@abstractmethod
def _confirm_remote_connection(self, timestep: int) -> bool:

View File

@@ -3,7 +3,7 @@ from ipaddress import IPv4Address
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
class C2Beacon(AbstractC2, identifier="C2Beacon"):
@@ -32,15 +33,30 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
2. Leveraging the terminal application to execute requests (dependent on the command given)
3. Sending the RequestResponse back to the C2 Server (Command output)
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
"""
class ConfigSchema(AbstractC2.ConfigSchema):
"""ConfigSchema for C2Beacon."""
type: str = "C2Beacon"
c2_server_ip_address: Optional[IPV4Address] = None
keep_alive_frequency: int = 5
masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"]
masquerade_port: Port = PORT_LOOKUP["HTTP"]
config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema())
keep_alive_attempted: bool = False
"""Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms."""
terminal_session: TerminalClientConnection = None
"The currently in use terminal session."
def __init__(self, **kwargs):
kwargs["name"] = "C2Beacon"
super().__init__(**kwargs)
@property
def _host_terminal(self) -> Optional[Terminal]:
"""Return the Terminal that is installed on the same machine as the C2 Beacon."""
@@ -119,10 +135,6 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
rm.add_request("configure", request_type=RequestType(func=_configure))
return rm
def __init__(self, **kwargs):
kwargs["name"] = "C2Beacon"
super().__init__(**kwargs)
# Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager.
@validate_call
def configure(
@@ -146,7 +158,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
masquerade_port | What port should the C2 traffic use? (TCP or UDP)
These configuration options are used to reassign the fields in the inherited inner class
``c2_config``.
``config``.
If a connection is already in progress then this method also sends a keep alive to the C2
Server in order for the C2 Server to sync with the new configuration settings.
@@ -162,9 +174,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
:return: Returns True if the configuration was successful, False otherwise.
"""
self.c2_remote_connection = IPv4Address(c2_server_ip_address)
self.c2_config.keep_alive_frequency = keep_alive_frequency
self.c2_config.masquerade_port = masquerade_port
self.c2_config.masquerade_protocol = masquerade_protocol
self.config.keep_alive_frequency = keep_alive_frequency
self.config.masquerade_port = masquerade_port
self.config.masquerade_protocol = masquerade_protocol
self.sys_log.info(
f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}."
)
@@ -263,14 +275,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
if self.send(
payload=output_packet,
dest_ip_address=self.c2_remote_connection,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
dest_port=self.config.masquerade_port,
ip_protocol=self.config.masquerade_protocol,
session_id=session_id,
):
self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}")
self.sys_log.debug(
f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}"
)
self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}")
return True
else:
self.sys_log.warning(
@@ -562,7 +572,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
:rtype bool:
"""
self.keep_alive_attempted = False # Resetting keep alive sent.
if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency:
if self.keep_alive_inactivity == self.config.keep_alive_frequency:
self.sys_log.info(
f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}."
)
@@ -627,9 +637,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
self.c2_connection_active,
self.c2_remote_connection,
self.keep_alive_inactivity,
self.c2_config.keep_alive_frequency,
self.c2_config.masquerade_protocol,
self.c2_config.masquerade_port,
self.config.keep_alive_frequency,
self.config.masquerade_protocol,
self.config.masquerade_port,
]
)
print(table)

View File

@@ -2,7 +2,7 @@
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from pydantic import Field, validate_call
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -31,9 +31,16 @@ class C2Server(AbstractC2, identifier="C2Server"):
1. Sending commands to the C2 Beacon. (Command input)
2. Parsing terminal RequestResponses back to the Agent.
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
"""
class ConfigSchema(AbstractC2.ConfigSchema):
"""ConfigSchema for C2Server."""
type: str = "C2Server"
config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema())
current_command_output: RequestResponse = None
"""The Request Response by the last command send. This attribute is updated by the method _handle_command_output."""
@@ -251,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
payload=command_packet,
dest_ip_address=self.c2_remote_connection,
session_id=self.c2_session.uuid,
dest_port=self.c2_config.masquerade_port,
ip_protocol=self.c2_config.masquerade_protocol,
dest_port=self.config.masquerade_port,
ip_protocol=self.config.masquerade_protocol,
):
self.sys_log.info(f"{self.name}: Successfully sent {given_command}.")
self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.")
@@ -334,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"):
:return: Returns False if the C2 beacon is considered dead. Otherwise True.
:rtype bool:
"""
if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency:
if self.keep_alive_inactivity > self.config.keep_alive_frequency:
self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.")
self.sys_log.debug(
f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}"
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}"
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}"
f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}"
)
self._reset_c2_connection()
@@ -389,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
[
self.c2_connection_active,
self.c2_remote_connection,
self.c2_config.masquerade_protocol,
self.c2_config.masquerade_port,
self.config.masquerade_protocol,
self.config.masquerade_port,
]
)
print(table)

View File

@@ -3,6 +3,8 @@ from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestResponse
@@ -10,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -40,6 +43,18 @@ class DataManipulationAttackStage(IntEnum):
class DataManipulationBot(Application, identifier="DataManipulationBot"):
"""A bot that simulates a script which performs a SQL injection attack."""
class ConfigSchema(Application.ConfigSchema):
"""Configuration schema for DataManipulationBot."""
type: str = "DataManipulationBot"
server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
payload: str = "DELETE"
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema())
payload: Optional[str] = None
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
@@ -56,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"):
super().__init__(**kwargs)
self._db_connection: Optional[DatabaseClientConnection] = None
self.server_ip_address = self.config.server_ip
self.server_password = self.config.server_password
self.payload = self.config.payload
self.port_scan_p_of_success = self.config.port_scan_p_of_success
self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -3,11 +3,14 @@ from enum import IntEnum
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
@@ -32,6 +35,20 @@ class DoSAttackStage(IntEnum):
class DoSBot(DatabaseClient, identifier="DoSBot"):
"""A bot that simulates a Denial of Service attack."""
class ConfigSchema(DatabaseClient.ConfigSchema):
"""ConfigSchema for DoSBot."""
type: str = "DoSBot"
target_ip_address: Optional[IPV4Address] = None
target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"]
payload: Optional[str] = None
repeat: bool = False
port_scan_p_of_success: float = 0.1
dos_intensity: float = 1.0
max_sessions: int = 1000
config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema())
target_ip_address: Optional[IPv4Address] = None
"""IP address of the target service."""
@@ -56,7 +73,13 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DoSBot"
self.max_sessions = 1000 # override normal max sessions
self.target_ip_address = self.config.target_ip_address
self.target_port = self.config.target_port
self.payload = self.config.payload
self.repeat = self.config.repeat
self.port_scan_p_of_success = self.config.port_scan_p_of_success
self.dos_intensity = self.config.dos_intensity
self.max_sessions = self.config.max_sessions
def _init_request_manager(self) -> RequestManager:
"""

View File

@@ -3,12 +3,14 @@ from ipaddress import IPv4Address
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
@@ -18,6 +20,16 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
:ivar payload: The attack stage query payload. (Default ENCRYPT)
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for RansomwareScript."""
type: str = "RansomwareScript"
server_ip: Optional[IPV4Address] = None
server_password: Optional[str] = None
payload: str = "ENCRYPT"
config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema())
server_ip_address: Optional[IPv4Address] = None
"""IP address of node which hosts the database."""
server_password: Optional[str] = None
@@ -32,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
super().__init__(**kwargs)
self._db_connection: Optional[DatabaseClientConnection] = None
self.server_ip_address = self.config.server_ip
self.server_password = self.config.server_password
self.payload = self.config.payload
def describe_state(self) -> Dict:
"""

View File

@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from typing import Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from primaite import getLogger
from primaite.interface.request import RequestResponse
@@ -30,7 +30,13 @@ class WebBrowser(Application, identifier="WebBrowser"):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
target_url: Optional[str] = None
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for WebBrowser."""
type: str = "WebBrowser"
target_url: Optional[str] = None
config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema())
domain_name_ip_address: Optional[IPv4Address] = None
"The IP address of the domain name for the webpage."
@@ -86,7 +92,7 @@ class WebBrowser(Application, identifier="WebBrowser"):
:param: url: The address of the web page the browser requests
:type: url: str
"""
url = url or self.target_url
url = url or self.config.target_url
if not self._can_perform_action():
return False

View File

@@ -106,7 +106,7 @@ class SoftwareManager:
return True
return False
def install(self, software_class: Type[IOSoftware], **install_kwargs):
def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None):
"""
Install an Application or Service.
@@ -115,13 +115,22 @@ class SoftwareManager:
if software_class in self._software_class_to_name_map:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
**install_kwargs,
)
if software_config is None:
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
)
else:
software = software_class(
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
config=software_config,
)
software.parent = self.node
if isinstance(software, Application):
self.node.applications[software.uuid] = software

View File

@@ -5,6 +5,7 @@ from abc import abstractmethod
from typing import Any, Dict, Optional, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
@@ -14,7 +15,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address
from primaite.utils.validation.port import PORT_LOOKUP
class ARP(Service):
class ARP(Service, identifier="ARP"):
"""
The ARP (Address Resolution Protocol) Service.
@@ -22,6 +23,13 @@ class ARP(Service):
sends ARP requests and replies, and processes incoming ARP packets.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for ARP."""
type: str = "ARP"
config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema())
arp: Dict[IPV4Address, ARPEntry] = {}
def __init__(self, **kwargs):

View File

@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import Field
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
@@ -17,13 +19,21 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class DatabaseService(Service):
class DatabaseService(Service, identifier="DatabaseService"):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DatabaseService."""
type: str = "DatabaseService"
backup_server_ip: Optional[IPv4Address] = None
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
@@ -42,6 +52,7 @@ class DatabaseService(Service):
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
self._create_db_file()
self.backup_server_ip = self.config.backup_server_ip
def install(self):
"""

View File

@@ -2,6 +2,8 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.system.core.software_manager import SoftwareManager
@@ -12,9 +14,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class DNSClient(Service):
class DNSClient(Service, identifier="DNSClient"):
"""Represents a DNS Client as a Service."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DNSClient."""
type: str = "DNSClient"
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
dns_cache: Dict[str, IPv4Address] = {}
"A dict of known mappings between domain/URLs names and IPv4 addresses."
dns_server: Optional[IPv4Address] = None

View File

@@ -3,6 +3,7 @@ from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket
@@ -13,9 +14,17 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class DNSServer(Service):
class DNSServer(Service, identifier="DNSServer"):
"""Represents a DNS Server as a Service."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for DNSServer."""
type: str = "DNSServer"
domain_mapping: dict = {}
config: "DNSServer.ConfigSchema" = Field(default_factory=lambda: DNSServer.ConfigSchema())
dns_table: Dict[str, IPv4Address] = {}
"A dict of mappings between domain names and IPv4 addresses."

View File

@@ -2,6 +2,8 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import Field
from primaite import getLogger
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
@@ -9,20 +11,28 @@ from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import Port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class FTPClient(FTPServiceABC):
class FTPClient(FTPServiceABC, identifier="FTPClient"):
"""
A class for simulating an FTP client service.
This class inherits from the `Service` class and provides methods to emulate FTP
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
config: "FTPClient.ConfigSchema" = Field(default_factory=lambda: FTPClient.ConfigSchema())
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for FTPClient."""
type: str = "FTPClient"
def __init__(self, **kwargs):
kwargs["name"] = "FTPClient"
kwargs["port"] = PORT_LOOKUP["FTP"]
@@ -108,6 +118,7 @@ class FTPClient(FTPServiceABC):
session_id: Optional[str] = None,
is_reattempt: Optional[bool] = False,
) -> bool:
self._active = True
"""
Connects the client to a given FTP server.
@@ -164,6 +175,7 @@ class FTPClient(FTPServiceABC):
:param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False.
:type: is_reattempt: Optional[bool]
"""
self._active = True
# send a disconnect request payload to FTP server
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT)
software_manager: SoftwareManager = self.software_manager
@@ -209,6 +221,7 @@ class FTPClient(FTPServiceABC):
:param: session_id: The id of the session
:type: session_id: Optional[str]
"""
self._active = True
# check if the file to transfer exists on the client
file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name)
if not file_to_transfer:
@@ -266,6 +279,7 @@ class FTPClient(FTPServiceABC):
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"].
:type: dest_port: Optional[int]
"""
self._active = True
# check if FTP is currently connected to IP
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
@@ -317,6 +331,7 @@ class FTPClient(FTPServiceABC):
This helps prevent an FTP request loop - FTP client and servers can exist on
the same node.
"""
self._active = True
if not self._can_perform_action():
return False

View File

@@ -1,26 +1,36 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from typing import Any, Optional
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import Service
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP
_LOGGER = getLogger(__name__)
class FTPServer(FTPServiceABC):
class FTPServer(FTPServiceABC, identifier="FTPServer"):
"""
A class for simulating an FTP server service.
This class inherits from the `Service` class and provides methods to emulate FTP
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for FTPServer."""
type: str = "FTPServer"
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
kwargs["port"] = PORT_LOOKUP["FTP"]

View File

@@ -3,9 +3,11 @@ from abc import ABC
from ipaddress import IPv4Address
from typing import Dict, Optional
from pydantic import StrictBool
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.utils.validation.port import Port
@@ -16,9 +18,22 @@ class FTPServiceABC(Service, ABC):
Contains shared methods between both classes.
"""
_active: StrictBool = False
"""Flag that is True on timesteps where service transmits data and False when idle. Used for describe_state."""
def pre_timestep(self, timestep: int) -> None:
"""When a new timestep begins, clear the _active attribute."""
self._active = False
return super().pre_timestep(timestep)
def describe_state(self) -> Dict:
"""Returns a Dict of the FTPService state."""
return super().describe_state()
state = super().describe_state()
# override so that the service is shows as running only if actively transmitting data this timestep
if self.operating_state == ServiceOperatingState.RUNNING and not self._active:
state["operating_state"] = ServiceOperatingState.STOPPED.value
return state
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
@@ -29,6 +44,7 @@ class FTPServiceABC(Service, ABC):
:param: session_id: session ID linked to the FTP Packet. Optional.
:type: session_id: Optional[str]
"""
self._active = True
if payload.ftp_command is not None:
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
@@ -51,6 +67,7 @@ class FTPServiceABC(Service, ABC):
:param: payload: The FTP Packet that contains the file data
:type: FTPPacket
"""
self._active = True
try:
file_name = payload.ftp_command_args["dest_file_name"]
folder_name = payload.ftp_command_args["dest_folder_name"]
@@ -106,6 +123,7 @@ class FTPServiceABC(Service, ABC):
:param: is_response: is true if the data being sent is in response to a request. Default False.
:type: is_response: bool
"""
self._active = True
# send STOR request
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.STOR,
@@ -135,6 +153,7 @@ class FTPServiceABC(Service, ABC):
:param: payload: The FTP Packet that contains the file data
:type: FTPPacket
"""
self._active = True
try:
# find the file
file_name = payload.ftp_command_args["src_file_name"]
@@ -181,6 +200,7 @@ class FTPServiceABC(Service, ABC):
:return: True if successful, False otherwise.
"""
self._active = True
self.sys_log.info(f"{self.name}: Sending FTP {payload.ftp_command.name} {payload.ftp_command_args}")
return super().send(

View File

@@ -3,6 +3,8 @@ import secrets
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Tuple, Union
from pydantic import Field
from primaite import getLogger
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
@@ -14,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class ICMP(Service):
class ICMP(Service, identifier="ICMP"):
"""
The Internet Control Message Protocol (ICMP) service.
@@ -22,6 +24,13 @@ class ICMP(Service):
network diagnostics, notably the ping command.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for ICMP."""
type: str = "ICMP"
config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema())
request_replies: Dict = {}
def __init__(self, **kwargs):

Some files were not shown because too many files have changed in this diff Show More