#2887 - Resolve conflicts from merge
This commit is contained in:
@@ -14,31 +14,36 @@ parameters:
|
||||
- name: matrix
|
||||
type: object
|
||||
default:
|
||||
# - job_name: 'UbuntuPython38'
|
||||
# py: '3.8'
|
||||
# img: 'ubuntu-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'UbuntuPython311'
|
||||
py: '3.11'
|
||||
- job_name: 'UbuntuPython39'
|
||||
py: '3.9'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'UbuntuPython310'
|
||||
py: '3.10'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: true
|
||||
publish_coverage: true
|
||||
# - job_name: 'WindowsPython38'
|
||||
# py: '3.8'
|
||||
# img: 'windows-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'UbuntuPython311'
|
||||
py: '3.11'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'WindowsPython39'
|
||||
py: '3.9'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'WindowsPython311'
|
||||
py: '3.11'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
# - job_name: 'MacOSPython38'
|
||||
# py: '3.8'
|
||||
# img: 'macOS-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'MacOSPython39'
|
||||
py: '3.9'
|
||||
img: 'macOS-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'MacOSPython311'
|
||||
py: '3.11'
|
||||
img: 'macOS-latest'
|
||||
@@ -63,7 +68,7 @@ stages:
|
||||
displayName: 'Use Python ${{ item.py }}'
|
||||
|
||||
- script: |
|
||||
python -m pip install pre-commit
|
||||
python -m pip install pre-commit>=6.1
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
displayName: 'Run pre-commits'
|
||||
@@ -71,7 +76,6 @@ stages:
|
||||
- script: |
|
||||
python -m pip install --upgrade pip==23.0.1
|
||||
pip install wheel==0.38.4 --upgrade
|
||||
pip install setuptools==66 --upgrade
|
||||
pip install build==0.10.0
|
||||
pip install pytest-azurepipelines
|
||||
displayName: 'Install build dependencies'
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: ensure-copyright-clause
|
||||
name: ensure copyright clause
|
||||
entry: python copyright_clause_pre_commit_hook.py
|
||||
language: python
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: ensure-copyright-clause
|
||||
# name: ensure copyright clause
|
||||
# entry: python copyright_clause_pre_commit_hook.py
|
||||
# language: python
|
||||
- repo: http://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
@@ -31,7 +31,7 @@ repos:
|
||||
- id: isort
|
||||
args: [ "--profile", "black" ]
|
||||
- repo: http://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
rev: 6.1.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies:
|
||||
|
||||
18
CHANGELOG.md
18
CHANGELOG.md
@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [4.0.0] = TBC
|
||||
|
||||
### Added
|
||||
|
||||
### Changed
|
||||
- Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty.
|
||||
- Actions within PrimAITE are now extensible, allowing for plugin support.
|
||||
- Added a config schema to `ObservationManager`, `ActionManager`, and `RewardFunction`.
|
||||
- Streamlined the way agents are created from config
|
||||
- Agent config no longer requires a dummy action space if the action space is empty, the same applies for observation space and reward function
|
||||
- Actions now support a config schema, to allow yaml data validation and default parameter values
|
||||
- Action parameters are no longer defined through IDs, instead meaningful data is provided directly in the action map
|
||||
- Test and example YAMLs have been updated to match the new agent and action schemas, such as:
|
||||
- Removed empty action spaces, observation spaces, or reward spaces for agent which didn't use them
|
||||
- Relabeled action parameters to match the new action config schemas, and updated the values to no longer rely on indices
|
||||
- Removed action space options which were previously used for assigning meaning to action space IDs
|
||||
- Updated tests that don't use YAMLs to still use the new action and agent schemas
|
||||
|
||||
## [3.3.0] - 2024-09-04
|
||||
|
||||
### Added
|
||||
|
||||
@@ -70,7 +70,7 @@ PrimAITE incorporates the following features:
|
||||
|
||||
- Architected with a separate Simulation layer and Game layer. This separation of concerns defines a clear path towards transfer learning with environments of differing fidelity;
|
||||
- Ability to reconfigure an RL reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure success for green agents;
|
||||
- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / ALLOW, source / destination IP addresses, protocol and port);
|
||||
- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / PERMIT, source / destination IP addresses, protocol and port);
|
||||
- Application of traffic to the links of the system laydown adheres to the ACL rulesets and routing tables contained within each network device;
|
||||
- Provides RL environments adherent to the Farama Foundation Gymnasium (Previously OpenAI Gym) API, allowing integration with any compliant RL Agent frameworks;
|
||||
- Provides RL environments adherent to Ray RLlib environment specifications for single-agent and multi-agent scenarios;
|
||||
|
||||
@@ -184,7 +184,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
- 192.168.1.5
|
||||
- ANY
|
||||
- ANY
|
||||
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or ALLOW for the same parameters will override an earlier entry.
|
||||
All ACL rules are considered when applying an IER. Logic follows the order of rules, so a DENY or PERMIT for the same parameters will override an earlier entry.
|
||||
Observation Spaces
|
||||
******************
|
||||
The observation space provides the blue agent with information about the current status of nodes and links.
|
||||
@@ -331,7 +331,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
|
||||
The placeholders inside the list under the key '1' mean the following:
|
||||
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
* [0, 1] - Permission (0 = DENY, 1 = PERMIT)
|
||||
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
|
||||
@@ -23,123 +23,117 @@ The following logic is applied:
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| Action | Action Mask Logic |
|
||||
+==========================================+=====================================================================+
|
||||
| **DONOTHING** | Always Possible. |
|
||||
| **do_nothing** | Always Possible. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_SCAN** | Node is on. Service is running. |
|
||||
| **node_service_scan** | Node is on. Service is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_STOP** | Node is on. Service is running. |
|
||||
| **node_service_stop** | Node is on. Service is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_START** | Node is on. Service is stopped. |
|
||||
| **node_service_start** | Node is on. Service is stopped. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_PAUSE** | Node is on. Service is running. |
|
||||
| **node_service_pause** | Node is on. Service is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_RESUME** | Node is on. Service is paused. |
|
||||
| **node_service_resume** | Node is on. Service is paused. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_RESTART** | Node is on. Service is running. |
|
||||
| **node_service_restart** | Node is on. Service is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_DISABLE** | Node is on. |
|
||||
| **node_service_disable** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_ENABLE** | Node is on. Service is disabled. |
|
||||
| **node_service_enable** | Node is on. Service is disabled. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SERVICE_FIX** | Node is on. Service is running. |
|
||||
| **node_service_fix** | Node is on. Service is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_APPLICATION_EXECUTE** | Node is on. |
|
||||
| **node_application_execute** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_APPLICATION_SCAN** | Node is on. Application is running. |
|
||||
| **node_application_scan** | Node is on. Application is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_APPLICATION_CLOSE** | Node is on. Application is running. |
|
||||
| **node_application_close** | Node is on. Application is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_APPLICATION_FIX** | Node is on. Application is running. |
|
||||
| **node_application_fix** | Node is on. Application is running. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_APPLICATION_INSTALL** | Node is on. |
|
||||
| **node_application_install** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_APPLICATION_REMOVE** | Node is on. |
|
||||
| **node_application_remove** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_SCAN** | Node is on. File exists. File not deleted. |
|
||||
| **node_file_scan** | Node is on. File exists. File not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_CREATE** | Node is on. |
|
||||
| **node_file_create** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_CHECKHASH** | Node is on. File exists. File not deleted. |
|
||||
| **node_file_checkhash** | Node is on. File exists. File not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_DELETE** | Node is on. File exists. |
|
||||
| **node_file_delete** | Node is on. File exists. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_REPAIR** | Node is on. File exists. File not deleted. |
|
||||
| **node_file_repair** | Node is on. File exists. File not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_RESTORE** | Node is on. File exists. File is deleted. |
|
||||
| **node_file_restore** | Node is on. File exists. File is deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_CORRUPT** | Node is on. File exists. File not deleted. |
|
||||
| **node_file_corrupt** | Node is on. File exists. File not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FILE_ACCESS** | Node is on. File exists. File not deleted. |
|
||||
| **node_file_access** | Node is on. File exists. File not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FOLDER_CREATE** | Node is on. |
|
||||
| **node_folder_create** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FOLDER_SCAN** | Node is on. Folder exists. Folder not deleted. |
|
||||
| **node_folder_scan** | Node is on. Folder exists. Folder not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FOLDER_CHECKHASH** | Node is on. Folder exists. Folder not deleted. |
|
||||
| **node_folder_checkhash** | Node is on. Folder exists. Folder not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FOLDER_REPAIR** | Node is on. Folder exists. Folder not deleted. |
|
||||
| **node_folder_repair** | Node is on. Folder exists. Folder not deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_FOLDER_RESTORE** | Node is on. Folder exists. Folder is deleted. |
|
||||
| **node_folder_restore** | Node is on. Folder exists. Folder is deleted. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_OS_SCAN** | Node is on. |
|
||||
| **node_os_scan** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **HOST_NIC_ENABLE** | NIC is disabled. Node is on. |
|
||||
| **host_nic_enable** | NIC is disabled. Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **HOST_NIC_DISABLE** | NIC is enabled. Node is on. |
|
||||
| **host_nic_disable** | NIC is enabled. Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SHUTDOWN** | Node is on. |
|
||||
| **node_shutdown** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_STARTUP** | Node is off. |
|
||||
| **node_startup** | Node is off. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_RESET** | Node is on. |
|
||||
| **node_reset** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_NMAP_PING_SCAN** | Node is on. |
|
||||
| **node_nmap_ping_scan** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_NMAP_PORT_SCAN** | Node is on. |
|
||||
| **node_nmap_port_scan** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. |
|
||||
| **node_network_service_recon** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NETWORK_PORT_ENABLE** | Node is on. Router is on. |
|
||||
| **network_port_enable** | Node is on. Router is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NETWORK_PORT_DISABLE** | Router is on. |
|
||||
| **network_port_disable** | Router is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **ROUTER_ACL_ADDRULE** | Router is on. |
|
||||
| **router_acl_addrule** | Router is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **ROUTER_ACL_REMOVERULE** | Router is on. |
|
||||
| **router_acl_removerule** | Router is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **FIREWALL_ACL_ADDRULE** | Firewall is on. |
|
||||
| **firewall_acl_addrule** | Firewall is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **FIREWALL_ACL_REMOVERULE** | Firewall is on. |
|
||||
| **firewall_acl_removerule** | Firewall is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_NMAP_PING_SCAN** | Node is on. |
|
||||
| **configure_database_client** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_NMAP_PORT_SCAN** | Node is on. |
|
||||
| **configure_ransomware_script** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_NMAP_NETWORK_SERVICE_RECON** | Node is on. |
|
||||
| **c2_server_ransomware_configure** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **CONFIGURE_DATABASE_CLIENT** | Node is on. |
|
||||
| **configure_dos_bot** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **CONFIGURE_RANSOMWARE_SCRIPT** | Node is on. |
|
||||
| **configure_c2_beacon** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **CONFIGURE_DOSBOT** | Node is on. |
|
||||
| **c2_server_ransomware_launch** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **CONFIGURE_C2_BEACON** | Node is on. |
|
||||
| **c2_server_terminal_command** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **C2_SERVER_RANSOMWARE_LAUNCH** | Node is on. |
|
||||
| **c2_server_data_exfiltrate** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **C2_SERVER_RANSOMWARE_CONFIGURE** | Node is on. |
|
||||
| **node_account_change_password** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **C2_SERVER_TERMINAL_COMMAND** | Node is on. |
|
||||
| **node_session_remote_login** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **C2_SERVER_DATA_EXFILTRATE** | Node is on. |
|
||||
| **node_session_remote_logoff** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_ACCOUNTS_CHANGE_PASSWORD** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **SSH_TO_REMOTE** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **SESSIONS_REMOTE_LOGOFF** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
| **NODE_SEND_REMOTE_COMMAND** | Node is on. |
|
||||
| **node_send_remote_command** | Node is on. |
|
||||
+------------------------------------------+---------------------------------------------------------------------+
|
||||
|
||||
|
||||
|
||||
@@ -23,19 +23,6 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
@@ -91,10 +78,6 @@ For more information see :py:mod:`primaite.game.agent.observations`
|
||||
|
||||
The action space is configured to be made up of individual action types. Once configured, the agent can select an action type and some optional action parameters at every step. For example: The ``NODE_SERVICE_SCAN`` action takes the parameters ``node_id`` and ``service_id``.
|
||||
|
||||
``action_list``
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
A list of action modules. The options are listed in the :py:mod:`primaite.game.agent.actions.ActionManager.act_class_identifiers` module.
|
||||
|
||||
``action_map``
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
67
docs/source/how_to_guides/extensible_actions.rst
Normal file
67
docs/source/how_to_guides/extensible_actions.rst
Normal file
@@ -0,0 +1,67 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
|
||||
Extensible Actions
|
||||
******************
|
||||
|
||||
|
||||
Changes to Actions class Structure.
|
||||
===================================
|
||||
|
||||
Actions within PrimAITE have been updated to inherit from a base class, AbstractAction, standardising their format and allowing for easier creation of custom actions. Actions now use a ``ConfigSchema`` to define the possible configuration variables, and use pydantic to enforce correct parameters are passed through.
|
||||
|
||||
|
||||
Developing Custom Actions.
|
||||
==========================
|
||||
|
||||
Custom actions within PrimAITE must be a sub-class of `AbstractAction`, and contain 3 key items:
|
||||
|
||||
#. ConfigSchema class
|
||||
|
||||
#. Unique Identifier
|
||||
|
||||
#. `form_request` method.
|
||||
|
||||
|
||||
ConfigSchema
|
||||
############
|
||||
|
||||
The ConfigSchema sub-class of the action must contain all `configurable` variables within the action, that would be specified within the environments configuration YAML file.
|
||||
|
||||
|
||||
Unique Identifier
|
||||
#################
|
||||
|
||||
When declaring a custom class, it must have a unique identifier string, that allows PrimAITE to generate the correct action when needed.
|
||||
|
||||
.. code:: Python
|
||||
|
||||
class CreateDirectoryAction(AbstractAction, identifier="node_folder_create")
|
||||
|
||||
config: CreateDirectoryAction.ConfigSchema
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
node_name: str
|
||||
directory_name: str
|
||||
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
return ["network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"folder",
|
||||
config.directory_name,
|
||||
]
|
||||
|
||||
The above action would fail pydantic validation as the identifier "node_folder_create" is already used by the `NodeFolderCreateAction`, and would create a duplicate listing within `AbstractAction._registry`.
|
||||
|
||||
|
||||
form_request method
|
||||
###################
|
||||
|
||||
PrimAITE actions need to have a `form_request` method, which can be passed to the `RequestManager` for processing. This allows the custom action to be actioned within the simulation environment.
|
||||
78
docs/source/how_to_guides/extensible_agents.rst
Normal file
78
docs/source/how_to_guides/extensible_agents.rst
Normal file
@@ -0,0 +1,78 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _about:
|
||||
|
||||
Extensible Agents
|
||||
*****************
|
||||
|
||||
Agents defined within PrimAITE have been updated to allow for easier creation of new bespoke agents for use in custom environments.
|
||||
|
||||
|
||||
Developing Agents for PrimAITE
|
||||
==============================
|
||||
|
||||
All agent types within PrimAITE must be subclassed from ``AbstractAgent`` in order to be used from configuration YAML files. This then allows you to implement any custom agent logic for the new agent in your training scenario. Examples of implementing custom agent logic can be seen in pre-existing agents, such as the ``DataManipulationBot`` and ``RandomAgent``.
|
||||
|
||||
The core features that should be implemented in any new agent are detailed below:
|
||||
|
||||
#. **ConfigSchema**:
|
||||
|
||||
Configurable items within a new agent within PrimAITE should contain a ``ConfigSchema`` which holds all configurable variables of the agent. This should not include parameters related to its *state*, these would be listed seperately.
|
||||
Agent generation will fail pydantic checks if incorrect or invalid parameters are passed to the ConfigSchema of the chosen Agent.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class ExampleAgent(AbstractAgent, identifier = "ExampleAgent"):
|
||||
"""An example agent for demonstration purposes."""
|
||||
|
||||
config: "ExampleAgent.ConfigSchema" = Field(default_factory= lambda: ExampleAgent.ConfigSchema())
|
||||
"""Agent configuration"""
|
||||
num_executions: int = 0
|
||||
"""Number of action executions by agent"""
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""ExampleAgent configuration schema"""
|
||||
|
||||
type: str = "ExampleAgent
|
||||
"""Name of agent"""
|
||||
starting_host: int
|
||||
"""Host node that this agent should start from in the given environment."""
|
||||
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
- ref: example_green_agent
|
||||
team: GREEN
|
||||
type: ExampleAgent
|
||||
|
||||
action_space:
|
||||
action_map:
|
||||
0:
|
||||
action: do_nothing
|
||||
options: {}
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
starting_host: "Server_1"
|
||||
|
||||
|
||||
#. **Identifiers**:
|
||||
|
||||
All agent classes should have an ``identifier`` attribute, a unique kebab-case string, for when they are added to the base ``AbstractAgent`` registry. This is then specified in your configuration YAML, and used by PrimAITE to generate the correct Agent.
|
||||
|
||||
Changes to YAML file
|
||||
====================
|
||||
|
||||
PrimAITE v4.0.0 introduces some breaking changes to how environment configuration yaml files are created. YAML files created for Primaite versions 3.3.0 should be compatible through a translation function, though it is encouraged that these are updated to reflect the updated format of 4.0.0+.
|
||||
|
||||
Agents now follow a more standardised settings definition, so should be more consistent across YAML files and the available agent types with PrimAITE.
|
||||
|
||||
All configurable items for agents sit under the ``agent_settings`` heading within your YAML files. There is no need for the inclusion of a ``start_settings``. Please see the above YAML example for full changes to agents.
|
||||
@@ -2,44 +2,44 @@
|
||||
|
||||
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| Name | Version | License | Description | URL |
|
||||
+===================+=========+====================================+=======================================================================================================+====================================================================+
|
||||
| gymnasium | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| ipywidgets | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| jupyterlab | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| kaleido | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| matplotlib | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| networkx | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| numpy | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| platformdirs | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| plotly | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| polars | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| prettytable | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| pydantic | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| PyYAML | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| ray | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| stable-baselines3 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| tensorflow | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| typer | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| Deepdiff | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
| sb3_contrib | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking)| https://github.com/Stable-Baselines-Team/stable-baselines3-contrib |
|
||||
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------+
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| Name | Supported Version | Built Version | License | Description | URL |
|
||||
+===================+=====================+===============+======================================+========================================================================================================+=====================================================================+
|
||||
| gymnasium | 0.28.1 | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| ipywidgets | ~=8.0 | 8.1.5 | BSD License | Jupyter interactive widgets | http://jupyter.org |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| jupyterlab | 3.6.1 | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| kaleido | ==0.2.1 | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| matplotlib | >=3.7.1 | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| networkx | 3.1 | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| numpy | ~1.23 | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| platformdirs | 3.5.1 | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| plotly | 5.15 | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| polars | 0.20.30 | 0.20.30 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| prettytable | 3.8.0 | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| pydantic | 2.7.0 | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| PyYAML | >=6.0 | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| ray | >=2.20, <2.33 | 2.32.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| stable-baselines3 | 2.1.0 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| tensorflow | ~=2.12 | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| typer | >=0.9 | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| Deepdiff | 8.0.1 | 8.0.1 | MIT License | Deep difference of dictionaries, iterables, strings, and any other object objects. | https://github.com/seperman/deepdiff |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
| sb3_contrib | 2.1.0 | 2.1.0 | MIT License | Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code (Action Masking) | https://github.com/Stable-Baselines-Team/stable-baselines3-contrib |
|
||||
+-------------------+---------------------+---------------+--------------------------------------+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------+
|
||||
|
||||
@@ -113,18 +113,6 @@ If not using the data manipulation bot manually, it needs to be used with a data
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_ref: data_manipulation_bot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
@@ -70,7 +70,7 @@ Python
|
||||
Configuration
|
||||
=============
|
||||
|
||||
The RansomwareScript inherits configuration options such as ``fix_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
|
||||
The RansomwareScript inherits configuration options such as ``fixing_duration`` from its parent class. However, for the ``RansomwareScript`` the most relevant option is ``server_ip``.
|
||||
|
||||
|
||||
``server_ip``
|
||||
|
||||
@@ -22,8 +22,8 @@ options
|
||||
|
||||
The configuration options are the attributes that fall under the options for an application or service.
|
||||
|
||||
fix_duration
|
||||
""""""""""""
|
||||
fixing_duration
|
||||
"""""""""""""""
|
||||
|
||||
Optional. Default value is ``2``.
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ name = "primaite"
|
||||
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
|
||||
authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}]
|
||||
license = {file = "LICENSE"}
|
||||
requires-python = ">=3.9, <3.12"
|
||||
requires-python = ">=3.9, <3.13"
|
||||
dynamic = ["version", "readme"]
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
@@ -26,15 +26,15 @@ dependencies = [
|
||||
"gymnasium==0.28.1",
|
||||
"jupyterlab==3.6.1",
|
||||
"kaleido==0.2.1",
|
||||
"matplotlib==3.7.1",
|
||||
"matplotlib>=3.7.1",
|
||||
"networkx==3.1",
|
||||
"numpy==1.23.5",
|
||||
"numpy~=1.23",
|
||||
"platformdirs==3.5.1",
|
||||
"plotly==5.15.0",
|
||||
"polars==0.20.30",
|
||||
"prettytable==3.8.0",
|
||||
"PyYAML==6.0",
|
||||
"typer[all]==0.9.0",
|
||||
"PyYAML>=6.0",
|
||||
"typer[all]>=0.9",
|
||||
"pydantic==2.7.0",
|
||||
"ipywidgets",
|
||||
"deepdiff"
|
||||
@@ -53,8 +53,8 @@ license-files = ["LICENSE"]
|
||||
[project.optional-dependencies]
|
||||
rl = [
|
||||
"ray[rllib] >= 2.20.0, <2.33",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
"tensorflow~=2.12",
|
||||
"stable-baselines3==2.1.0",
|
||||
"sb3-contrib==2.1.0",
|
||||
]
|
||||
dev = [
|
||||
@@ -69,7 +69,7 @@ dev = [
|
||||
"pytest-xdist==3.3.1",
|
||||
"pytest-cov==4.0.0",
|
||||
"pytest-flake8==1.1.1",
|
||||
"setuptools==66",
|
||||
"setuptools==75.6.0",
|
||||
"Sphinx==7.1.2",
|
||||
"sphinx-copybutton==0.5.2",
|
||||
"wheel==0.38.4",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
@@ -30,35 +30,22 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client_2
|
||||
application_name: WebBrowser
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
node_name: client_2
|
||||
application_name: DatabaseClient
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -79,35 +66,22 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client_1
|
||||
application_name: WebBrowser
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
node_name: client_1
|
||||
application_name: WebBrowser
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -128,33 +102,12 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
agent_settings:
|
||||
possible_start_nodes: [client_1, client_2]
|
||||
target_application: DataManipulationBot
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
@@ -208,8 +161,8 @@ agents:
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
@@ -235,490 +188,426 @@ agents:
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_SERVICE_FIX
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: ROUTER_ACL_ADDRULE
|
||||
- type: ROUTER_ACL_REMOVERULE
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
action: node_service_scan
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
action: node_service_stop
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
action: "node_service_start"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
action: "node_service_pause"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
action: "node_service_resume"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
action: "node_service_restart"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
action: "node_service_disable"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
action: "node_service_enable"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
node_name: web_server
|
||||
service_name: WebServer
|
||||
9: # check database.db file
|
||||
action: "NODE_FILE_SCAN"
|
||||
action: "node_file_scan"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
action: "node_file_scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
action: "node_file_delete"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
action: "node_file_repair"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
file_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
13:
|
||||
action: "NODE_SERVICE_FIX"
|
||||
action: "node_service_fix"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 0
|
||||
node_name: database_server
|
||||
service_name: DatabaseService
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
action: "node_folder_scan"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
action: "node_folder_scan" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context.
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
action: "node_folder_repair"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
action: "node_folder_restore"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 0
|
||||
node_name: database_server
|
||||
folder_name: database
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
20:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
21:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: domain_controller
|
||||
22:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
23:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
24:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
25:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: web_server
|
||||
26: # old action num: 18
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
27:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
28:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
29:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 2
|
||||
node_name: database_server
|
||||
30:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
31:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
32:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
33:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 3
|
||||
node_name: backup_server
|
||||
34:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
35:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
36:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
37:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 4
|
||||
node_name: security_suite
|
||||
38:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
39: # old action num: 19 # shutdown client 1
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
40: # old action num: 20
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
41: # old action num: 21
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 5
|
||||
node_name: client_1
|
||||
42:
|
||||
action: "NODE_OS_SCAN"
|
||||
action: "node_os_scan"
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
43:
|
||||
action: "NODE_SHUTDOWN"
|
||||
action: "node_shutdown"
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
44:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
45:
|
||||
action: NODE_RESET
|
||||
action: node_reset
|
||||
options:
|
||||
node_id: 6
|
||||
node_name: client_2
|
||||
|
||||
46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.21 # client 1
|
||||
dst_ip: ALL # ALL
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: ALL
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2"
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.22 # client 2
|
||||
dst_ip: ALL # ALL
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: ALL
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
48: # old action num: 24 # block tcp traffic from client 1 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.21 # client 1
|
||||
dst_ip: 192.168.1.12 # web server
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
49: # old action num: 25 # block tcp traffic from client 2 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.22 # client 2
|
||||
dst_ip: 192.168.1.12 # web server
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
50: # old action num: 26
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.21 # client 1
|
||||
dst_ip: 192.168.1.14 # database
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
51: # old action num: 27
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
action: "router_acl_add_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
permission: DENY
|
||||
src_ip: 192.168.10.22 # client 2
|
||||
dst_ip: 192.168.1.14 # database
|
||||
src_port: ALL
|
||||
dst_port: ALL
|
||||
protocol_name: TCP
|
||||
src_wildcard: NONE
|
||||
dst_wildcard: NONE
|
||||
52: # old action num: 28
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 0
|
||||
53: # old action num: 29
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 1
|
||||
54: # old action num: 30
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 2
|
||||
55: # old action num: 31
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 3
|
||||
56: # old action num: 32
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 4
|
||||
57: # old action num: 33
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 5
|
||||
58: # old action num: 34
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 6
|
||||
59: # old action num: 35
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 7
|
||||
60: # old action num: 36
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 8
|
||||
61: # old action num: 37
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
action: "router_acl_remove_rule"
|
||||
options:
|
||||
target_router: router_1
|
||||
position: 9
|
||||
62: # old action num: 38
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: domain_controller
|
||||
nic_num: 1
|
||||
63: # old action num: 39
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: domain_controller
|
||||
nic_num: 1
|
||||
64: # old action num: 40
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: web_server
|
||||
nic_num: 1
|
||||
65: # old action num: 41
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: web_server
|
||||
nic_num: 1
|
||||
66: # old action num: 42
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
node_name: database_server
|
||||
nic_num: 1
|
||||
67: # old action num: 43
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
node_name: database_server
|
||||
nic_num: 1
|
||||
68: # old action num: 44
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
node_name: backup_server
|
||||
nic_num: 1
|
||||
69: # old action num: 45
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
node_name: backup_server
|
||||
nic_num: 1
|
||||
70: # old action num: 46
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
node_name: security_suite
|
||||
nic_num: 1
|
||||
71: # old action num: 47
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
node_name: security_suite
|
||||
nic_num: 1
|
||||
72: # old action num: 48
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
node_name: security_suite
|
||||
nic_num: 2
|
||||
73: # old action num: 49
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
node_name: security_suite
|
||||
nic_num: 2
|
||||
74: # old action num: 50
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
75: # old action num: 51
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
76: # old action num: 52
|
||||
action: "HOST_NIC_DISABLE"
|
||||
action: "host_nic_disable"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
node_name: client_2
|
||||
nic_num: 1
|
||||
77: # old action num: 53
|
||||
action: "HOST_NIC_ENABLE"
|
||||
action: "host_nic_enable"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
node_name: client_2
|
||||
nic_num: 1
|
||||
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_name: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
services:
|
||||
- service_name: DatabaseService
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,68 +6,48 @@ game:
|
||||
agents:
|
||||
- ref: RL_Agent
|
||||
type: ProxyAgent
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
- node_name: server
|
||||
max_folders_per_node: 0
|
||||
max_files_per_folder: 0
|
||||
max_services_per_node: 0
|
||||
max_nics_per_node: 1
|
||||
max_acl_rules: 0
|
||||
ip_list:
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
reward_function:
|
||||
reward_components: []
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
|
||||
simulation:
|
||||
network:
|
||||
|
||||
@@ -6,25 +6,17 @@ agents: &greens
|
||||
action_probabilities:
|
||||
0: 0.2
|
||||
1: 0.8
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client
|
||||
application_name: DatabaseClient
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -6,25 +6,17 @@ agents: &greens
|
||||
action_probabilities:
|
||||
0: 0.95
|
||||
1: 0.05
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
action: node_application_execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
node_name: client
|
||||
application_name: DatabaseClient
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
@@ -3,24 +3,9 @@ reds: &reds
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 10
|
||||
frequency: 10
|
||||
variance: 0
|
||||
possible_start_nodes: [client,]
|
||||
target_application: DataManipulationBot
|
||||
start_step: 10
|
||||
frequency: 10
|
||||
variance: 0
|
||||
|
||||
@@ -3,24 +3,9 @@ reds: &reds
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 3
|
||||
frequency: 2
|
||||
variance: 1
|
||||
possible_start_nodes: [client_1]
|
||||
target_application: DataManipulationBot
|
||||
start_step: 3
|
||||
frequency: 2
|
||||
variance: 1
|
||||
|
||||
@@ -54,65 +54,46 @@ agents:
|
||||
- server:eth-1<->switch_1:eth-2
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
action: do_nothing
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
action: node_shutdown
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 0
|
||||
node_name: client_1
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
action: node_startup
|
||||
options:
|
||||
node_id: 1
|
||||
node_name: server
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
action: host_nic_disable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
node_name: client_1
|
||||
nic_num: 1
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
action: host_nic_enable
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
- node_name: server
|
||||
|
||||
max_folders_per_node: 0
|
||||
max_files_per_folder: 0
|
||||
max_services_per_node: 0
|
||||
max_nics_per_node: 1
|
||||
max_acl_rules: 0
|
||||
ip_list:
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
node_name: server
|
||||
nic_num: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
33
src/primaite/game/agent/actions/__init__.py
Normal file
33
src/primaite/game/agent/actions/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
abstract,
|
||||
acl,
|
||||
application,
|
||||
file,
|
||||
folder,
|
||||
host_nic,
|
||||
manager,
|
||||
network,
|
||||
node,
|
||||
service,
|
||||
session,
|
||||
software,
|
||||
)
|
||||
from primaite.game.agent.actions.manager import ActionManager
|
||||
|
||||
__all__ = (
|
||||
"abstract",
|
||||
"acl",
|
||||
"application",
|
||||
"software",
|
||||
"file",
|
||||
"folder",
|
||||
"host_nic",
|
||||
"manager",
|
||||
"network",
|
||||
"node",
|
||||
"service",
|
||||
"session",
|
||||
"ActionManager",
|
||||
)
|
||||
36
src/primaite/game/agent/actions/abstract.py
Normal file
36
src/primaite/game/agent/actions/abstract.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, ClassVar, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
|
||||
class AbstractAction(BaseModel, ABC):
|
||||
"""Base class for actions."""
|
||||
|
||||
config: "AbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Base configuration schema for Actions."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str = ""
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAction]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create new action under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
pass
|
||||
157
src/primaite/game/agent/actions/acl.py
Normal file
157
src/primaite/game/agent/actions/acl.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Literal, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
__all__ = (
|
||||
"RouterACLAddRuleAction",
|
||||
"RouterACLRemoveRuleAction",
|
||||
"FirewallACLAddRuleAction",
|
||||
"FirewallACLRemoveRuleAction",
|
||||
)
|
||||
|
||||
|
||||
class ACLAddRuleAbstractAction(AbstractAction, ABC):
|
||||
"""Base abstract class for ACL add rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL add rule abstract actions."""
|
||||
|
||||
src_ip: IPV4Address
|
||||
protocol_name: Union[IPProtocol, Literal["ALL"]]
|
||||
permission: Literal["PERMIT", "DENY"]
|
||||
position: int
|
||||
dst_ip: Union[IPV4Address, Literal["ALL"]]
|
||||
src_port: Union[Port, Literal["ALL"]]
|
||||
dst_port: Union[Port, Literal["ALL"]]
|
||||
src_wildcard: Union[IPV4Address, Literal["NONE"]]
|
||||
dst_wildcard: Union[IPV4Address, Literal["NONE"]]
|
||||
|
||||
|
||||
class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"):
|
||||
"""Base abstract class for ACL remove rule actions."""
|
||||
|
||||
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema base for ACL remove rule abstract actions."""
|
||||
|
||||
position: int
|
||||
|
||||
|
||||
class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
config: "RouterACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for RouterACLAddRuleAction."""
|
||||
|
||||
target_router: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_router,
|
||||
"acl",
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
str(config.src_ip),
|
||||
str(config.src_wildcard),
|
||||
config.src_port,
|
||||
str(config.dst_ip),
|
||||
str(config.dst_wildcard),
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
config: "RouterACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for RouterACLRemoveRuleAction."""
|
||||
|
||||
target_router: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", config.target_router, "acl", "remove_rule", config.position]
|
||||
|
||||
|
||||
class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"):
|
||||
"""Action which adds a rule to a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLAddRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLAddRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_firewall_nodename,
|
||||
config.firewall_port_name,
|
||||
config.firewall_port_direction,
|
||||
"acl",
|
||||
"add_rule",
|
||||
config.permission,
|
||||
config.protocol_name,
|
||||
str(config.src_ip),
|
||||
str(config.src_wildcard),
|
||||
config.src_port,
|
||||
str(config.dst_ip),
|
||||
str(config.dst_wildcard),
|
||||
config.dst_port,
|
||||
config.position,
|
||||
]
|
||||
|
||||
|
||||
class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"):
|
||||
"""Action which removes a rule from a firewall port's ACL."""
|
||||
|
||||
config: "FirewallACLRemoveRuleAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for FirewallACLRemoveRuleAction."""
|
||||
|
||||
target_firewall_nodename: str
|
||||
firewall_port_name: str
|
||||
firewall_port_direction: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_firewall_nodename,
|
||||
config.firewall_port_name,
|
||||
config.firewall_port_direction,
|
||||
"acl",
|
||||
"remove_rule",
|
||||
config.position,
|
||||
]
|
||||
137
src/primaite/game/agent/actions/application.py
Normal file
137
src/primaite/game/agent/actions/application.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.abstract import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeApplicationExecuteAction",
|
||||
"NodeApplicationScanAction",
|
||||
"NodeApplicationCloseAction",
|
||||
"NodeApplicationFixAction",
|
||||
"NodeApplicationInstallAction",
|
||||
"NodeApplicationRemoveAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeApplicationAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Base class for application actions.
|
||||
|
||||
Any action which applies to an application and uses node_name and application_name as its only two parameters can
|
||||
inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeApplicationAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for Node Application actions."""
|
||||
|
||||
node_name: str
|
||||
application_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"application",
|
||||
config.application_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"):
|
||||
"""Action which executes an application."""
|
||||
|
||||
config: "NodeApplicationExecuteAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationExecuteAction."""
|
||||
|
||||
verb: str = "execute"
|
||||
|
||||
|
||||
class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"):
|
||||
"""Action which scans an application."""
|
||||
|
||||
config: "NodeApplicationScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationScanAction."""
|
||||
|
||||
verb: str = "scan"
|
||||
|
||||
|
||||
class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"):
|
||||
"""Action which closes an application."""
|
||||
|
||||
config: "NodeApplicationCloseAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationCloseAction."""
|
||||
|
||||
verb: str = "close"
|
||||
|
||||
|
||||
class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"):
|
||||
"""Action which fixes an application."""
|
||||
|
||||
config: "NodeApplicationFixAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationFixAction."""
|
||||
|
||||
verb: str = "fix"
|
||||
|
||||
|
||||
class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"):
|
||||
"""Action which installs an application."""
|
||||
|
||||
config: "NodeApplicationInstallAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationInstallAction."""
|
||||
|
||||
verb: str = "install"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"software_manager",
|
||||
"application",
|
||||
config.verb,
|
||||
config.application_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"):
|
||||
"""Action which removes/uninstalls an application."""
|
||||
|
||||
config: "NodeApplicationRemoveAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeApplicationRemoveAction."""
|
||||
|
||||
verb: str = "uninstall"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"software_manager",
|
||||
"application",
|
||||
config.verb,
|
||||
config.application_name,
|
||||
]
|
||||
189
src/primaite/game/agent/actions/file.py
Normal file
189
src/primaite/game/agent/actions/file.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeFileCreateAction",
|
||||
"NodeFileScanAction",
|
||||
"NodeFileDeleteAction",
|
||||
"NodeFileRestoreAction",
|
||||
"NodeFileCorruptAction",
|
||||
"NodeFileAccessAction",
|
||||
"NodeFileCheckhashAction",
|
||||
"NodeFileRepairAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction, ABC):
|
||||
"""Abstract base class for file actions.
|
||||
|
||||
Any action which applies to a file and uses node_name, folder_name, and file_name as its
|
||||
only three parameters can inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeFileAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeFileAbstractAction."""
|
||||
|
||||
node_name: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
"folder",
|
||||
config.folder_name,
|
||||
"file",
|
||||
config.file_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"):
|
||||
"""Action which creates a new file in a given folder."""
|
||||
|
||||
config: "NodeFileCreateAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCreateAction."""
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
force: bool = False
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"file",
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"):
|
||||
"""Action which scans a file."""
|
||||
|
||||
config: "NodeFileScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"):
|
||||
"""Action which deletes a file."""
|
||||
|
||||
config: "NodeFileDeleteAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileDeleteAction."""
|
||||
|
||||
verb: ClassVar[str] = "delete"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"file",
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"):
|
||||
"""Action which restores a file."""
|
||||
|
||||
config: "NodeFileRestoreAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileRestoreAction."""
|
||||
|
||||
verb: ClassVar[str] = "restore"
|
||||
|
||||
|
||||
class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"):
|
||||
"""Action which corrupts a file."""
|
||||
|
||||
config: "NodeFileCorruptAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCorruptAction."""
|
||||
|
||||
verb: ClassVar[str] = "corrupt"
|
||||
|
||||
|
||||
class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"):
|
||||
"""Action which increases a file's access count."""
|
||||
|
||||
config: "NodeFileAccessAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileAccessAction."""
|
||||
|
||||
verb: ClassVar[str] = "access"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None or config.file_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
config.folder_name,
|
||||
config.file_name,
|
||||
]
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"):
|
||||
"""Action which checks the hash of a file."""
|
||||
|
||||
config: "NodeFileCheckhashAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFileCheckhashAction."""
|
||||
|
||||
verb: ClassVar[str] = "checkhash"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"):
|
||||
"""Action which repairs a file."""
|
||||
|
||||
config: "NodeFileRepairAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFileAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeFileRepairAction."""
|
||||
|
||||
verb: ClassVar[str] = "repair"
|
||||
117
src/primaite/game/agent/actions/folder.py
Normal file
117
src/primaite/game/agent/actions/folder.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeFolderScanAction",
|
||||
"NodeFolderCheckhashAction",
|
||||
"NodeFolderRepairAction",
|
||||
"NodeFolderRestoreAction",
|
||||
"NodeFolderCreateAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
|
||||
Any action which applies to a folder and uses node_name and folder_name as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
config: "NodeFolderAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NodeFolder actions."""
|
||||
|
||||
node_name: str
|
||||
folder_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
"folder",
|
||||
config.folder_name,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"):
|
||||
"""Action which scans a folder."""
|
||||
|
||||
config: "NodeFolderScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"):
|
||||
"""Action which checks the hash of a folder."""
|
||||
|
||||
config: "NodeFolderCheckhashAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderCheckhashAction."""
|
||||
|
||||
verb: ClassVar[str] = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"):
|
||||
"""Action which repairs a folder."""
|
||||
|
||||
config: "NodeFolderRepairAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderRepairAction."""
|
||||
|
||||
verb: ClassVar[str] = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"):
|
||||
"""Action which restores a folder."""
|
||||
|
||||
config: "NodeFolderRestoreAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderRestoreAction."""
|
||||
|
||||
verb: ClassVar[str] = "restore"
|
||||
|
||||
|
||||
class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"):
|
||||
"""Action which creates a new folder."""
|
||||
|
||||
config: "NodeFolderCreateAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeFolderAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeFolderCreateAction."""
|
||||
|
||||
verb: ClassVar[str] = "create"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.folder_name is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"file_system",
|
||||
config.verb,
|
||||
"folder",
|
||||
config.folder_name,
|
||||
]
|
||||
62
src/primaite/game/agent/actions/host_nic.py
Normal file
62
src/primaite/game/agent/actions/host_nic.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("HostNICEnableAction", "HostNICDisableAction")
|
||||
|
||||
|
||||
class HostNICAbstractAction(AbstractAction, ABC):
|
||||
"""
|
||||
Abstract base class for NIC actions.
|
||||
|
||||
Any action which applies to a NIC and uses node_name and nic_num as its only two parameters can inherit from this
|
||||
base class.
|
||||
"""
|
||||
|
||||
config: "HostNICAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for HostNIC actions."""
|
||||
|
||||
node_name: str
|
||||
nic_num: int
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.nic_num is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"network_interface",
|
||||
config.nic_num,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"):
|
||||
"""Action which enables a NIC."""
|
||||
|
||||
config: "HostNICEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for HostNICEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
config: "HostNICDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(HostNICAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for HostNICDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
108
src/primaite/game/agent/actions/manager.py
Normal file
108
src/primaite/game/agent/actions/manager.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""yaml example.
|
||||
|
||||
agents:
|
||||
- name: agent_1
|
||||
action_space:
|
||||
actions:
|
||||
- do_nothing
|
||||
- node_service_start
|
||||
- node_service_stop
|
||||
action_map:
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium import spaces
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from primaite.game.agent.actions.abstract import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("DoNothingAction", "ActionManager")
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction, identifier="do_nothing"):
|
||||
"""Do Nothing Action."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for do_nothingAction."""
|
||||
|
||||
type: str = "do_nothing"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class _ActionMapItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: str
|
||||
options: Dict
|
||||
|
||||
|
||||
class ActionManager(BaseModel):
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""Config Schema for ActionManager."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
action_map: Dict[int, _ActionMapItem] = {}
|
||||
"""Mapping between integer action choices and CAOS actions."""
|
||||
|
||||
@field_validator("action_map", mode="after")
|
||||
def consecutive_action_nums(cls, v: Dict) -> Dict:
|
||||
"""Make sure all numbers between 0 and N are represented as dict keys in action map."""
|
||||
assert all([i in v.keys() for i in range(len(v))])
|
||||
return v
|
||||
|
||||
config: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
|
||||
|
||||
action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""Init as empty, populate after model validation."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.action_map = {n: (v.action, v.options) for n, v in self.config.action_map.items()}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Produce action in CAOS format.
|
||||
|
||||
The agent chooses an action (as an integer), this is converted into an action in CAOS format
|
||||
The CAOS format is basically an action identifier, followed by parameters stored in a dictionary.
|
||||
"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_class = AbstractAction._registry[action_identifier]
|
||||
config = act_class.ConfigSchema(**action_options)
|
||||
return act_class.form_request(config=config)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config dictionary.
|
||||
|
||||
The action space config supports must contain the following key:
|
||||
``action_map`` - List of actions available to the agent, formatted as a dictionary where the key is the
|
||||
action number between 0 - N, and the value is the CAOS-formatted action.
|
||||
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
return cls(**cfg.get("options", {}), act_map=cfg.get("action_map"))
|
||||
57
src/primaite/game/agent/actions/network.py
Normal file
57
src/primaite/game/agent/actions/network.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
|
||||
|
||||
|
||||
class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"):
|
||||
"""Base class for Network port actions."""
|
||||
|
||||
config: "NetworkPortAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NetworkPort actions."""
|
||||
|
||||
target_nodename: str
|
||||
port_num: int
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.target_nodename is None or config.port_num is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.target_nodename,
|
||||
"network_interface",
|
||||
config.port_num,
|
||||
config.verb,
|
||||
]
|
||||
|
||||
|
||||
class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"):
|
||||
"""Action which enables are port on a router or a firewall."""
|
||||
|
||||
config: "NetworkPortEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NetworkPortEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"):
|
||||
"""Action which disables are port on a router or a firewall."""
|
||||
|
||||
config: "NetworkPortDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NetworkPortAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NetworkPortDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
186
src/primaite/game/agent/actions/node.py
Normal file
186
src/primaite/game/agent/actions/node.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import abstractmethod
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
__all__ = (
|
||||
"NodeOSScanAction",
|
||||
"NodeShutdownAction",
|
||||
"NodeStartupAction",
|
||||
"NodeResetAction",
|
||||
"NodeNMAPPingScanAction",
|
||||
"NodeNMAPPortScanAction",
|
||||
"NodeNetworkServiceReconAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction, identifier="node_abstract"):
|
||||
"""
|
||||
Abstract base class for node actions.
|
||||
|
||||
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
|
||||
"""
|
||||
|
||||
config: "NodeAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration schema for Node actions."""
|
||||
|
||||
node_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
print(config)
|
||||
return ["network", "node", config.node_name, config.verb]
|
||||
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"):
|
||||
"""Action which scans a node's OS."""
|
||||
|
||||
config: "NodeOSScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeOSScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"):
|
||||
"""Action which shuts down a node."""
|
||||
|
||||
config: "NodeShutdownAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeShutdownAction."""
|
||||
|
||||
verb: ClassVar[str] = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction, identifier="node_startup"):
|
||||
"""Action which starts up a node."""
|
||||
|
||||
config: "NodeStartupAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeStartupAction."""
|
||||
|
||||
verb: ClassVar[str] = "startup"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction, identifier="node_reset"):
|
||||
"""Action which resets a node."""
|
||||
|
||||
config: "NodeResetAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeResetAction."""
|
||||
|
||||
verb: ClassVar[str] = "reset"
|
||||
|
||||
|
||||
class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"):
|
||||
"""Base class for NodeNMAP actions."""
|
||||
|
||||
config: "NodeNMAPAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base Configuration Schema for NodeNMAP actions."""
|
||||
|
||||
target_ip_address: Union[str, List[str]]
|
||||
show: bool = False
|
||||
source_node: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
# NMAP action requests don't share a common format for their requests
|
||||
# This is just a placeholder to ensure the method is defined.
|
||||
pass
|
||||
|
||||
|
||||
class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"):
|
||||
"""Action which performs an NMAP ping scan."""
|
||||
|
||||
config: "NodeNMAPPingScanAction.ConfigSchema"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: "NodeNMAPPingScanAction.ConfigSchema") -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"ping_scan",
|
||||
{"target_ip_address": config.target_ip_address, "show": config.show},
|
||||
]
|
||||
|
||||
|
||||
class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"):
|
||||
"""Action which performs an NMAP port scan."""
|
||||
|
||||
config: "NodeNMAPPortScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeNMAPPortScanAction."""
|
||||
|
||||
source_node: str
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None
|
||||
target_port: Optional[Union[Port, List[Port]]] = None
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"port_scan",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"):
|
||||
"""Action which performs an NMAP network service recon (ping scan followed by port scan)."""
|
||||
|
||||
config: "NodeNetworkServiceReconAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNetworkServiceReconAction."""
|
||||
|
||||
target_protocol: Optional[Union[IPProtocol, List[IPProtocol]]] = None
|
||||
target_port: Optional[Union[Port, List[Port]]] = None
|
||||
show: Optional[bool] = (False,)
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.source_node,
|
||||
"application",
|
||||
"NMAP",
|
||||
"network_service_recon",
|
||||
{
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"target_port": config.target_port,
|
||||
"target_protocol": config.target_protocol,
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
135
src/primaite/game/agent/actions/service.py
Normal file
135
src/primaite/game/agent/actions/service.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import ClassVar
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeServiceScanAction",
|
||||
"NodeServiceStopAction",
|
||||
"NodeServiceStartAction",
|
||||
"NodeServicePauseAction",
|
||||
"NodeServiceResumeAction",
|
||||
"NodeServiceRestartAction",
|
||||
"NodeServiceDisableAction",
|
||||
"NodeServiceEnableAction",
|
||||
"NodeServiceFixAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"):
|
||||
"""Abstract Action for Node Service related actions.
|
||||
|
||||
Any actions which use node_name and service_name can inherit from this class.
|
||||
"""
|
||||
|
||||
config: "NodeServiceAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
node_name: str
|
||||
service_name: str
|
||||
verb: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", config.node_name, "service", config.service_name, config.verb]
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"):
|
||||
"""Action which scans a service."""
|
||||
|
||||
config: "NodeServiceScanAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceScanAction."""
|
||||
|
||||
verb: ClassVar[str] = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"):
|
||||
"""Action which stops a service."""
|
||||
|
||||
config: "NodeServiceStopAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceStopAction."""
|
||||
|
||||
verb: ClassVar[str] = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"):
|
||||
"""Action which starts a service."""
|
||||
|
||||
config: "NodeServiceStartAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceStartAction."""
|
||||
|
||||
verb: ClassVar[str] = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"):
|
||||
"""Action which pauses a service."""
|
||||
|
||||
config: "NodeServicePauseAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServicePauseAction."""
|
||||
|
||||
verb: ClassVar[str] = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"):
|
||||
"""Action which resumes a service."""
|
||||
|
||||
config: "NodeServiceResumeAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceResumeAction."""
|
||||
|
||||
verb: ClassVar[str] = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"):
|
||||
"""Action which restarts a service."""
|
||||
|
||||
config: "NodeServiceRestartAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceRestartAction."""
|
||||
|
||||
verb: ClassVar[str] = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"):
|
||||
"""Action which disables a service."""
|
||||
|
||||
config: "NodeServiceDisableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceDisableAction."""
|
||||
|
||||
verb: ClassVar[str] = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"):
|
||||
"""Action which enables a service."""
|
||||
|
||||
config: "NodeServiceEnableAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceEnableAction."""
|
||||
|
||||
verb: ClassVar[str] = "enable"
|
||||
|
||||
|
||||
class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"):
|
||||
"""Action which fixes a service."""
|
||||
|
||||
config: "NodeServiceFixAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeServiceAbstractAction.ConfigSchema):
|
||||
"""Configuration Schema for NodeServiceFixAction."""
|
||||
|
||||
verb: ClassVar[str] = "fix"
|
||||
108
src/primaite/game/agent/actions/session.py
Normal file
108
src/primaite/game/agent/actions/session.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import abstractmethod
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"NodeSessionsRemoteLoginAction",
|
||||
"NodeSessionsRemoteLogoutAction",
|
||||
"NodeAccountChangePasswordAction",
|
||||
)
|
||||
|
||||
|
||||
class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"):
|
||||
"""Base class for NodeSession actions."""
|
||||
|
||||
config: "NodeSessionAbstractAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Base configuration schema for NodeSessionAbstractActions."""
|
||||
|
||||
node_name: str
|
||||
remote_ip: str
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""
|
||||
Abstract method for request forming.
|
||||
|
||||
Should return the action formatted as a request which can be ingested by the PrimAITE simulation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"):
|
||||
"""Action which performs a remote session login."""
|
||||
|
||||
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLoginAction."""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.remote_ip is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"node_session_remote_login",
|
||||
config.username,
|
||||
config.password,
|
||||
config.remote_ip,
|
||||
]
|
||||
|
||||
|
||||
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"):
|
||||
"""Action which performs a remote session logout."""
|
||||
|
||||
config: "NodeSessionsRemoteLogoutAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
|
||||
|
||||
verb: str = "remote_logoff"
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if config.node_name is None or config.remote_ip is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip]
|
||||
|
||||
|
||||
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"):
|
||||
"""Action which changes the password for a user."""
|
||||
|
||||
config: "NodeAccountChangePasswordAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeAccountsChangePasswordAction."""
|
||||
|
||||
username: str
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"UserManager",
|
||||
"change_password",
|
||||
config.username,
|
||||
config.current_password,
|
||||
config.new_password,
|
||||
]
|
||||
241
src/primaite/game/agent/actions/software.py
Normal file
241
src/primaite/game/agent/actions/software.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
|
||||
__all__ = (
|
||||
"ConfigureRansomwareScriptAction",
|
||||
"ConfigureDoSBotAction",
|
||||
"ConfigureC2BeaconAction",
|
||||
"NodeSendRemoteCommandAction",
|
||||
"TerminalC2ServerAction",
|
||||
"RansomwareLaunchC2ServerAction",
|
||||
"ExfiltrationC2ServerAction",
|
||||
"ConfigureDatabaseClientAction",
|
||||
)
|
||||
|
||||
|
||||
class ConfigureRansomwareScriptAction(AbstractAction, identifier="configure_ransomware_script"):
|
||||
"""Action which sets config parameters for a ransomware script on a node."""
|
||||
|
||||
config: "ConfigureRansomwareScriptAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for ConfigureRansomwareScriptAction."""
|
||||
|
||||
node_name: str
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
data = dict(
|
||||
server_ip_address=config.server_ip_address,
|
||||
server_password=config.server_password,
|
||||
payload=config.payload,
|
||||
)
|
||||
return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", data]
|
||||
|
||||
|
||||
class RansomwareConfigureC2ServerAction(ConfigureRansomwareScriptAction, identifier="c2_server_ransomware_configure"):
|
||||
"""Action which causes a C2 server to send a command to set options on a ransomware script remotely."""
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigureRansomwareScriptAction.ConfigSchema) -> RequestFormat:
|
||||
data = dict(
|
||||
server_ip_address=config.server_ip_address, server_password=config.server_password, payload=config.payload
|
||||
)
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_configure", data]
|
||||
|
||||
|
||||
class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"):
|
||||
"""Action which sets config parameters for a DoS bot on a node."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_name: str
|
||||
target_ip_address: Optional[str] = None
|
||||
target_port: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
repeat: Optional[bool] = None
|
||||
port_scan_p_of_success: Optional[float] = None
|
||||
dos_intensity: Optional[float] = None
|
||||
max_sessions: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
data = dict(
|
||||
target_ip_address=config.target_ip_address,
|
||||
target_port=config.target_port,
|
||||
payload=config.payload,
|
||||
repeat=config.repeat,
|
||||
port_scan_p_of_success=config.port_scan_p_of_success,
|
||||
dos_intensity=config.dos_intensity,
|
||||
max_sessions=config.max_sessions,
|
||||
)
|
||||
data = {k: v for k, v in data.items() if v is not None}
|
||||
return ["network", "node", config.node_name, "application", "DoSBot", "configure", data]
|
||||
|
||||
|
||||
class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"):
|
||||
"""Action which configures a C2 Beacon based on the parameters given."""
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for ConfigureC2BeaconAction."""
|
||||
|
||||
node_name: str
|
||||
c2_server_ip_address: str
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
masquerade_protocol: str = Field(default="TCP")
|
||||
masquerade_port: str = Field(default="HTTP")
|
||||
|
||||
@classmethod
|
||||
def form_request(self, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
data = dict(
|
||||
c2_server_ip_address=config.c2_server_ip_address,
|
||||
keep_alive_frequency=config.keep_alive_frequency,
|
||||
masquerade_protocol=config.masquerade_protocol,
|
||||
masquerade_port=config.masquerade_port,
|
||||
)
|
||||
return ["network", "node", config.node_name, "application", "C2Beacon", "configure", data]
|
||||
|
||||
|
||||
class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"):
|
||||
"""Action which sends a terminal command to a remote node via SSH."""
|
||||
|
||||
config: "NodeSendRemoteCommandAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSendRemoteCommandAction."""
|
||||
|
||||
node_name: str
|
||||
remote_ip: str
|
||||
command: RequestFormat
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"Terminal",
|
||||
"send_remote_command",
|
||||
config.remote_ip,
|
||||
{"command": config.command},
|
||||
]
|
||||
|
||||
|
||||
class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"):
|
||||
"""Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed."""
|
||||
|
||||
config: "TerminalC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
commands: Union[List[RequestFormat], RequestFormat]
|
||||
ip_address: Optional[str]
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
|
||||
command_model = {
|
||||
"commands": config.commands,
|
||||
"ip_address": config.ip_address,
|
||||
"username": config.username,
|
||||
"password": config.password,
|
||||
}
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model]
|
||||
|
||||
|
||||
class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"):
|
||||
"""Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript."""
|
||||
|
||||
config: "RansomwareLaunchC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for RansomwareLaunchC2ServerAction."""
|
||||
|
||||
node_name: str
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
# This action currently doesn't require any further configuration options.
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "ransomware_launch"]
|
||||
|
||||
|
||||
class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"):
|
||||
"""Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server."""
|
||||
|
||||
config: "ExfiltrationC2ServerAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
target_ip_address: str
|
||||
target_file_name: str
|
||||
target_folder_name: str
|
||||
exfiltration_folder_name: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
|
||||
command_model = {
|
||||
"target_file_name": config.target_file_name,
|
||||
"target_folder_name": config.target_folder_name,
|
||||
"exfiltration_folder_name": config.exfiltration_folder_name,
|
||||
"target_ip_address": config.target_ip_address,
|
||||
"username": config.username,
|
||||
"password": config.password,
|
||||
}
|
||||
return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model]
|
||||
|
||||
|
||||
class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"):
|
||||
"""Action which sets config parameters for a database client on a node."""
|
||||
|
||||
config: "ConfigureDatabaseClientAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Schema for options that can be passed to this action."""
|
||||
|
||||
node_name: str
|
||||
server_ip_address: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def form_request(cls, config: ConfigSchema) -> RequestFormat:
|
||||
"""Return the action formatted as a request that can be ingested by the simulation."""
|
||||
if config.node_name is None:
|
||||
return ["do_nothing"]
|
||||
data = {"server_ip_address": config.server_ip_address, "server_password": config.server_password}
|
||||
return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", data]
|
||||
@@ -1,6 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
@@ -20,20 +21,22 @@ class _NotJSONFilter(logging.Filter):
|
||||
|
||||
class AgentLog:
|
||||
"""
|
||||
A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent.
|
||||
An Agent Log class is a simple logger dedicated to managing and writing updates and information for an agent.
|
||||
|
||||
Each log message is written to a file located at: <simulation output directory>/agent_name/agent_name.log
|
||||
Each log message is written to a file located at:
|
||||
<simulation output directory>/agent_name/agent_name.log
|
||||
"""
|
||||
|
||||
def __init__(self, agent_name: str):
|
||||
def __init__(self, agent_name: Optional[str]):
|
||||
"""
|
||||
Constructs a Agent Log instance for a given hostname.
|
||||
|
||||
:param hostname: The hostname associated with the system logs being recorded.
|
||||
:param agent_name: The agent_name associated with the system logs being recorded.
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
self.current_episode: int = 1
|
||||
super().__init__()
|
||||
self.agent_name = agent_name if agent_name else "unnamed_agent"
|
||||
self.current_timestep: int = 0
|
||||
self.current_episode: int = 1
|
||||
self.setup_logger()
|
||||
|
||||
@property
|
||||
@@ -90,7 +93,7 @@ class AgentLog:
|
||||
|
||||
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
|
||||
if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal:
|
||||
print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}")
|
||||
print(f"{self.agent_name}: ({self.timestep}) ({level}) {msg}")
|
||||
|
||||
def debug(self, msg: str, to_terminal: bool = False):
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""Interface for agents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.agent_log import AgentLog
|
||||
@@ -15,6 +17,8 @@ from primaite.interface.request import RequestFormat, RequestResponse
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
__all__ = ("AgentHistoryItem", "AbstractAgent", "AbstractScriptedAgent", "ProxyAgent")
|
||||
|
||||
|
||||
class AgentHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
@@ -39,89 +43,56 @@ class AgentHistoryItem(BaseModel):
|
||||
reward_info: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "AgentStartSettings":
|
||||
"""
|
||||
Make sure variance is equal to or lower than frequency.
|
||||
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
|
||||
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
|
||||
"""
|
||||
if self.variance > self.frequency:
|
||||
raise ValueError(
|
||||
f"Agent start settings error: variance must be lower than frequency "
|
||||
f"{self.variance=}, {self.frequency=}"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
"""Settings for configuring the operation of an agent."""
|
||||
|
||||
start_settings: Optional[AgentStartSettings] = None
|
||||
"Configuration for when an agent begins performing it's actions"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
action_masking: bool = False
|
||||
"Whether to return action masks at each step."
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
|
||||
"""Construct agent settings from a config dictionary.
|
||||
|
||||
:param config: A dict of options for the agent settings.
|
||||
:type config: Dict
|
||||
:return: The agent settings.
|
||||
:rtype: AgentSettings
|
||||
"""
|
||||
if config is None:
|
||||
return cls()
|
||||
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
class AbstractAgent(BaseModel, ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
:type agent_name: Optional[str]
|
||||
:param action_space: Action space for the agent.
|
||||
:type action_space: Optional[ActionManager]
|
||||
:param observation_space: Observation space for the agent.
|
||||
:type observation_space: Optional[ObservationSpace]
|
||||
:param reward_function: Reward function for the agent.
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
:param agent_settings: Configurable Options for Abstracted Agents
|
||||
:type agent_settings: Optional[AgentSettings]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_manager: Optional[ActionManager] = action_space
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.history: List[AgentHistoryItem] = []
|
||||
self.logger = AgentLog(agent_name)
|
||||
class AgentSettingsSchema(BaseModel, ABC):
|
||||
"""Schema for the 'agent_settings' key."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Configuration Schema for AbstractAgents."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
type: str
|
||||
ref: str = ""
|
||||
"""name of the agent."""
|
||||
team: Optional[Literal["BLUE", "GREEN", "RED"]] = None
|
||||
agent_settings: AbstractAgent.AgentSettingsSchema = Field(default=lambda: AbstractAgent.AgentSettingsSchema())
|
||||
action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
|
||||
observation_space: ObservationManager.ConfigSchema = Field(
|
||||
default_factory=lambda: ObservationManager.ConfigSchema()
|
||||
)
|
||||
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
|
||||
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
|
||||
|
||||
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
|
||||
history: List[AgentHistoryItem] = []
|
||||
|
||||
action_manager: ActionManager = Field(default_factory=lambda: ActionManager())
|
||||
observation_manager: ObservationManager = Field(default_factory=lambda: ObservationManager())
|
||||
reward_function: RewardFunction = Field(default_factory=lambda: RewardFunction())
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
|
||||
self.action_manager = ActionManager(config=self.config.action_space)
|
||||
self.observation_manager = ObservationManager(config=self.config.observation_space)
|
||||
self.reward_function = RewardFunction(config=self.config.reward_function)
|
||||
return super().model_post_init(__context)
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -159,9 +130,9 @@ class AbstractAgent(ABC):
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
return ("do_nothing", {})
|
||||
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> RequestFormat:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
@@ -182,36 +153,47 @@ class AbstractAgent(ABC):
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractAgent:
|
||||
"""Grab the relevant agent class and construct an instance from a config dict."""
|
||||
agent_type = config["type"]
|
||||
agent_class = cls._registry[agent_type]
|
||||
return agent_class(config=config)
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for AbstractScriptedAgents."""
|
||||
|
||||
type: str = "AbstractScriptedAgent"
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return an action to be taken in the environment."""
|
||||
return super().get_action(obs=obs, timestep=timestep)
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent):
|
||||
class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
agent_settings: Optional[AgentSettings] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema())
|
||||
most_recent_action: ActType = None
|
||||
|
||||
class AgentSettingsSchema(AbstractAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
flatten_obs: bool = False
|
||||
action_masking: bool = False
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Proxy Agent."""
|
||||
|
||||
type: str = "Proxy_Agent"
|
||||
agent_settings: ProxyAgent.AgentSettingsSchema = Field(default_factory=lambda: ProxyAgent.AgentSettingsSchema())
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
@@ -233,3 +215,8 @@ class ProxyAgent(AbstractAgent):
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
|
||||
@property
|
||||
def flatten_obs(self) -> bool:
|
||||
"""Return agent flatten_obs param."""
|
||||
return self.config.agent_settings.flatten_obs
|
||||
|
||||
@@ -17,5 +17,5 @@ from primaite.game.agent.observations.software_observation import ApplicationObs
|
||||
__all__ = [
|
||||
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
|
||||
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation", "RouterObservation", "LinkObservation",]
|
||||
# fmt: on
|
||||
|
||||
@@ -24,8 +24,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"""List of IP addresses."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of wildcard strings."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of port numbers."""
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of port names."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocol names."""
|
||||
num_rules: Optional[int] = None
|
||||
@@ -37,7 +37,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
num_rules: int,
|
||||
ip_list: List[IPv4Address],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
port_list: List[str],
|
||||
protocol_list: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
@@ -51,8 +51,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
:type ip_list: List[IPv4Address]
|
||||
:param wildcard_list: List of wildcard strings.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param port_list: List of port names.
|
||||
:type port_list: List[str]
|
||||
:param protocol_list: List of protocol names.
|
||||
:type protocol_list: List[str]
|
||||
"""
|
||||
@@ -60,7 +60,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
|
||||
@@ -190,6 +190,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
if self.files:
|
||||
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
|
||||
|
||||
self.cached_obs: Optional[ObsType] = self.default_observation
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
@@ -204,7 +206,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
return self.default_observation
|
||||
|
||||
if self.file_system_requires_scan:
|
||||
health_status = folder_state["visible_status"]
|
||||
if not folder_state["scanned_this_step"]:
|
||||
health_status = self.cached_obs["health_status"]
|
||||
else:
|
||||
health_status = folder_state["visible_status"]
|
||||
else:
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
|
||||
@@ -27,13 +27,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
include_users: Optional[bool] = True
|
||||
include_users: Optional[bool] = None
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -41,7 +41,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
port_list: List[str],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
include_users: bool,
|
||||
@@ -56,8 +56,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:type ip_list: List[str]
|
||||
:param wildcard_list: List of wildcard rules.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param port_list: List of port names.
|
||||
:type port_list: List[str]
|
||||
:param protocol_list: List of protocol types.
|
||||
:type protocol_list: List[str]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
@@ -72,7 +72,6 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
# TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(
|
||||
where=self.where + ["internal_inbound_acl", "acl"],
|
||||
@@ -140,6 +139,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.include_users:
|
||||
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -153,29 +154,35 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
firewall_state = access_from_nested_dict(state, self.where)
|
||||
if firewall_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
|
||||
is_on = firewall_state["operating_state"] == 1
|
||||
if not is_on:
|
||||
obs = {**self.default_observation}
|
||||
|
||||
else:
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.include_users:
|
||||
sess = firewall_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
if self.include_users:
|
||||
sess = firewall_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -186,34 +193,36 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
:return: Gymnasium space representing the observation space for firewall status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return space
|
||||
shape = {
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
if self.include_users:
|
||||
shape["users"] = spaces.Dict(
|
||||
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
|
||||
)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
|
||||
|
||||
@@ -54,7 +54,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""
|
||||
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
include_users: Optional[bool] = True
|
||||
include_users: Optional[bool] = None
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -191,25 +191,31 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
is_on = node_state["operating_state"] == 1
|
||||
if not is_on:
|
||||
obs = {**self.default_observation}
|
||||
|
||||
else:
|
||||
obs = {}
|
||||
if self.services:
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
if self.applications:
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
if self.folders:
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
if self.include_users:
|
||||
sess = node_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
if self.services:
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
if self.applications:
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
if self.folders:
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
if self.include_users:
|
||||
sess = node_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
||||
@@ -56,7 +56,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
|
||||
from pydantic import BaseModel, computed_field, ConfigDict, Field, model_validator, ValidationError
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
@@ -140,7 +141,7 @@ class NullObservation(AbstractObservation, identifier="NONE"):
|
||||
return cls()
|
||||
|
||||
|
||||
class ObservationManager:
|
||||
class ObservationManager(BaseModel):
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
@@ -150,15 +151,66 @@ class ObservationManager:
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
def __init__(self, obs: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = obs
|
||||
self.current_observation: ObsType
|
||||
"""Cached copy of the observation at the time it was most recently calculated."""
|
||||
class ConfigSchema(BaseModel):
|
||||
"""Config Schema for Observation Manager."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str = "NONE"
|
||||
"""Identifier name for the top-level observation."""
|
||||
options: AbstractObservation.ConfigSchema = Field(
|
||||
default_factory=lambda: NullObservation.ConfigSchema(), validate_default=True
|
||||
)
|
||||
"""Options to pass into the top-level observation during creation."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def resolve_obs_options_type(cls, data: Any) -> Any:
|
||||
"""
|
||||
When constructing the model from a dict, resolve the correct observation class based on `type` field.
|
||||
|
||||
Workaround: The `options` field is statically typed as AbstractObservation. Therefore, it falls over when
|
||||
passing in data that adheres to a subclass schema rather than the plain AbstractObservation schema. There is
|
||||
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
|
||||
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
|
||||
our plugin architecture.
|
||||
|
||||
We may be able to revisit and implement a better solution when needed using the following resources as
|
||||
research starting points:
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
https://github.com/pydantic/pydantic/issues/7366
|
||||
https://github.com/pydantic/pydantic/issues/7462
|
||||
https://github.com/pydantic/pydantic/pull/7983
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
# (TODO: duplicate default definition between here and the actual model)
|
||||
obs_type = data["type"] if "type" in data else "NONE"
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
|
||||
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields
|
||||
if "options" not in data:
|
||||
data["options"] = obs_class.ConfigSchema()
|
||||
|
||||
# if options passed as a dict, validate against schema
|
||||
elif isinstance(data["options"], dict):
|
||||
data["options"] = obs_class.ConfigSchema(**data["options"])
|
||||
|
||||
return data
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: ObservationManager.ConfigSchema())
|
||||
|
||||
current_observation: ObsType = 0
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def obs(self) -> AbstractObservation:
|
||||
"""Create the main observation component for the observation manager from the config."""
|
||||
obs_class = AbstractObservation._registry[self.config.type]
|
||||
obs_instance = obs_class.from_config(config=self.config.options)
|
||||
return obs_instance
|
||||
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,7 @@ class AbstractObservation(ABC):
|
||||
"""Initialise an observation. This method must be overwritten."""
|
||||
self.default_observation: ObsType
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an observation type.
|
||||
|
||||
@@ -40,6 +40,8 @@ class AbstractObservation(ABC):
|
||||
:raises ValueError: When attempting to create a component with a name that is already in use.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate observation component type {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -33,13 +33,13 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
port_list: Optional[List[str]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
include_users: Optional[bool] = True
|
||||
include_users: Optional[bool] = None
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -84,6 +84,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
}
|
||||
if self.ports:
|
||||
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
|
||||
if self.include_users:
|
||||
self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -98,16 +100,21 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
if router_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
if self.include_users:
|
||||
sess = router_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
is_on = router_state["operating_state"] == 1
|
||||
if not is_on:
|
||||
obs = {**self.default_observation}
|
||||
|
||||
else:
|
||||
obs = {}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
if self.include_users:
|
||||
sess = router_state["services"]["UserSessionManager"]
|
||||
obs["users"] = {
|
||||
"local_login": 1 if sess["current_local_user"] else 0,
|
||||
"remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])),
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -121,6 +128,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
shape = {"ACL": self.acl.space}
|
||||
if self.ports:
|
||||
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
|
||||
if self.include_users:
|
||||
shape["users"] = spaces.Dict(
|
||||
{"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)}
|
||||
)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -30,7 +30,7 @@ the structure:
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -48,21 +48,17 @@ class AbstractReward(BaseModel):
|
||||
|
||||
config: "AbstractReward.ConfigSchema"
|
||||
|
||||
# def __init__(self, schema_name, **kwargs):
|
||||
# super.__init__(self, **kwargs)
|
||||
# # Create ConfigSchema class
|
||||
# self.config_class = type(schema_name, (BaseModel, ABC), **kwargs)
|
||||
# self.config = self.config_class()
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Config schema for AbstractReward."""
|
||||
|
||||
type: str
|
||||
type: str = ""
|
||||
|
||||
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate reward {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
@@ -381,14 +377,19 @@ class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
"""Apply a negative reward when taking any action except do_nothing."""
|
||||
|
||||
config: "ActionPenalty.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for ActionPenalty."""
|
||||
"""Config schema for ActionPenalty.
|
||||
|
||||
:param action_penalty: Reward to give agents for taking any action except do_nothing
|
||||
:type action_penalty: float
|
||||
:param do_nothing_penalty: Reward to give agent for taking the do_nothing action
|
||||
:type do_nothing_penalty: float
|
||||
"""
|
||||
|
||||
type: str = "ACTION_PENALTY"
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
@@ -402,21 +403,81 @@ class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
if last_action_response.action == "do_nothing":
|
||||
return self.config.do_nothing_penalty
|
||||
|
||||
else:
|
||||
return self.config.action_penalty
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
class _SingleComponentConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
options: AbstractReward.ConfigSchema
|
||||
weight: float = 1.0
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def resolve_obs_options_type(cls, data: Any) -> Any:
|
||||
"""
|
||||
When constructing the model from a dict, resolve the correct reward class based on `type` field.
|
||||
|
||||
Workaround: The `options` field is statically typed as AbstractReward. Therefore, it falls over when
|
||||
passing in data that adheres to a subclass schema rather than the plain AbstractReward schema. There is
|
||||
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
|
||||
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
|
||||
our plugin architecture.
|
||||
|
||||
We may be able to revisit and implement a better solution when needed using the following resources as
|
||||
research starting points:
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
https://github.com/pydantic/pydantic/issues/7366
|
||||
https://github.com/pydantic/pydantic/issues/7462
|
||||
https://github.com/pydantic/pydantic/pull/7983
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
assert "type" in data, ValueError('Reward component definition is missing the "type" key.')
|
||||
rew_type = data["type"]
|
||||
rew_class = AbstractReward._registry[rew_type]
|
||||
|
||||
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields.
|
||||
if "options" not in data:
|
||||
data["options"] = rew_class.ConfigSchema()
|
||||
|
||||
# if options are passed as a dict, validate against schema
|
||||
elif isinstance(data["options"], dict):
|
||||
data["options"] = rew_class.ConfigSchema(**data["options"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class RewardFunction(BaseModel):
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
self.current_reward: float = 0.0
|
||||
self.total_reward: float = 0.0
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""Config Schema for RewardFunction."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
reward_components: Iterable[_SingleComponentConfig] = []
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
|
||||
reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
|
||||
current_reward: float = 0.0
|
||||
total_reward: float = 0.0
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
for rew_config in self.config.reward_components:
|
||||
rew_class = AbstractReward._registry[rew_config.type]
|
||||
rew_instance = rew_class(config=rew_config.options)
|
||||
self.register_component(component=rew_instance, weight=rew_config.weight)
|
||||
|
||||
def register_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
|
||||
@@ -1 +1,6 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite.game.agent import interface
|
||||
from primaite.game.agent.scripted_agents import abstract_tap, data_manipulation_bot, probabilistic_agent, random_agent
|
||||
|
||||
__all__ = ("abstract_tap", "data_manipulation_bot", "interface", "probabilistic_agent", "random_agent")
|
||||
|
||||
61
src/primaite/game/agent/scripted_agents/abstract_tap.py
Normal file
61
src/primaite/game/agent/scripted_agents/abstract_tap.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
|
||||
__all__ = "AbstractTAPAgent"
|
||||
|
||||
|
||||
class AbstractTAPAgent(PeriodicAgent, identifier="AbstractTAP"):
|
||||
"""Base class for TAP agents to inherit from."""
|
||||
|
||||
config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema())
|
||||
next_execution_timestep: int = 0
|
||||
|
||||
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
possible_starting_nodes: List[str] = Field(default_factory=list)
|
||||
|
||||
class ConfigSchema(PeriodicAgent.ConfigSchema):
|
||||
"""Configuration schema for Abstract TAP agents."""
|
||||
|
||||
type: str = "AbstractTAP"
|
||||
agent_settings: AbstractTAPAgent.AgentSettingsSchema = Field(
|
||||
default_factory=lambda: AbstractTAPAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
starting_node: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return an action to be taken in the environment."""
|
||||
return super().get_action(obs=obs, timestep=timestep)
|
||||
|
||||
@abstractmethod
|
||||
def setup_agent(self) -> None:
|
||||
"""Set up agent."""
|
||||
pass
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.config.agent_settings.variance, self.config.agent_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
self.starting_node = random.choice(self.config.agent_settings.possible_starting_nodes)
|
||||
self.logger.debug(f"Selected starting node: {self.starting_node}")
|
||||
@@ -1,31 +1,35 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
|
||||
__all__ = "DataManipulationAgent"
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
class DataManipulationAgent(PeriodicAgent, identifier="RedDatabaseCorruptingAgent"):
|
||||
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
target_application: str = "DataManipulationBot"
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
class ConfigSchema(PeriodicAgent.ConfigSchema):
|
||||
"""Configuration Schema for DataManipulationAgent."""
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
type: str = "RedDatabaseCorruptingAgent"
|
||||
agent_settings: "DataManipulationAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: DataManipulationAgent.AgentSettingsSchema()
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
config: "DataManipulationAgent.ConfigSchema" = Field(default_factory=lambda: DataManipulationAgent.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._set_next_execution_timestep(timestep=self.config.agent_settings.start_step, variance=0)
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
|
||||
@@ -38,21 +42,14 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
self.logger.debug(msg="Performing do NOTHING")
|
||||
return "DONOTHING", {}
|
||||
self.logger.debug(msg="Performing do nothing action")
|
||||
return "do_nothing", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
self._set_next_execution_timestep(
|
||||
timestep=timestep + self.config.agent_settings.frequency, variance=self.config.agent_settings.variance
|
||||
)
|
||||
self.logger.info(msg="Performing a data manipulation attack!")
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")
|
||||
return "node_application_execute", {
|
||||
"node_name": self.start_node,
|
||||
"application_name": self.config.agent_settings.target_application,
|
||||
}
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""Agents with predefined behaviours."""
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from gymnasium.core import ObsType
|
||||
from numpy.random import Generator
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
__all__ = "ProbabilisticAgent"
|
||||
|
||||
|
||||
class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"):
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
class Settings(pydantic.BaseModel):
|
||||
"""Config schema for Probabilistic agent settings."""
|
||||
rng: Generator = Field(default_factory=lambda: np.random.default_rng(np.random.randint(0, 65535)))
|
||||
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
"""Strict validation."""
|
||||
action_probabilities: Dict[int, float]
|
||||
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
action_probabilities: Dict[int, float] = None
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
|
||||
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
|
||||
|
||||
@pydantic.field_validator("action_probabilities", mode="after")
|
||||
@classmethod
|
||||
@@ -44,31 +43,20 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
)
|
||||
return v
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
settings: Dict = {},
|
||||
) -> None:
|
||||
# If the action probabilities are not specified, create equal probabilities for all actions
|
||||
if "action_probabilities" not in settings:
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
# The random number seed for np.random is dependent on whether a random number seed is set
|
||||
# in the config file. If there is one it is processed by set_random_seed() in environment.py
|
||||
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
|
||||
self.settings = ProbabilisticAgent.Settings(**settings)
|
||||
rng_seed = np.random.randint(0, 65535)
|
||||
self.rng = np.random.default_rng(rng_seed)
|
||||
type: str = "ProbabilisticAgent"
|
||||
agent_settings: "ProbabilisticAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: ProbabilisticAgent.AgentSettingsSchema()
|
||||
)
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
|
||||
config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema())
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
|
||||
@property
|
||||
def probabilities(self) -> Dict[str, int]:
|
||||
"""Convenience method to view the probabilities of the Agent."""
|
||||
return np.asarray(list(self.config.agent_settings.action_probabilities.values()))
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel
|
||||
from pydantic import computed_field, Field, model_validator
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
__all__ = ("RandomAgent", "PeriodicAgent")
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
class RandomAgent(AbstractScriptedAgent, identifier="RandomAgent"):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
config: "RandomAgent.ConfigSchema" = Field(default_factory=lambda: RandomAgent.ConfigSchema())
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Random Agents."""
|
||||
|
||||
type: str = "RandomAgent"
|
||||
|
||||
def get_action(self) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
|
||||
:param obs: Current observation for this agent, not used in RandomAgent
|
||||
@@ -27,41 +34,60 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class PeriodicAgent(AbstractScriptedAgent):
|
||||
class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema())
|
||||
|
||||
start_step: int = 20
|
||||
"The timestep at which an agent begins performing it's actions."
|
||||
start_variance: int = 5
|
||||
"Deviation around the start step."
|
||||
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
|
||||
"""Schema for the `agent_settings` part of the agent config."""
|
||||
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions."
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to."
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
"The amount the frequency can randomly change to"
|
||||
possible_start_nodes: List[str]
|
||||
target_application: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str,
|
||||
action_space: ActionManager,
|
||||
observation_space: ObservationManager,
|
||||
reward_function: RewardFunction,
|
||||
settings: Optional[Settings] = None,
|
||||
) -> None:
|
||||
"""Initialise PeriodicAgent."""
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema":
|
||||
"""
|
||||
Make sure variance is equal to or lower than frequency.
|
||||
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance).
|
||||
If variance were greater than frequency, sometimes the bracketed term would be negative
|
||||
and the attack would never happen again.
|
||||
"""
|
||||
if self.variance >= self.frequency:
|
||||
raise ValueError(
|
||||
f"Agent start settings error: variance must be lower than frequency "
|
||||
f"{self.variance=}, {self.frequency=}"
|
||||
)
|
||||
return self
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration Schema for Periodic Agent."""
|
||||
|
||||
type: str = "PeriodicAgent"
|
||||
"""Name of the agent."""
|
||||
agent_settings: "PeriodicAgent.AgentSettingsSchema" = Field(
|
||||
default_factory=lambda: PeriodicAgent.AgentSettingsSchema()
|
||||
)
|
||||
self.settings = settings or PeriodicAgent.Settings()
|
||||
self._set_next_execution_timestep(timestep=self.settings.start_step, variance=self.settings.start_variance)
|
||||
self.num_executions = 0
|
||||
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
num_executions: int = 0
|
||||
"""Number of times the agent has executed an action."""
|
||||
next_execution_timestep: int = 0
|
||||
"""Timestep of the next action execution by the agent."""
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def start_node(self) -> str:
|
||||
"""On instantiation, randomly select a start node."""
|
||||
return random.choice(self.config.agent_settings.possible_start_nodes)
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
@@ -76,9 +102,14 @@ class PeriodicAgent(AbstractScriptedAgent):
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.settings.max_executions:
|
||||
if timestep == self.next_execution_timestep and self.num_executions < self.max_executions:
|
||||
self.num_executions += 1
|
||||
self._set_next_execution_timestep(timestep + self.settings.frequency, self.settings.variance)
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}
|
||||
self._set_next_execution_timestep(
|
||||
timestep + self.config.agent_settings.frequency, self.config.agent_settings.variance
|
||||
)
|
||||
return "node_application_execute", {
|
||||
"node_name": self.start_node,
|
||||
"application_name": self.config.agent_settings.target_application,
|
||||
}
|
||||
|
||||
return "DONOTHING", {}
|
||||
return "do_nothing", {}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class TAP001(AbstractScriptedAgent):
|
||||
"""
|
||||
TAP001 | Mobile Malware -- Ransomware Variant.
|
||||
|
||||
Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup_agent()
|
||||
|
||||
next_execution_timestep: int = 0
|
||||
starting_node_idx: int = 0
|
||||
installed: bool = False
|
||||
|
||||
def _set_next_execution_timestep(self, timestep: int) -> None:
|
||||
"""Set the next execution timestep with a configured random variance.
|
||||
|
||||
:param timestep: The timestep to add variance to.
|
||||
"""
|
||||
random_timestep_increment = random.randint(
|
||||
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
|
||||
)
|
||||
self.next_execution_timestep = timestep + random_timestep_increment
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
|
||||
"""Waits until a specific timestep, then attempts to execute the ransomware application.
|
||||
|
||||
This application acts a wrapper around the kill-chain, similar to green-analyst and
|
||||
the previous UC2 data manipulation bot.
|
||||
|
||||
:param obs: Current observation for this agent.
|
||||
:type obs: ObsType
|
||||
:param timestep: The current simulation timestep, used for scheduling actions
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
if not self.installed:
|
||||
self.installed = True
|
||||
return "NODE_APPLICATION_INSTALL", {
|
||||
"node_id": self.starting_node_idx,
|
||||
"application_name": "RansomwareScript",
|
||||
}
|
||||
|
||||
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
|
||||
|
||||
def setup_agent(self) -> None:
|
||||
"""Set the next execution timestep when the episode resets."""
|
||||
self._select_start_node()
|
||||
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
|
||||
for n, act in self.action_manager.action_map.items():
|
||||
if not act[0] == "NODE_APPLICATION_INSTALL":
|
||||
continue
|
||||
if act[1]["node_id"] == self.starting_node_idx:
|
||||
self.ip_address = act[1]["ip_address"]
|
||||
return
|
||||
raise RuntimeError("TAP001 agent could not find database server ip address in action map")
|
||||
|
||||
def _select_start_node(self) -> None:
|
||||
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
|
||||
# we are assuming that every node in the node manager has a data manipulation application at idx 0
|
||||
num_nodes = len(self.action_manager.node_names)
|
||||
self.starting_node_idx = random.randint(0, num_nodes - 1)
|
||||
@@ -7,14 +7,8 @@ import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
|
||||
from primaite.game.agent.scripted_agents.tap001 import TAP001
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
|
||||
from primaite.game.agent.rewards import SharedReward
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.network.creation import NetworkNodeAdder
|
||||
@@ -44,7 +38,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from primaite.simulator.system.software import Software
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -258,6 +252,7 @@ class PrimaiteGame:
|
||||
net = sim.network
|
||||
|
||||
simulation_config = cfg.get("simulation", {})
|
||||
defaults_config = cfg.get("defaults", {})
|
||||
network_config = simulation_config.get("network", {})
|
||||
airspace_cfg = network_config.get("airspace", {})
|
||||
frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {})
|
||||
@@ -283,6 +278,18 @@ class PrimaiteGame:
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# TODO: handle simulation defaults more cleanly
|
||||
if "node_start_up_duration" in defaults_config:
|
||||
new_node.start_up_duration = defaults_config["node_startup_duration"]
|
||||
if "node_shut_down_duration" in defaults_config:
|
||||
new_node.shut_down_duration = defaults_config["node_shut_down_duration"]
|
||||
if "node_scan_duration" in defaults_config:
|
||||
new_node.node_scan_duration = defaults_config["node_scan_duration"]
|
||||
if "folder_scan_duration" in defaults_config:
|
||||
new_node.file_system._default_folder_scan_duration = defaults_config["folder_scan_duration"]
|
||||
if "folder_restore_duration" in defaults_config:
|
||||
new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"]
|
||||
|
||||
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
|
||||
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
|
||||
for user_cfg in node_cfg["users"]:
|
||||
@@ -315,12 +322,12 @@ class PrimaiteGame:
|
||||
|
||||
if service_class is not None:
|
||||
_LOGGER.debug(f"installing {service_type} on node {new_node.config.hostname}")
|
||||
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
|
||||
new_node.software_manager.install(service_class)
|
||||
new_service = new_node.software_manager.software[service_class.__name__]
|
||||
|
||||
# fixing duration for the service
|
||||
if "fix_duration" in service_cfg.get("options", {}):
|
||||
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
|
||||
if "fixing_duration" in service_cfg.get("options", {}):
|
||||
new_service.config.fixing_duration = service_cfg["options"]["fixing_duration"]
|
||||
|
||||
_set_software_listen_on_ports(new_service, service_cfg)
|
||||
# start the service
|
||||
@@ -329,6 +336,15 @@ class PrimaiteGame:
|
||||
msg = f"Configuration contains an invalid service type: {service_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# TODO: handle simulation defaults more cleanly
|
||||
if "service_fix_duration" in defaults_config:
|
||||
new_service.fixing_duration = defaults_config["service_fix_duration"]
|
||||
if "service_restart_duration" in defaults_config:
|
||||
new_service.restart_duration = defaults_config["service_restart_duration"]
|
||||
if "service_install_duration" in defaults_config:
|
||||
new_service.install_duration = defaults_config["service_install_duration"]
|
||||
|
||||
# service-dependent options
|
||||
if service_type == "DNSClient":
|
||||
if "options" in service_cfg:
|
||||
@@ -361,74 +377,20 @@ class PrimaiteGame:
|
||||
application_type = application_cfg["type"]
|
||||
|
||||
if application_type in Application._registry:
|
||||
new_node.software_manager.install(Application._registry[application_type])
|
||||
application_class = Application._registry[application_type]
|
||||
application_options = application_cfg.get("options", {})
|
||||
application_options["type"] = application_type
|
||||
new_node.software_manager.install(application_class, software_config=application_options)
|
||||
new_application = new_node.software_manager.software[application_type] # grab the instance
|
||||
|
||||
# fixing duration for the application
|
||||
if "fix_duration" in application_cfg.get("options", {}):
|
||||
new_application.fixing_duration = application_cfg["options"]["fix_duration"]
|
||||
else:
|
||||
msg = f"Configuration contains an invalid application type: {application_type}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
_set_software_listen_on_ports(new_application, application_cfg)
|
||||
|
||||
# run the application
|
||||
new_application.run()
|
||||
|
||||
if application_type == "DataManipulationBot":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "DELETE"),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
|
||||
)
|
||||
elif application_type == "RansomwareScript":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")) if opt.get("server_ip") else None,
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload", "ENCRYPT"),
|
||||
)
|
||||
elif application_type == "DatabaseClient":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("db_server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
)
|
||||
elif application_type == "WebBrowser":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.target_url = opt.get("target_url")
|
||||
elif application_type == "DoSBot":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
target_ip_address=IPv4Address(opt.get("target_ip_address")),
|
||||
target_port=PORT_LOOKUP[opt.get("target_port", "POSTGRES_SERVER")],
|
||||
payload=opt.get("payload"),
|
||||
repeat=bool(opt.get("repeat")),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
dos_intensity=float(opt.get("dos_intensity", "1.0")),
|
||||
max_sessions=int(opt.get("max_sessions", "1000")),
|
||||
)
|
||||
elif application_type == "C2Beacon":
|
||||
if "options" in application_cfg:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
c2_server_ip_address=IPv4Address(opt.get("c2_server_ip_address")),
|
||||
keep_alive_frequency=(opt.get("keep_alive_frequency", 5)),
|
||||
masquerade_protocol=PROTOCOL_LOOKUP[
|
||||
(opt.get("masquerade_protocol", PROTOCOL_LOOKUP["TCP"]))
|
||||
],
|
||||
masquerade_port=PORT_LOOKUP[(opt.get("masquerade_port", PORT_LOOKUP["HTTP"]))],
|
||||
)
|
||||
if "network_interfaces" in node_cfg:
|
||||
for nic_num, nic_cfg in node_cfg["network_interfaces"].items():
|
||||
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
|
||||
@@ -470,76 +432,10 @@ class PrimaiteGame:
|
||||
agents_cfg = cfg.get("agents", [])
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_ref = agent_cfg["ref"] # noqa: F841
|
||||
agent_type = agent_cfg["type"]
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
reward_function = RewardFunction.from_config(reward_function_cfg)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
settings = agent_cfg.get("agent_settings", {})
|
||||
new_agent = ProbabilisticAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
elif agent_type == "PeriodicAgent":
|
||||
settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {}))
|
||||
new_agent = PeriodicAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
elif agent_type == "ProxyAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
|
||||
new_agent = DataManipulationAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
elif agent_type == "TAP001":
|
||||
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
|
||||
new_agent = TAP001(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=reward_function,
|
||||
agent_settings=agent_settings,
|
||||
)
|
||||
else:
|
||||
msg = f"Configuration error: {agent_type} is not a valid agent type."
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
new_agent = AbstractAgent.from_config(agent_cfg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
if isinstance(new_agent, ProxyAgent):
|
||||
game.rl_agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
@@ -11,6 +11,15 @@
|
||||
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -19,7 +28,7 @@
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from prettytable import PrettyTable\n"
|
||||
"from prettytable import PrettyTable"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -195,7 +204,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -209,7 +218,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -51,24 +51,15 @@
|
||||
" - ref: CustomC2Agent\n",
|
||||
" team: RED\n",
|
||||
" type: ProxyAgent\n",
|
||||
" observation_space: null\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_INSTALL\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
" - type: CONFIGURE_C2_BEACON\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
" - type: C2_SERVER_TERMINAL_COMMAND\n",
|
||||
" - type: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: web_server\n",
|
||||
" applications: \n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Beacon\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications: \n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Server\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
@@ -82,15 +73,15 @@
|
||||
" - 0.0.0.1\n",
|
||||
" action_map:\n",
|
||||
" 0:\n",
|
||||
" action: DONOTHING\n",
|
||||
" action: do_nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: NODE_APPLICATION_INSTALL\n",
|
||||
" action: node_application_install\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_name: C2Beacon\n",
|
||||
" 2:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -99,12 +90,12 @@
|
||||
" masquerade_protocol:\n",
|
||||
" masquerade_port:\n",
|
||||
" 3:\n",
|
||||
" action: NODE_APPLICATION_EXECUTE\n",
|
||||
" action: node_application_execute\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_id: 0 \n",
|
||||
" application_id: 0\n",
|
||||
" 4:\n",
|
||||
" action: C2_SERVER_TERMINAL_COMMAND\n",
|
||||
" action: c2_server_terminal_command\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" ip_address:\n",
|
||||
@@ -112,20 +103,20 @@
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" commands:\n",
|
||||
" - \n",
|
||||
" -\n",
|
||||
" - software_manager\n",
|
||||
" - application\n",
|
||||
" - install\n",
|
||||
" - RansomwareScript\n",
|
||||
" 5:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
" action: c2_server_ransomware_configure\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" config:\n",
|
||||
" server_ip_address: 192.168.1.14\n",
|
||||
" payload: ENCRYPT\n",
|
||||
" 6:\n",
|
||||
" action: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" action: c2_server_data_exfiltrate\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" target_file_name: \"database.db\"\n",
|
||||
@@ -134,14 +125,14 @@
|
||||
" target_ip_address: 192.168.1.14\n",
|
||||
" account:\n",
|
||||
" username: admin\n",
|
||||
" password: admin \n",
|
||||
" password: admin\n",
|
||||
"\n",
|
||||
" 7:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" action: c2_server_ransomware_launch\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" 8:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -150,7 +141,7 @@
|
||||
" masquerade_protocol: TCP\n",
|
||||
" masquerade_port: DNS\n",
|
||||
" 9:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -177,7 +168,7 @@
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = c2_agent_yaml\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
@@ -222,7 +213,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_INSTALL\n",
|
||||
"### **Command and Control** | C2 Beacon Actions | node_application_install\n",
|
||||
"\n",
|
||||
"The custom proxy red agent defined at the start of this notebook has been configured to install the C2 Beacon as action ``1`` in it's action map. \n",
|
||||
"\n",
|
||||
@@ -230,10 +221,6 @@
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: NODE_APPLICATION_INSTALL\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
@@ -243,7 +230,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 1:\n",
|
||||
" action: NODE_APPLICATION_INSTALL \n",
|
||||
" action: node_application_install \n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" application_name: C2Beacon\n",
|
||||
@@ -265,7 +252,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Beacon Actions | CONFIGURE_C2_BEACON \n",
|
||||
"### **Command and Control** | C2 Beacon Actions | configure_c2_beacon \n",
|
||||
"\n",
|
||||
"The custom proxy red agent defined at the start of this notebook can configure the C2 Beacon via action ``2`` in it's action map. \n",
|
||||
"\n",
|
||||
@@ -273,10 +260,6 @@
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: CONFIGURE_C2_BEACON\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
@@ -285,7 +268,7 @@
|
||||
" action_map:\n",
|
||||
" ...\n",
|
||||
" 2:\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Node Index\n",
|
||||
" config: # Further information about these config options can be found at the bottom of this notebook.\n",
|
||||
@@ -312,18 +295,14 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Beacon Actions | NODE_APPLICATION_EXECUTE\n",
|
||||
"### **Command and Control** | C2 Beacon Actions | node_application_execute\n",
|
||||
"\n",
|
||||
"The final action is ``NODE_APPLICATION_EXECUTE`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n",
|
||||
"The final action is ``node_application_execute`` which is used to establish a connection for the C2 application. This action can be called by the Red Agent via action ``3`` in it's action map. \n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
@@ -334,7 +313,7 @@
|
||||
" action_map:\n",
|
||||
" ...\n",
|
||||
" 3:\n",
|
||||
" action: NODE_APPLICATION_EXECUTE\n",
|
||||
" action: node_application_execute\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_id: 0\n",
|
||||
@@ -347,7 +326,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(3) "
|
||||
"env.step(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -390,10 +369,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_TERMINAL_COMMAND\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -441,7 +416,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
"### **Command and Control** | C2 Server Actions | c2_server_ransomware_configure\n",
|
||||
"\n",
|
||||
"Another action the C2 Server grants is the ability for a Red Agent to configure the RansomwareScript via the C2 Server rather than the note directly.\n",
|
||||
"\n",
|
||||
@@ -451,10 +426,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_CONFIGURE\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -464,7 +435,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 5:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_CONFIG\n",
|
||||
" action: c2_server_ransomware_configure\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" config:\n",
|
||||
@@ -497,9 +468,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Server Actions | C2_SERVER_DATA_EXFILTRATE\n",
|
||||
"### **Command and Control** | C2 Server Actions | c2_server_data_exfiltrate\n",
|
||||
"\n",
|
||||
"The second to last action available is the ``C2_SERVER_DATA_EXFILTRATE`` which is indexed as action ``6`` in the action map.\n",
|
||||
"The second to last action available is the ``c2_server_data_exfiltrate`` which is indexed as action ``6`` in the action map.\n",
|
||||
"\n",
|
||||
"This action can be used to exfiltrate a target file on a remote node to the C2 Beacon and the C2 Server's host file system via the ``FTP`` services.\n",
|
||||
"\n",
|
||||
@@ -507,10 +478,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -520,7 +487,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 6:\n",
|
||||
" action: C2_SERVER_DATA_EXFILTRATE\n",
|
||||
" action: c2_server_data_exfiltrate\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" target_file_name: \"database.db\"\n",
|
||||
@@ -567,9 +534,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Command and Control** | C2 Server Actions | C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
"### **Command and Control** | C2 Server Actions | c2_server_ransomware_launch\n",
|
||||
"\n",
|
||||
"Finally, the last available action is for the C2_SERVER_RANSOMWARE_LAUNCH to start the ransomware script installed on the same node as the C2 beacon.\n",
|
||||
"Finally, the last available action is for the c2_server_ransomware_launch to start the ransomware script installed on the same node as the C2 beacon.\n",
|
||||
"\n",
|
||||
"This action is indexed as action ``7``.\n",
|
||||
"\n",
|
||||
@@ -577,10 +544,6 @@
|
||||
"\n",
|
||||
"``` yaml\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" ...\n",
|
||||
@@ -590,7 +553,7 @@
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 7:\n",
|
||||
" action: C2_SERVER_RANSOMWARE_LAUNCH\n",
|
||||
" action: c2_server_ransomware_launch\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
"```\n"
|
||||
@@ -632,7 +595,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"custom_blue_agent_yaml = \"\"\" \n",
|
||||
"custom_blue_agent_yaml = \"\"\"\n",
|
||||
" - ref: defender\n",
|
||||
" team: BLUE\n",
|
||||
" type: ProxyAgent\n",
|
||||
@@ -715,28 +678,23 @@
|
||||
" - type: \"NONE\"\n",
|
||||
" label: ICS\n",
|
||||
" options: {}\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" - type: NODE_APPLICATION_REMOVE\n",
|
||||
" - type: NODE_SHUTDOWN\n",
|
||||
" - type: ROUTER_ACL_ADDRULE\n",
|
||||
" - type: DONOTHING\n",
|
||||
" action_map:\n",
|
||||
" 0:\n",
|
||||
" action: DONOTHING\n",
|
||||
" action: do_nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: NODE_APPLICATION_REMOVE\n",
|
||||
" action: node_application_remove\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_name: C2Beacon\n",
|
||||
" 2:\n",
|
||||
" action: NODE_SHUTDOWN\n",
|
||||
" action: node_shutdown\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" 3:\n",
|
||||
" action: ROUTER_ACL_ADDRULE\n",
|
||||
" action: router_acl_add_rule\n",
|
||||
" options:\n",
|
||||
" target_router: router_1\n",
|
||||
" position: 1\n",
|
||||
@@ -747,7 +705,7 @@
|
||||
" dest_port_id: 2\n",
|
||||
" protocol_id: 1\n",
|
||||
" source_wildcard_id: 0\n",
|
||||
" dest_wildcard_id: 0 \n",
|
||||
" dest_wildcard_id: 0\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" options:\n",
|
||||
@@ -796,7 +754,7 @@
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = custom_blue\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"blue_env = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
@@ -1121,7 +1079,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code cell below uses the custom blue agent defined at the start of this section perform a NODE_APPLICATION_REMOVE on the C2 beacon:"
|
||||
"The code cell below uses the custom blue agent defined at the start of this section perform a node_application_remove on the C2 beacon:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1130,7 +1088,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using CAOS ACTION: NODE_APPLICATION_REMOVE & capturing the OBS\n",
|
||||
"# Using CAOS ACTION: node_application_remove & capturing the OBS\n",
|
||||
"post_blue_action_obs, _, _, _, _ = blue_env.step(1)"
|
||||
]
|
||||
},
|
||||
@@ -1216,7 +1174,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a ``NODE_SHUT_DOWN`` action on the web server."
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a ``node_shut_down`` action on the web server."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1225,7 +1183,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using CAOS ACTION: NODE_SHUT_DOWN & capturing the OBS\n",
|
||||
"# Using CAOS ACTION: node_shut_down & capturing the OBS\n",
|
||||
"post_blue_action_obs, _, _, _, _ = blue_env.step(2)"
|
||||
]
|
||||
},
|
||||
@@ -1306,7 +1264,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a ROUTER_ACL_ADDRULE on router 1."
|
||||
"The code cell below uses the custom blue agent defined at the start of this section to perform a router_acl_add_rule on router 1."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1315,7 +1273,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Using CAOS ACTION: ROUTER_ACL_ADDRULE & capturing the OBS\n",
|
||||
"# Using CAOS ACTION: router_acl_add_rule & capturing the OBS\n",
|
||||
"post_blue_action_obs, _, _, _, _ = blue_env.step(3)"
|
||||
]
|
||||
},
|
||||
@@ -1429,11 +1387,11 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As demonstrated earlier, red agents can use the ``CONFIGURE_C2_BEACON`` action to configure these settings mid episode through the configuration options:\n",
|
||||
"As demonstrated earlier, red agents can use the ``configure_c2_beacon`` action to configure these settings mid episode through the configuration options:\n",
|
||||
"\n",
|
||||
"``` YAML\n",
|
||||
"...\n",
|
||||
" action: CONFIGURE_C2_BEACON\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -1468,7 +1426,7 @@
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = c2_agent_yaml\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"c2_config_env = PrimaiteGymEnv(env_config=cfg)"
|
||||
]
|
||||
@@ -1555,7 +1513,7 @@
|
||||
"source": [
|
||||
"for i in range(6):\n",
|
||||
" env.step(0)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"c2_server_1.show()"
|
||||
]
|
||||
},
|
||||
@@ -1676,7 +1634,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Comparing the OBS of the default frequency to a timestep frequency of 1 \n",
|
||||
"# Comparing the OBS of the default frequency to a timestep frequency of 1\n",
|
||||
"for i in range(2):\n",
|
||||
" keep_alive_obs, _, _, _, _ = blue_config_env.step(0)\n",
|
||||
" display_obs_diffs(default_obs, keep_alive_obs, blue_config_env.game.step_counter)"
|
||||
@@ -1760,7 +1718,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Capturing default C2 Traffic \n",
|
||||
"# Capturing default C2 Traffic\n",
|
||||
"for i in range(3):\n",
|
||||
" tcp_c2_obs, _, _, _, _ = blue_config_env.step(0)\n",
|
||||
"\n",
|
||||
|
||||
@@ -15,6 +15,15 @@
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -67,9 +76,9 @@
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" if red_action == 'do_nothing':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" elif red_action == 'node_application_execute':\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
@@ -147,12 +156,7 @@
|
||||
" nodes: {}\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
"\n",
|
||||
" # The agent has two action choices, either do nothing, or execute a pre-scripted attack by using \n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" # The agent has access to the DataManipulationBoth on clients 1 and 2.\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
@@ -306,19 +310,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
"action_space:\n",
|
||||
" action_list:\n",
|
||||
" - type: DONOTHING\n",
|
||||
" - type: NODE_APPLICATION_EXECUTE\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications:\n",
|
||||
" - application_name: DataManipulationBot\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
" max_services_per_node: 1\n",
|
||||
"# TODO:\n",
|
||||
"\"\"\")\n",
|
||||
"#TODO 2869 fix\n",
|
||||
"\n",
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
@@ -444,7 +438,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@@ -165,13 +165,13 @@
|
||||
"\n",
|
||||
"| node_id | node name |\n",
|
||||
"|---------|------------------|\n",
|
||||
"| 1 | domain_controller|\n",
|
||||
"| 2 | web_server |\n",
|
||||
"| 3 | database_server |\n",
|
||||
"| 4 | backup_server |\n",
|
||||
"| 5 | security_suite |\n",
|
||||
"| 6 | client_1 |\n",
|
||||
"| 7 | client_2 |\n",
|
||||
"| 0 | domain_controller|\n",
|
||||
"| 1 | web_server |\n",
|
||||
"| 2 | database_server |\n",
|
||||
"| 3 | backup_server |\n",
|
||||
"| 4 | security_suite |\n",
|
||||
"| 5 | client_1 |\n",
|
||||
"| 6 | client_2 |\n",
|
||||
"\n",
|
||||
"Service 1 on node 2 (web_server) corresponds to the Web Server service. Other services are only there for padding to ensure that each node's observation space has the same shape. They are filled with zeroes.\n",
|
||||
"\n",
|
||||
@@ -371,6 +371,15 @@
|
||||
"First, load the required modules"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -449,9 +458,9 @@
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" if red_action == 'do_nothing':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'NODE_APPLICATION_EXECUTE':\n",
|
||||
" elif red_action == 'node_application_execute':\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
@@ -547,7 +556,7 @@
|
||||
"\n",
|
||||
"The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 0.9 when both green agents make successful requests.\n",
|
||||
"\n",
|
||||
"Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
"Run the following cell until the green action is `node_application_execute` for application 0, then the reward should increase. If you run it enough times, another red attack will happen and the reward will drop again."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,6 +9,15 @@
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -201,7 +201,7 @@
|
||||
"source": [
|
||||
"caos_action = [\n",
|
||||
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
|
||||
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
"]\n",
|
||||
"game.simulation.apply_request(caos_action)"
|
||||
]
|
||||
@@ -259,7 +259,7 @@
|
||||
"source": [
|
||||
"caos_action = [\n",
|
||||
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
|
||||
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n",
|
||||
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_rt.network_interface[4].ip_address)\n",
|
||||
"]\n",
|
||||
"game.simulation.apply_request(caos_action)"
|
||||
]
|
||||
@@ -396,7 +396,7 @@
|
||||
"source": [
|
||||
"caos_action = [\n",
|
||||
" \"network\", \"node\", \"some_tech_jnr_dev_pc\", \n",
|
||||
" \"service\", \"Terminal\", \"ssh_to_remote\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
" \"service\", \"Terminal\", \"node_session_remote_login\", \"admin\", \"admin\", str(some_tech_storage_srv.network_interface[1].ip_address)\n",
|
||||
"]\n",
|
||||
"game.simulation.apply_request(caos_action)"
|
||||
]
|
||||
@@ -25,6 +25,15 @@
|
||||
"Let's set up a minimal network simulation and send some requests to see how it works."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -18,6 +18,15 @@
|
||||
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -18,6 +18,15 @@
|
||||
"#### First, Import packages and read our config file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -32,8 +41,6 @@
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
|
||||
"# to copy the files to your user data path.\n",
|
||||
"with open(PRIMAITE_PATHS.user_config_path / 'example_config/data_manipulation_marl.yaml', 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
|
||||
@@ -11,6 +11,15 @@
|
||||
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -95,7 +104,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -109,7 +118,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -18,6 +18,15 @@
|
||||
"#### First, we import the inital packages and read in our configuration file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -168,7 +177,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -182,7 +191,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -238,7 +238,7 @@
|
||||
"### Episode 2\n",
|
||||
"When we reset the environment again, it moves onto episode 2, where it will bring in greens_1 and reds_1 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
|
||||
"\n",
|
||||
"Most green actions will be `NODE_APPLICATION_EXECUTE` while red will `DONOTHING` except at steps 10 and 20."
|
||||
"Most green actions will be `node_application_execute` while red will `DONOTHING` except at steps 10 and 20."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -269,7 +269,7 @@
|
||||
"### Episode 3\n",
|
||||
"When we reset the environment again, it moves onto episode 3, where it will bring in greens_2 and reds_2 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
|
||||
"\n",
|
||||
"Now, green will perform `NODE_APPLICATION_EXECUTE` only 5% of the time, while red will perform `NODE_APPLICATION_EXECUTE` more frequently than before."
|
||||
"Now, green will perform `node_application_execute` only 5% of the time, while red will perform `node_application_execute` more frequently than before."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 110 KiB After Width: | Height: | Size: 110 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 69 KiB After Width: | Height: | Size: 69 KiB |
@@ -18,6 +18,15 @@
|
||||
"Import packages and read config file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -89,7 +89,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
:return: Action mask
|
||||
:rtype: List[bool]
|
||||
"""
|
||||
if not self.agent.action_masking:
|
||||
if not self.agent.config.agent_settings.action_masking:
|
||||
return np.asarray([True] * len(self.agent.action_manager.action_map))
|
||||
else:
|
||||
return self.game.action_mask(self._agent_name)
|
||||
|
||||
@@ -44,7 +44,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
)
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
if agent.action_masking:
|
||||
if agent.config.agent_settings.action_masking:
|
||||
self.observation_space[agent_name] = spaces.Dict(
|
||||
{
|
||||
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
|
||||
@@ -143,7 +143,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
if agent.action_masking:
|
||||
if agent.config.agent_settings.action_masking:
|
||||
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
|
||||
else:
|
||||
all_obs[agent_name] = obs
|
||||
@@ -168,7 +168,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
self.observation_space = spaces.Dict(
|
||||
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
|
||||
)
|
||||
@@ -178,7 +178,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
return new_obs, *_
|
||||
@@ -187,7 +187,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
# if action masking is enabled, intercept the step method and add action mask to observation
|
||||
if self.env.agent.action_masking:
|
||||
if self.env.agent.config.agent_settings.action_masking:
|
||||
obs, *_ = self.env.step(action)
|
||||
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
|
||||
return new_obs, *_
|
||||
|
||||
@@ -264,7 +264,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -664,7 +664,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -130,8 +130,8 @@ class File(FileSystemItemABC):
|
||||
|
||||
Return False if corruption is detected, otherwise True
|
||||
"""
|
||||
warnings.warn("NODE_FILE_CHECKHASH is currently not implemented.")
|
||||
self.sys_log.warning("NODE_FILE_CHECKHASH is currently not implemented.")
|
||||
warnings.warn("node_file_checkhash is currently not implemented.")
|
||||
self.sys_log.warning("node_file_checkhash is currently not implemented.")
|
||||
return False
|
||||
|
||||
if self.deleted:
|
||||
|
||||
@@ -30,6 +30,11 @@ class FileSystem(SimComponent):
|
||||
num_file_deletions: int = 0
|
||||
"Number of file deletions in the current step."
|
||||
|
||||
_default_folder_scan_duration: Optional[int] = None
|
||||
"Override default scan duration for folders"
|
||||
_default_folder_restore_duration: Optional[int] = None
|
||||
"Override default restore duration for folders"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Ensure a default root folder
|
||||
@@ -258,6 +263,11 @@ class FileSystem(SimComponent):
|
||||
name=folder.name, request_type=RequestType(func=folder._request_manager)
|
||||
)
|
||||
self.folders[folder.uuid] = folder
|
||||
# set the folder scan and restore durations.
|
||||
if self._default_folder_scan_duration is not None:
|
||||
folder.scan_duration = self._default_folder_scan_duration
|
||||
if self._default_folder_restore_duration is not None:
|
||||
folder.restore_duration = self._default_folder_restore_duration
|
||||
return folder
|
||||
|
||||
def delete_folder(self, folder_name: str) -> bool:
|
||||
|
||||
@@ -43,6 +43,9 @@ def convert_size(size_bytes: int) -> str:
|
||||
class FileSystemItemHealthStatus(Enum):
|
||||
"""Status of the FileSystemItem."""
|
||||
|
||||
NONE = 0
|
||||
"""File system item health status is not known."""
|
||||
|
||||
GOOD = 1
|
||||
"""File/Folder is OK."""
|
||||
|
||||
@@ -72,7 +75,7 @@ class FileSystemItemABC(SimComponent):
|
||||
health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
|
||||
"Actual status of the current FileSystemItem"
|
||||
|
||||
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.GOOD
|
||||
visible_health_status: FileSystemItemHealthStatus = FileSystemItemHealthStatus.NONE
|
||||
"Visible status of the current FileSystemItem"
|
||||
|
||||
previous_hash: Optional[str] = None
|
||||
|
||||
@@ -46,7 +46,7 @@ class Folder(FileSystemItemABC):
|
||||
:param sys_log: The SysLog instance to us to create system logs.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._scanned_this_step: bool = False
|
||||
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
@@ -83,6 +83,7 @@ class Folder(FileSystemItemABC):
|
||||
state = super().describe_state()
|
||||
state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()}
|
||||
state["deleted_files"] = {file.name: file.describe_state() for uuid, file in self.deleted_files.items()}
|
||||
state["scanned_this_step"] = self._scanned_this_step
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -135,7 +136,7 @@ class Folder(FileSystemItemABC):
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
super().pre_timestep(timestep)
|
||||
|
||||
self._scanned_this_step = False
|
||||
for file in self.files.values():
|
||||
file.pre_timestep(timestep)
|
||||
|
||||
@@ -148,9 +149,17 @@ class Folder(FileSystemItemABC):
|
||||
for file_id in self.files:
|
||||
file = self.get_file_by_id(file_uuid=file_id)
|
||||
file.scan()
|
||||
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
# set folder health to worst file's health by generating a list of file healths. If no files, use 0
|
||||
self.health_status = FileSystemItemHealthStatus(
|
||||
max(
|
||||
[f.health_status.value for f in self.files.values()]
|
||||
or [
|
||||
0,
|
||||
]
|
||||
)
|
||||
)
|
||||
self.visible_health_status = self.health_status
|
||||
self._scanned_this_step = True
|
||||
|
||||
def _reveal_to_red_timestep(self) -> None:
|
||||
"""Apply reveal to red timestep."""
|
||||
@@ -387,8 +396,8 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
Return False if corruption is detected, otherwise True
|
||||
"""
|
||||
warnings.warn("NODE_FOLDER_CHECKHASH is currently not implemented.")
|
||||
self.sys_log.error("NODE_FOLDER_CHECKHASH is currently not implemented.")
|
||||
warnings.warn("node_folder_checkhash is currently not implemented.")
|
||||
self.sys_log.error("node_folder_checkhash is currently not implemented.")
|
||||
return False
|
||||
|
||||
if self.deleted:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, ClassVar, Dict, Literal, Type
|
||||
from typing import Any, ClassVar, Dict, Literal, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -49,7 +49,7 @@ class NetworkNodeAdder(BaseModel):
|
||||
|
||||
_registry: ClassVar[Dict[str, Type["NetworkNodeAdder"]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str], **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a network node adder class.
|
||||
|
||||
@@ -58,6 +58,8 @@ class NetworkNodeAdder(BaseModel):
|
||||
:raises ValueError: When attempting to register a name that is already reserved.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate node adder {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -824,7 +824,7 @@ class User(SimComponent):
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class UserManager(Service):
|
||||
class UserManager(Service, identifier="UserManager"):
|
||||
"""
|
||||
Manages users within the PrimAITE system, handling creation, authentication, and administration.
|
||||
|
||||
@@ -833,11 +833,18 @@ class UserManager(Service):
|
||||
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for UserManager."""
|
||||
|
||||
type: str = "UserManager"
|
||||
|
||||
config: "UserManager.ConfigSchema" = Field(default_factory=lambda: UserManager.ConfigSchema())
|
||||
|
||||
users: Dict[str, User] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initializes a UserManager instanc.
|
||||
Initializes a UserManager instance.
|
||||
|
||||
:param username: The username for the default admin user
|
||||
:param password: The password for the default admin user
|
||||
@@ -1130,13 +1137,20 @@ class RemoteUserSession(UserSession):
|
||||
return state
|
||||
|
||||
|
||||
class UserSessionManager(Service):
|
||||
class UserSessionManager(Service, identifier="UserSessionManager"):
|
||||
"""
|
||||
Manages user sessions on a Node, including local and remote sessions.
|
||||
|
||||
This class handles authentication, session management, and session timeouts for users interacting with the Node.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for UserSessionManager."""
|
||||
|
||||
type: str = "UserSessionManager"
|
||||
|
||||
config: "UserSessionManager.ConfigSchema" = Field(default_factory=lambda: UserSessionManager.ConfigSchema())
|
||||
|
||||
local_session: Optional[UserSession] = None
|
||||
"""The current local user session, if any."""
|
||||
|
||||
@@ -1554,7 +1568,6 @@ class Node(SimComponent, ABC):
|
||||
red_scan_countdown: int = 0
|
||||
"Time steps until reveal to red scan is complete."
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Node":
|
||||
"""Create Node object from a given configuration dictionary."""
|
||||
@@ -1564,7 +1577,7 @@ class Node(SimComponent, ABC):
|
||||
obj = cls(config=cls.ConfigSchema(**config))
|
||||
return obj
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a node type.
|
||||
|
||||
@@ -1572,10 +1585,10 @@ class Node(SimComponent, ABC):
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to register an node with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
identifier = identifier.lower()
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new node {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, Optional, Set, Type
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
@@ -21,13 +23,20 @@ class ApplicationOperatingState(Enum):
|
||||
"The application is being installed or updated."
|
||||
|
||||
|
||||
class Application(IOSoftware):
|
||||
class Application(IOSoftware, ABC):
|
||||
"""
|
||||
Represents an Application in the simulation environment.
|
||||
|
||||
Applications are user-facing programs that may perform input/output operations.
|
||||
"""
|
||||
|
||||
class ConfigSchema(IOSoftware.ConfigSchema, ABC):
|
||||
"""Config Schema for Application class."""
|
||||
|
||||
type: str
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: Application.ConfigSchema())
|
||||
|
||||
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
|
||||
"The current operating state of the Application."
|
||||
execution_control_status: str = "manual"
|
||||
@@ -44,21 +53,36 @@ class Application(IOSoftware):
|
||||
_registry: ClassVar[Dict[str, Type["Application"]]] = {}
|
||||
"""Registry of application types. Automatically populated when subclasses are defined."""
|
||||
|
||||
def __init_subclass__(cls, identifier: str = "default", **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an application type.
|
||||
|
||||
:param identifier: Uniquely specifies an application class by name. Used for finding items by config.
|
||||
:type identifier: str
|
||||
:type identifier: Optional[str]
|
||||
:raises ValueError: When attempting to register an application with a name that is already allocated.
|
||||
"""
|
||||
if identifier == "default":
|
||||
return
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier is None:
|
||||
return
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Tried to define new application {identifier}, but this name is already reserved.")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "Application":
|
||||
"""Create an application from a config dictionary.
|
||||
|
||||
:param config: dict of options for application components constructor
|
||||
:type config: dict
|
||||
:return: The application component.
|
||||
:rtype: Application
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid Application type {config['type']}")
|
||||
application_class = cls._registry[config["type"]]
|
||||
application_object = application_class(config=application_class.ConfigSchema(**config))
|
||||
return application_object
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -67,11 +67,19 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
|
||||
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
|
||||
Database Service. It mainly operates over TCP protocol.
|
||||
|
||||
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseClient."""
|
||||
|
||||
type: str = "DatabaseClient"
|
||||
db_server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: DatabaseClient.ConfigSchema())
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""The IPv4 address of the Database Service server, defaults to None."""
|
||||
server_password: Optional[str] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
@@ -93,6 +101,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
kwargs["port"] = PORT_LOOKUP["POSTGRES_SERVER"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self.server_ip_address = self.config.db_server_ip
|
||||
self.server_password = self.config.server_password
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
|
||||
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
@@ -52,6 +52,13 @@ class NMAP(Application, identifier="NMAP"):
|
||||
as ping scans to discover active hosts and port scans to detect open ports on those hosts.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for NMAP."""
|
||||
|
||||
type: str = "NMAP"
|
||||
|
||||
config: "NMAP.ConfigSchema" = Field(default_factory=lambda: NMAP.ConfigSchema())
|
||||
|
||||
_active_port_scans: Dict[str, PortScanPayload] = {}
|
||||
_port_scan_responses: Dict[str, PortScanPayload] = {}
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
@@ -45,10 +45,10 @@ class C2Payload(Enum):
|
||||
"""C2 Input Command payload. Used by the C2 Server to send a command to the c2 beacon."""
|
||||
|
||||
OUTPUT = "output_command"
|
||||
"""C2 Output Command. Used by the C2 Beacon to send the results of a Input command to the c2 server."""
|
||||
"""C2 Output Command. Used by the C2 Beacon to send the results of an Input command to the c2 server."""
|
||||
|
||||
|
||||
class AbstractC2(Application, identifier="AbstractC2"):
|
||||
class AbstractC2(Application):
|
||||
"""
|
||||
An abstract command and control (c2) application.
|
||||
|
||||
@@ -60,9 +60,25 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
|
||||
Defaults to masquerading as HTTP (Port 80) via TCP.
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""Configuration for AbstractC2."""
|
||||
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
|
||||
|
||||
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
|
||||
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
|
||||
|
||||
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
|
||||
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
|
||||
|
||||
listen_on_ports: Set[Port] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: AbstractC2.ConfigSchema())
|
||||
|
||||
c2_connection_active: bool = False
|
||||
"""Indicates if the c2 server and c2 beacon are currently connected."""
|
||||
|
||||
@@ -75,19 +91,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
keep_alive_inactivity: int = 0
|
||||
"""Indicates how many timesteps since the last time the c2 application received a keep alive."""
|
||||
|
||||
class _C2Opts(BaseModel):
|
||||
"""A Pydantic Schema for the different C2 configuration options."""
|
||||
|
||||
keep_alive_frequency: int = Field(default=5, ge=1)
|
||||
"""The frequency at which ``Keep Alive`` packets are sent to the C2 Server from the C2 Beacon."""
|
||||
|
||||
masquerade_protocol: IPProtocol = Field(default=PROTOCOL_LOOKUP["TCP"])
|
||||
"""The currently chosen protocol that the C2 traffic is masquerading as. Defaults as TCP."""
|
||||
|
||||
masquerade_port: Port = Field(default=PORT_LOOKUP["HTTP"])
|
||||
"""The currently chosen port that the C2 traffic is masquerading as. Defaults at HTTP."""
|
||||
|
||||
c2_config: _C2Opts = _C2Opts()
|
||||
"""
|
||||
Holds the current configuration settings of the C2 Suite.
|
||||
|
||||
@@ -100,6 +103,12 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
C2 beacon to reconfigure it's configuration settings.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialise the C2 applications to by default listen for HTTP traffic."""
|
||||
kwargs["port"] = PORT_LOOKUP["NONE"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _craft_packet(
|
||||
self, c2_payload: C2Payload, c2_command: Optional[C2Command] = None, command_options: Optional[Dict] = {}
|
||||
) -> C2Packet:
|
||||
@@ -118,13 +127,13 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
:type c2_command: C2Command.
|
||||
:param command_options: The relevant C2 Beacon parameters.F
|
||||
:type command_options: Dict
|
||||
:return: Returns the construct C2Packet
|
||||
:return: Returns the constructed C2Packet
|
||||
:rtype: C2Packet
|
||||
"""
|
||||
constructed_packet = C2Packet(
|
||||
masquerade_protocol=self.c2_config.masquerade_protocol,
|
||||
masquerade_port=self.c2_config.masquerade_port,
|
||||
keep_alive_frequency=self.c2_config.keep_alive_frequency,
|
||||
masquerade_protocol=self.config.masquerade_protocol,
|
||||
masquerade_port=self.config.masquerade_port,
|
||||
keep_alive_frequency=self.config.keep_alive_frequency,
|
||||
payload_type=c2_payload,
|
||||
command=c2_command,
|
||||
payload=command_options,
|
||||
@@ -140,13 +149,6 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
"""
|
||||
return super().describe_state()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialise the C2 applications to by default listen for HTTP traffic."""
|
||||
kwargs["listen_on_ports"] = {PORT_LOOKUP["HTTP"], PORT_LOOKUP["FTP"], PORT_LOOKUP["DNS"]}
|
||||
kwargs["port"] = PORT_LOOKUP["NONE"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _host_ftp_client(self) -> Optional[FTPClient]:
|
||||
"""Return the FTPClient that is installed C2 Application's host.
|
||||
@@ -330,8 +332,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
if self.send(
|
||||
payload=keep_alive_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
session_id=session_id,
|
||||
):
|
||||
# Setting the keep_alive_sent guard condition to True. This is used to prevent packet storms.
|
||||
@@ -340,8 +342,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
self.sys_log.info(f"{self.name}: Keep Alive sent to {self.c2_remote_connection}")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: Keep Alive sent to {self.c2_remote_connection} "
|
||||
f"Masquerade Port: {self.c2_config.masquerade_port} "
|
||||
f"Masquerade Protocol: {self.c2_config.masquerade_protocol} "
|
||||
f"Masquerade Port: {self.config.masquerade_port} "
|
||||
f"Masquerade Protocol: {self.config.masquerade_protocol} "
|
||||
)
|
||||
return True
|
||||
else:
|
||||
@@ -376,15 +378,15 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
|
||||
# Updating the C2 Configuration attribute.
|
||||
|
||||
self.c2_config.masquerade_port = payload.masquerade_port
|
||||
self.c2_config.masquerade_protocol = payload.masquerade_protocol
|
||||
self.c2_config.keep_alive_frequency = payload.keep_alive_frequency
|
||||
self.config.masquerade_port = payload.masquerade_port
|
||||
self.config.masquerade_protocol = payload.masquerade_protocol
|
||||
self.config.keep_alive_frequency = payload.keep_alive_frequency
|
||||
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: C2 Config Resolved Config from Keep Alive:"
|
||||
f"Masquerade Port: {self.c2_config.masquerade_port}"
|
||||
f"Masquerade Protocol: {self.c2_config.masquerade_protocol}"
|
||||
f"Keep Alive Frequency: {self.c2_config.keep_alive_frequency}"
|
||||
f"Masquerade Port: {self.config.masquerade_port}"
|
||||
f"Masquerade Protocol: {self.config.masquerade_protocol}"
|
||||
f"Keep Alive Frequency: {self.config.keep_alive_frequency}"
|
||||
)
|
||||
|
||||
# This statement is intended to catch on the C2 Application that is listening for connection.
|
||||
@@ -410,8 +412,8 @@ class AbstractC2(Application, identifier="AbstractC2"):
|
||||
self.keep_alive_inactivity = 0
|
||||
self.keep_alive_frequency = 5
|
||||
self.c2_remote_connection = None
|
||||
self.c2_config.masquerade_port = PORT_LOOKUP["HTTP"]
|
||||
self.c2_config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
|
||||
self.config.masquerade_port = PORT_LOOKUP["HTTP"]
|
||||
self.config.masquerade_protocol = PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
@abstractmethod
|
||||
def _confirm_remote_connection(self, timestep: int) -> bool:
|
||||
|
||||
@@ -3,7 +3,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -12,8 +12,9 @@ from primaite.simulator.system.applications.red_applications.c2 import ExfilOpts
|
||||
from primaite.simulator.system.applications.red_applications.c2.abstract_c2 import AbstractC2, C2Command, C2Payload
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal, TerminalClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
|
||||
class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
@@ -32,15 +33,30 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
2. Leveraging the terminal application to execute requests (dependent on the command given)
|
||||
3. Sending the RequestResponse back to the C2 Server (Command output)
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(AbstractC2.ConfigSchema):
|
||||
"""ConfigSchema for C2Beacon."""
|
||||
|
||||
type: str = "C2Beacon"
|
||||
c2_server_ip_address: Optional[IPV4Address] = None
|
||||
keep_alive_frequency: int = 5
|
||||
masquerade_protocol: IPProtocol = PROTOCOL_LOOKUP["TCP"]
|
||||
masquerade_port: Port = PORT_LOOKUP["HTTP"]
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: C2Beacon.ConfigSchema())
|
||||
|
||||
keep_alive_attempted: bool = False
|
||||
"""Indicates if a keep alive has been attempted to be sent this timestep. Used to prevent packet storms."""
|
||||
|
||||
terminal_session: TerminalClientConnection = None
|
||||
"The currently in use terminal session."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "C2Beacon"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _host_terminal(self) -> Optional[Terminal]:
|
||||
"""Return the Terminal that is installed on the same machine as the C2 Beacon."""
|
||||
@@ -119,10 +135,6 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
rm.add_request("configure", request_type=RequestType(func=_configure))
|
||||
return rm
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "C2Beacon"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Configure is practically setter method for the ``c2.config`` attribute that also ties into the request manager.
|
||||
@validate_call
|
||||
def configure(
|
||||
@@ -146,7 +158,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
masquerade_port | What port should the C2 traffic use? (TCP or UDP)
|
||||
|
||||
These configuration options are used to reassign the fields in the inherited inner class
|
||||
``c2_config``.
|
||||
``config``.
|
||||
|
||||
If a connection is already in progress then this method also sends a keep alive to the C2
|
||||
Server in order for the C2 Server to sync with the new configuration settings.
|
||||
@@ -162,9 +174,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
:return: Returns True if the configuration was successful, False otherwise.
|
||||
"""
|
||||
self.c2_remote_connection = IPv4Address(c2_server_ip_address)
|
||||
self.c2_config.keep_alive_frequency = keep_alive_frequency
|
||||
self.c2_config.masquerade_port = masquerade_port
|
||||
self.c2_config.masquerade_protocol = masquerade_protocol
|
||||
self.config.keep_alive_frequency = keep_alive_frequency
|
||||
self.config.masquerade_port = masquerade_port
|
||||
self.config.masquerade_protocol = masquerade_protocol
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured {self.name} with remote C2 server connection: {c2_server_ip_address=}."
|
||||
)
|
||||
@@ -263,14 +275,12 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
if self.send(
|
||||
payload=output_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
session_id=session_id,
|
||||
):
|
||||
self.sys_log.info(f"{self.name}: Command output sent to {self.c2_remote_connection}")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: on {self.c2_config.masquerade_port} via {self.c2_config.masquerade_protocol}"
|
||||
)
|
||||
self.sys_log.debug(f"{self.name}: on {self.config.masquerade_port} via {self.config.masquerade_protocol}")
|
||||
return True
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
@@ -562,7 +572,7 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
:rtype bool:
|
||||
"""
|
||||
self.keep_alive_attempted = False # Resetting keep alive sent.
|
||||
if self.keep_alive_inactivity == self.c2_config.keep_alive_frequency:
|
||||
if self.keep_alive_inactivity == self.config.keep_alive_frequency:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Attempting to Send Keep Alive to {self.c2_remote_connection} at timestep {timestep}."
|
||||
)
|
||||
@@ -627,9 +637,9 @@ class C2Beacon(AbstractC2, identifier="C2Beacon"):
|
||||
self.c2_connection_active,
|
||||
self.c2_remote_connection,
|
||||
self.keep_alive_inactivity,
|
||||
self.c2_config.keep_alive_frequency,
|
||||
self.c2_config.masquerade_protocol,
|
||||
self.c2_config.masquerade_port,
|
||||
self.config.keep_alive_frequency,
|
||||
self.config.masquerade_protocol,
|
||||
self.config.masquerade_port,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
from pydantic import Field, validate_call
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -31,9 +31,16 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
1. Sending commands to the C2 Beacon. (Command input)
|
||||
2. Parsing terminal RequestResponses back to the Agent.
|
||||
|
||||
Please refer to the Command-&-Control notebook for an in-depth example of the C2 Suite.
|
||||
Please refer to the Command-and-Control notebook for an in-depth example of the C2 Suite.
|
||||
"""
|
||||
|
||||
class ConfigSchema(AbstractC2.ConfigSchema):
|
||||
"""ConfigSchema for C2Server."""
|
||||
|
||||
type: str = "C2Server"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: C2Server.ConfigSchema())
|
||||
|
||||
current_command_output: RequestResponse = None
|
||||
"""The Request Response by the last command send. This attribute is updated by the method _handle_command_output."""
|
||||
|
||||
@@ -251,8 +258,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
payload=command_packet,
|
||||
dest_ip_address=self.c2_remote_connection,
|
||||
session_id=self.c2_session.uuid,
|
||||
dest_port=self.c2_config.masquerade_port,
|
||||
ip_protocol=self.c2_config.masquerade_protocol,
|
||||
dest_port=self.config.masquerade_port,
|
||||
ip_protocol=self.config.masquerade_protocol,
|
||||
):
|
||||
self.sys_log.info(f"{self.name}: Successfully sent {given_command}.")
|
||||
self.sys_log.info(f"{self.name}: Awaiting command response {given_command}.")
|
||||
@@ -334,11 +341,11 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
:return: Returns False if the C2 beacon is considered dead. Otherwise True.
|
||||
:rtype bool:
|
||||
"""
|
||||
if self.keep_alive_inactivity > self.c2_config.keep_alive_frequency:
|
||||
if self.keep_alive_inactivity > self.config.keep_alive_frequency:
|
||||
self.sys_log.info(f"{self.name}: C2 Beacon connection considered dead due to inactivity.")
|
||||
self.sys_log.debug(
|
||||
f"{self.name}: Did not receive expected keep alive connection from {self.c2_remote_connection}"
|
||||
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.c2_config.keep_alive_frequency}"
|
||||
f"{self.name}: Expected at timestep: {timestep} due to frequency: {self.config.keep_alive_frequency}"
|
||||
f"{self.name}: Last Keep Alive received at {(timestep - self.keep_alive_inactivity)}"
|
||||
)
|
||||
self._reset_c2_connection()
|
||||
@@ -389,8 +396,8 @@ class C2Server(AbstractC2, identifier="C2Server"):
|
||||
[
|
||||
self.c2_connection_active,
|
||||
self.c2_remote_connection,
|
||||
self.c2_config.masquerade_protocol,
|
||||
self.c2_config.masquerade_port,
|
||||
self.config.masquerade_protocol,
|
||||
self.config.masquerade_port,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -3,6 +3,8 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -10,6 +12,7 @@ from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -40,6 +43,18 @@ class DataManipulationAttackStage(IntEnum):
|
||||
class DataManipulationBot(Application, identifier="DataManipulationBot"):
|
||||
"""A bot that simulates a script which performs a SQL injection attack."""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""Configuration schema for DataManipulationBot."""
|
||||
|
||||
type: str = "DataManipulationBot"
|
||||
server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: str = "DELETE"
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
|
||||
config: "DataManipulationBot.ConfigSchema" = Field(default_factory=lambda: DataManipulationBot.ConfigSchema())
|
||||
|
||||
payload: Optional[str] = None
|
||||
port_scan_p_of_success: float = 0.1
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
@@ -56,6 +71,12 @@ class DataManipulationBot(Application, identifier="DataManipulationBot"):
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
self.server_ip_address = self.config.server_ip
|
||||
self.server_password = self.config.server_password
|
||||
self.payload = self.config.payload
|
||||
self.port_scan_p_of_success = self.config.port_scan_p_of_success
|
||||
self.data_manipulation_p_of_success = self.config.data_manipulation_p_of_success
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
@@ -3,11 +3,14 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -32,6 +35,20 @@ class DoSAttackStage(IntEnum):
|
||||
class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
class ConfigSchema(DatabaseClient.ConfigSchema):
|
||||
"""ConfigSchema for DoSBot."""
|
||||
|
||||
type: str = "DoSBot"
|
||||
target_ip_address: Optional[IPV4Address] = None
|
||||
target_port: Port = PORT_LOOKUP["POSTGRES_SERVER"]
|
||||
payload: Optional[str] = None
|
||||
repeat: bool = False
|
||||
port_scan_p_of_success: float = 0.1
|
||||
dos_intensity: float = 1.0
|
||||
max_sessions: int = 1000
|
||||
|
||||
config: "DoSBot.ConfigSchema" = Field(default_factory=lambda: DoSBot.ConfigSchema())
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of the target service."""
|
||||
|
||||
@@ -56,7 +73,13 @@ class DoSBot(DatabaseClient, identifier="DoSBot"):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
self.target_ip_address = self.config.target_ip_address
|
||||
self.target_port = self.config.target_port
|
||||
self.payload = self.config.payload
|
||||
self.repeat = self.config.repeat
|
||||
self.port_scan_p_of_success = self.config.port_scan_p_of_success
|
||||
self.dos_intensity = self.config.dos_intensity
|
||||
self.max_sessions = self.config.max_sessions
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""
|
||||
|
||||
@@ -3,12 +3,14 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
@@ -18,6 +20,16 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
|
||||
:ivar payload: The attack stage query payload. (Default ENCRYPT)
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for RansomwareScript."""
|
||||
|
||||
type: str = "RansomwareScript"
|
||||
server_ip: Optional[IPV4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
payload: str = "ENCRYPT"
|
||||
|
||||
config: "RansomwareScript.ConfigSchema" = Field(default_factory=lambda: RansomwareScript.ConfigSchema())
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of node which hosts the database."""
|
||||
server_password: Optional[str] = None
|
||||
@@ -32,6 +44,9 @@ class RansomwareScript(Application, identifier="RansomwareScript"):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
self.server_ip_address = self.config.server_ip
|
||||
self.server_password = self.config.server_password
|
||||
self.payload = self.config.payload
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -30,7 +30,13 @@ class WebBrowser(Application, identifier="WebBrowser"):
|
||||
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
|
||||
"""
|
||||
|
||||
target_url: Optional[str] = None
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for WebBrowser."""
|
||||
|
||||
type: str = "WebBrowser"
|
||||
target_url: Optional[str] = None
|
||||
|
||||
config: "WebBrowser.ConfigSchema" = Field(default_factory=lambda: WebBrowser.ConfigSchema())
|
||||
|
||||
domain_name_ip_address: Optional[IPv4Address] = None
|
||||
"The IP address of the domain name for the webpage."
|
||||
@@ -86,7 +92,7 @@ class WebBrowser(Application, identifier="WebBrowser"):
|
||||
:param: url: The address of the web page the browser requests
|
||||
:type: url: str
|
||||
"""
|
||||
url = url or self.target_url
|
||||
url = url or self.config.target_url
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class SoftwareManager:
|
||||
return True
|
||||
return False
|
||||
|
||||
def install(self, software_class: Type[IOSoftware], **install_kwargs):
|
||||
def install(self, software_class: Type[IOSoftware], software_config: Optional[IOSoftware.ConfigSchema] = None):
|
||||
"""
|
||||
Install an Application or Service.
|
||||
|
||||
@@ -115,13 +115,22 @@ class SoftwareManager:
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
**install_kwargs,
|
||||
)
|
||||
if software_config is None:
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
)
|
||||
else:
|
||||
software = software_class(
|
||||
software_manager=self,
|
||||
sys_log=self.sys_log,
|
||||
file_system=self.file_system,
|
||||
dns_server=self.dns_server,
|
||||
config=software_config,
|
||||
)
|
||||
|
||||
software.parent = self.node
|
||||
if isinstance(software, Application):
|
||||
self.node.applications[software.uuid] = software
|
||||
|
||||
@@ -5,6 +5,7 @@ from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
|
||||
@@ -14,7 +15,7 @@ from primaite.utils.validation.ipv4_address import IPV4Address
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
class ARP(Service):
|
||||
class ARP(Service, identifier="ARP"):
|
||||
"""
|
||||
The ARP (Address Resolution Protocol) Service.
|
||||
|
||||
@@ -22,6 +23,13 @@ class ARP(Service):
|
||||
sends ARP requests and replies, and processes incoming ARP packets.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ARP."""
|
||||
|
||||
type: str = "ARP"
|
||||
|
||||
config: "ARP.ConfigSchema" = Field(default_factory=lambda: ARP.ConfigSchema())
|
||||
|
||||
arp: Dict[IPV4Address, ARPEntry] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
@@ -17,13 +19,21 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
class DatabaseService(Service, identifier="DatabaseService"):
|
||||
"""
|
||||
A class for simulating a generic SQL Server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseService."""
|
||||
|
||||
type: str = "DatabaseService"
|
||||
backup_server_ip: Optional[IPv4Address] = None
|
||||
|
||||
config: "DatabaseService.ConfigSchema" = Field(default_factory=lambda: DatabaseService.ConfigSchema())
|
||||
|
||||
password: Optional[str] = None
|
||||
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
|
||||
|
||||
@@ -42,6 +52,7 @@ class DatabaseService(Service):
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
self._create_db_file()
|
||||
self.backup_server_ip = self.config.backup_server_ip
|
||||
|
||||
def install(self):
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
@@ -12,9 +14,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSClient(Service):
|
||||
class DNSClient(Service, identifier="DNSClient"):
|
||||
"""Represents a DNS Client as a Service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DNSClient."""
|
||||
|
||||
type: str = "DNSClient"
|
||||
|
||||
config: "DNSClient.ConfigSchema" = Field(default_factory=lambda: DNSClient.ConfigSchema())
|
||||
dns_cache: Dict[str, IPv4Address] = {}
|
||||
"A dict of known mappings between domain/URLs names and IPv4 addresses."
|
||||
dns_server: Optional[IPv4Address] = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket
|
||||
@@ -13,9 +14,17 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DNSServer(Service):
|
||||
class DNSServer(Service, identifier="DNSServer"):
|
||||
"""Represents a DNS Server as a Service."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for DNSServer."""
|
||||
|
||||
type: str = "DNSServer"
|
||||
domain_mapping: dict = {}
|
||||
|
||||
config: "DNSServer.ConfigSchema" = Field(default_factory=lambda: DNSServer.ConfigSchema())
|
||||
|
||||
dns_table: Dict[str, IPv4Address] = {}
|
||||
"A dict of mappings between domain names and IPv4 addresses."
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -9,20 +11,28 @@ from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FTPClient(FTPServiceABC):
|
||||
class FTPClient(FTPServiceABC, identifier="FTPClient"):
|
||||
"""
|
||||
A class for simulating an FTP client service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to emulate FTP
|
||||
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
config: "FTPClient.ConfigSchema" = Field(default_factory=lambda: FTPClient.ConfigSchema())
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for FTPClient."""
|
||||
|
||||
type: str = "FTPClient"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPClient"
|
||||
kwargs["port"] = PORT_LOOKUP["FTP"]
|
||||
@@ -108,6 +118,7 @@ class FTPClient(FTPServiceABC):
|
||||
session_id: Optional[str] = None,
|
||||
is_reattempt: Optional[bool] = False,
|
||||
) -> bool:
|
||||
self._active = True
|
||||
"""
|
||||
Connects the client to a given FTP server.
|
||||
|
||||
@@ -164,6 +175,7 @@ class FTPClient(FTPServiceABC):
|
||||
:param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False.
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
self._active = True
|
||||
# send a disconnect request payload to FTP server
|
||||
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT)
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
@@ -209,6 +221,7 @@ class FTPClient(FTPServiceABC):
|
||||
:param: session_id: The id of the session
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
self._active = True
|
||||
# check if the file to transfer exists on the client
|
||||
file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name)
|
||||
if not file_to_transfer:
|
||||
@@ -266,6 +279,7 @@ class FTPClient(FTPServiceABC):
|
||||
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port["FTP"].
|
||||
:type: dest_port: Optional[int]
|
||||
"""
|
||||
self._active = True
|
||||
# check if FTP is currently connected to IP
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
@@ -317,6 +331,7 @@ class FTPClient(FTPServiceABC):
|
||||
This helps prevent an FTP request loop - FTP client and servers can exist on
|
||||
the same node.
|
||||
"""
|
||||
self._active = True
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,26 +1,36 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import is_valid_port, PORT_LOOKUP
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FTPServer(FTPServiceABC):
|
||||
class FTPServer(FTPServiceABC, identifier="FTPServer"):
|
||||
"""
|
||||
A class for simulating an FTP server service.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to emulate FTP
|
||||
This class inherits from the `FTPServiceABC` class and provides methods to emulate FTP
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
config: "FTPServer.ConfigSchema" = Field(default_factory=lambda: FTPServer.ConfigSchema())
|
||||
|
||||
server_password: Optional[str] = None
|
||||
"""Password needed to connect to FTP server. Default is None."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for FTPServer."""
|
||||
|
||||
type: str = "FTPServer"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPServer"
|
||||
kwargs["port"] = PORT_LOOKUP["FTP"]
|
||||
|
||||
@@ -3,9 +3,11 @@ from abc import ABC
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import StrictBool
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
|
||||
@@ -16,9 +18,22 @@ class FTPServiceABC(Service, ABC):
|
||||
Contains shared methods between both classes.
|
||||
"""
|
||||
|
||||
_active: StrictBool = False
|
||||
"""Flag that is True on timesteps where service transmits data and False when idle. Used for describe_state."""
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""When a new timestep begins, clear the _active attribute."""
|
||||
self._active = False
|
||||
return super().pre_timestep(timestep)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""Returns a Dict of the FTPService state."""
|
||||
return super().describe_state()
|
||||
state = super().describe_state()
|
||||
|
||||
# override so that the service is shows as running only if actively transmitting data this timestep
|
||||
if self.operating_state == ServiceOperatingState.RUNNING and not self._active:
|
||||
state["operating_state"] = ServiceOperatingState.STOPPED.value
|
||||
return state
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
"""
|
||||
@@ -29,6 +44,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: session_id: session ID linked to the FTP Packet. Optional.
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
self._active = True
|
||||
if payload.ftp_command is not None:
|
||||
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
|
||||
|
||||
@@ -51,6 +67,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: payload: The FTP Packet that contains the file data
|
||||
:type: FTPPacket
|
||||
"""
|
||||
self._active = True
|
||||
try:
|
||||
file_name = payload.ftp_command_args["dest_file_name"]
|
||||
folder_name = payload.ftp_command_args["dest_folder_name"]
|
||||
@@ -106,6 +123,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: is_response: is true if the data being sent is in response to a request. Default False.
|
||||
:type: is_response: bool
|
||||
"""
|
||||
self._active = True
|
||||
# send STOR request
|
||||
payload: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.STOR,
|
||||
@@ -135,6 +153,7 @@ class FTPServiceABC(Service, ABC):
|
||||
:param: payload: The FTP Packet that contains the file data
|
||||
:type: FTPPacket
|
||||
"""
|
||||
self._active = True
|
||||
try:
|
||||
# find the file
|
||||
file_name = payload.ftp_command_args["src_file_name"]
|
||||
@@ -181,6 +200,7 @@ class FTPServiceABC(Service, ABC):
|
||||
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
self._active = True
|
||||
self.sys_log.info(f"{self.name}: Sending FTP {payload.ftp_command.name} {payload.ftp_command_args}")
|
||||
|
||||
return super().send(
|
||||
|
||||
@@ -3,6 +3,8 @@ import secrets
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.hardware.base import NetworkInterface
|
||||
from primaite.simulator.network.protocols.icmp import ICMPPacket, ICMPType
|
||||
@@ -14,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ICMP(Service):
|
||||
class ICMP(Service, identifier="ICMP"):
|
||||
"""
|
||||
The Internet Control Message Protocol (ICMP) service.
|
||||
|
||||
@@ -22,6 +24,13 @@ class ICMP(Service):
|
||||
network diagnostics, notably the ping command.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ICMP."""
|
||||
|
||||
type: str = "ICMP"
|
||||
|
||||
config: "ICMP.ConfigSchema" = Field(default_factory=lambda: ICMP.ConfigSchema())
|
||||
|
||||
request_replies: Dict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user